broadcast in ElementwiseMixin.div [pr] (#15897)

This commit is contained in:
chenyu
2026-04-23 16:02:43 -04:00
committed by GitHub
parent 7745e05a2f
commit 782bc6aece
3 changed files with 16 additions and 6 deletions

View File

@@ -22,6 +22,15 @@ class TestTensorUOpBinop(unittest.TestCase):
self.assertIs(_strip_unique((t.eq(1) * Tensor.arange(3)).uop), _strip_unique(t.uop.eq(1) * UOp.arange(3)))
# Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match.
def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5)
# div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp.
def test_div_tensor_by_tensor(self):
a, b = _t(4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3)
def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum())
def test_div_broadcast_tensor_by_tensor(self):
a, b = _t(3, 4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
class TestTensorUOpGetitem(unittest.TestCase):
# ---- pure slice patterns ----

View File

@@ -182,7 +182,8 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
return self._binop(Ops.MOD, x, reverse)
def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
lhs, rhs = self._broadcasted(x, reverse)
return lhs * rhs.reciprocal()
def __neg__(self) -> Self:
return self.neg()

View File

@@ -1715,20 +1715,20 @@ class Tensor(OpMixin):
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
```
"""
if rounding_mode is None: return super().div(x, reverse) # type: ignore[arg-type]
numerator, denominator = self._broadcasted(x, reverse)
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)) and rounding_mode is not None:
if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)):
numerator, denominator = numerator.cast(dt), denominator.cast(dt)
if rounding_mode == "trunc": return numerator.idiv(denominator)
if rounding_mode == "floor":
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False)
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
if rounding_mode == "floor": return d.floor().cast(output_dtype)
if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
return d
raise RuntimeError(f"{rounding_mode=} is not supported")
def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""