diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 6c3023af77..e5009a5859 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -22,6 +22,15 @@ class TestTensorUOpBinop(unittest.TestCase): self.assertIs(_strip_unique((t.eq(1) * Tensor.arange(3)).uop), _strip_unique(t.uop.eq(1) * UOp.arange(3))) # Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match. def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5) + # div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp. + def test_div_tensor_by_tensor(self): + a, b = _t(4).float(), _t(4).float() + 1 + self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop)) + def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3) + def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum()) + def test_div_broadcast_tensor_by_tensor(self): + a, b = _t(3, 4).float(), _t(4).float() + 1 + self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop)) class TestTensorUOpGetitem(unittest.TestCase): # ---- pure slice patterns ---- diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index 9df2c03bce..c722064f58 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -182,7 +182,8 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): return self._binop(Ops.MOD, x, reverse) def div(self, x: Self | ConstType, reverse: bool = False) -> Self: - return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL)) + lhs, rhs = self._broadcasted(x, reverse) + return lhs * rhs.reciprocal() def __neg__(self) -> Self: return self.neg() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e9825d4f28..822279b674 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1715,20 +1715,20 @@ class Tensor(OpMixin): print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy()) ``` """ + if rounding_mode is None: return super().div(x, reverse) # type: ignore[arg-type] numerator, denominator = self._broadcasted(x, reverse) - d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal() - output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype - if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)) and rounding_mode is not None: + if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)): numerator, denominator = numerator.cast(dt), denominator.cast(dt) if rounding_mode == "trunc": return numerator.idiv(denominator) if rounding_mode == "floor": truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False) opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0)) return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div) + d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal() + output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype if rounding_mode == "trunc": return d.trunc().cast(output_dtype) if rounding_mode == "floor": return d.floor().cast(output_dtype) - if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported") - return d + raise RuntimeError(f"{rounding_mode=} is not supported") def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor: """