mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
broadcast in ElementwiseMixin.div [pr] (#15897)
This commit is contained in:
@@ -22,6 +22,15 @@ class TestTensorUOpBinop(unittest.TestCase):
|
||||
self.assertIs(_strip_unique((t.eq(1) * Tensor.arange(3)).uop), _strip_unique(t.uop.eq(1) * UOp.arange(3)))
|
||||
# Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match.
|
||||
def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5)
|
||||
# div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp.
|
||||
def test_div_tensor_by_tensor(self):
|
||||
a, b = _t(4).float(), _t(4).float() + 1
|
||||
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
|
||||
def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3)
|
||||
def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum())
|
||||
def test_div_broadcast_tensor_by_tensor(self):
|
||||
a, b = _t(3, 4).float(), _t(4).float() + 1
|
||||
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
|
||||
|
||||
class TestTensorUOpGetitem(unittest.TestCase):
|
||||
# ---- pure slice patterns ----
|
||||
|
||||
@@ -182,7 +182,8 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
|
||||
return self._binop(Ops.MOD, x, reverse)
|
||||
|
||||
def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
|
||||
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
|
||||
lhs, rhs = self._broadcasted(x, reverse)
|
||||
return lhs * rhs.reciprocal()
|
||||
|
||||
def __neg__(self) -> Self:
|
||||
return self.neg()
|
||||
|
||||
@@ -1715,20 +1715,20 @@ class Tensor(OpMixin):
|
||||
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
||||
```
|
||||
"""
|
||||
if rounding_mode is None: return super().div(x, reverse) # type: ignore[arg-type]
|
||||
numerator, denominator = self._broadcasted(x, reverse)
|
||||
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
||||
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
|
||||
if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)) and rounding_mode is not None:
|
||||
if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)):
|
||||
numerator, denominator = numerator.cast(dt), denominator.cast(dt)
|
||||
if rounding_mode == "trunc": return numerator.idiv(denominator)
|
||||
if rounding_mode == "floor":
|
||||
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False)
|
||||
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
|
||||
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
|
||||
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
||||
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
|
||||
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
|
||||
if rounding_mode == "floor": return d.floor().cast(output_dtype)
|
||||
if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
|
||||
return d
|
||||
raise RuntimeError(f"{rounding_mode=} is not supported")
|
||||
|
||||
def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user