clean up test_tesnor_uop_mixin (#16525)

most of those don't have UNIQUE anymore
This commit is contained in:
chenyu
2026-06-06 23:25:44 -04:00
committed by GitHub
parent 2a2f81dd3d
commit 4e7c6260b0

View File

@@ -18,25 +18,25 @@ class TestTensorUOpBinop(unittest.TestCase):
# Tensor's binop upcasts mixed dtypes via least_upper_dtype + explicit CAST; UOp should match.
def test_mul_float_int(self):
t = _t(3).float()
self.assertIs(_strip_unique((t * Tensor.arange(3)).uop), _strip_unique(t.uop * UOp.arange(3)))
self.assertIs((t * Tensor.arange(3)).uop, t.uop * UOp.arange(3))
def test_mul_bool_int(self):
t = _t(3)
self.assertIs(_strip_unique((t.eq(1) * Tensor.arange(3)).uop), _strip_unique(t.uop.eq(1) * UOp.arange(3)))
self.assertIs((t.eq(1) * Tensor.arange(3)).uop, t.uop.eq(1) * UOp.arange(3))
# Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match.
def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5)
# div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp.
def test_div_tensor_by_tensor(self):
a, b = _t(4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
self.assertIs((a/b).uop, a.uop/b.uop)
def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3)
def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum())
def test_div_broadcast_tensor_by_tensor(self):
a, b = _t(3, 4).float(), _t(4).float() + 1
self.assertIs(_strip_unique((a/b).uop), _strip_unique(a.uop/b.uop))
self.assertIs((a/b).uop, a.uop/b.uop)
# isclose used `self == other` which is Python identity on UOp (not elementwise); now uses .eq().
def test_isclose(self):
t = _t(4).float()
self.assertIs(_strip_unique(t.isclose(t).uop), _strip_unique(t.uop.isclose(t.uop)))
self.assertIs(t.isclose(t).uop, t.uop.isclose(t.uop))
# __floordiv__/mod/fmod and div(rounding_mode=...) dispatch on dtype in mixin
def test_floordiv_int(self): _check(self, _t(4), lambda x: x // 3)
def test_floordiv_float(self): _check(self, _t(4).float() + 1.5, lambda x: x // 2.0)
@@ -113,8 +113,8 @@ class TestTensorUOpCumMinMax(unittest.TestCase):
def _check_pair(self, t, fn):
vt, it = fn(t)
vu, iu = fn(t.uop)
self.assertIs(_strip_unique(vt.uop), _strip_unique(vu))
self.assertIs(_strip_unique(it.uop), _strip_unique(iu))
self.assertIs(vt.uop, vu)
self.assertIs(it.uop, iu)
def test_cummax_1d(self): self._check_pair(_t(5), lambda x: x.cummax(0))
def test_cummax_2d(self): self._check_pair(_t(3, 4), lambda x: x.cummax(1))
def test_cummax_0d(self): self._check_pair(_t(1).reshape(()), lambda x: x.cummax(0))
@@ -122,10 +122,10 @@ class TestTensorUOpCumMinMax(unittest.TestCase):
def test_cummin_2d(self): self._check_pair(_t(3, 4), lambda x: x.cummin(1))
class TestTensorUOpArgMinMax(unittest.TestCase):
def _check_stripped(self, t, fn): self.assertIs(_strip_unique(fn(t).uop), _strip_unique(fn(t.uop)))
def test_argmax(self): self._check_stripped(_t(3, 4), lambda x: x.argmax(axis=1))
def test_argmax_flat(self): self._check_stripped(_t(3, 4), lambda x: x.argmax())
def test_argmin(self): self._check_stripped(_t(3, 4), lambda x: x.argmin(axis=0))
def _check(self, t, fn): self.assertIs(fn(t).uop, fn(t.uop))
def test_argmax(self): self._check(_t(3, 4), lambda x: x.argmax(axis=1))
def test_argmax_flat(self): self._check(_t(3, 4), lambda x: x.argmax())
def test_argmin(self): self._check(_t(3, 4), lambda x: x.argmin(axis=0))
class TestTensorUOpSequential(unittest.TestCase):
def test_sequential(self): _check(self, _t(4), lambda x: x.sequential([lambda y: y * 2, lambda y: y + 1]))
@@ -133,32 +133,32 @@ class TestTensorUOpSequential(unittest.TestCase):
class TestTensorUOpOneHot(unittest.TestCase):
def test_one_hot(self):
t = _t(5)
self.assertIs(_strip_unique(t.one_hot(5).uop), _strip_unique(t.uop.one_hot(5)))
self.assertIs(t.one_hot(5).uop, t.uop.one_hot(5))
class TestTensorUOpSort(unittest.TestCase):
def _check(self, t, **kw):
tv, ti = t.sort(**kw)
uv, ui = t.uop.sort(**kw)
self.assertIs(_strip_unique(tv.uop), _strip_unique(uv))
self.assertIs(_strip_unique(ti.uop), _strip_unique(ui))
self.assertIs(tv.uop, uv)
self.assertIs(ti.uop, ui)
def test_sort_1d(self): self._check(Tensor([0.5, 0.1, 0.3]).float())
def test_sort_descending(self): self._check(Tensor([0.5, 0.1, 0.3]).float(), descending=True)
def test_sort_2d(self): self._check(_t(2, 4).float())
def test_sort_single(self): self._check(Tensor([1.0]).float())
def test_argsort(self):
t = Tensor([0.5, 0.1, 0.3]).float()
self.assertIs(_strip_unique(t.argsort().uop), _strip_unique(t.uop.argsort()))
self.assertIs(t.argsort().uop, t.uop.argsort())
def test_topk(self):
t = _t(2, 4).float()
tv, ti = t.topk(2)
uv, ui = t.uop.topk(2)
self.assertIs(_strip_unique(tv.uop), _strip_unique(uv))
self.assertIs(_strip_unique(ti.uop), _strip_unique(ui))
self.assertIs(tv.uop, uv)
self.assertIs(ti.uop, ui)
class TestTensorUOpAllclose(unittest.TestCase):
def test_allclose(self):
a, b = _t(4).float(), _t(4).float()
self.assertIs(_strip_unique(a.allclose(b).uop), _strip_unique(a.uop.allclose(b.uop)))
self.assertIs(a.allclose(b).uop, a.uop.allclose(b.uop))
class TestTensorUOpBitcast(unittest.TestCase):
def test_bitcast_same_dtype(self): _check(self, _t(4).float(), lambda x: x.bitcast(dtypes.float32))
@@ -168,25 +168,22 @@ class TestTensorUOpRand(unittest.TestCase):
k = UOp.empty((2,), dtype=dtypes.uint32)
c = UOp.zeros(2, dtype=dtypes.uint32)
for num in (1, 4, 7, 1024):
self.assertIs(_strip_unique(Tensor.random_bits(Tensor(k), Tensor(c), num).uop),
_strip_unique(UOp.random_bits(k, c, num)))
self.assertIs(Tensor.random_bits(Tensor(k), Tensor(c), num).uop, UOp.random_bits(k, c, num))
def test_bits_to_rand_float32(self):
bits_uop = UOp.empty((8,), dtype=dtypes.uint32)
for shape in ((8,), (2, 4), (5,)):
self.assertIs(_strip_unique(Tensor._bits_to_rand(Tensor(bits_uop), shape, dtypes.float32).uop),
_strip_unique(UOp._bits_to_rand(bits_uop, shape, dtypes.float32)))
self.assertIs(Tensor._bits_to_rand(Tensor(bits_uop), shape, dtypes.float32).uop, UOp._bits_to_rand(bits_uop, shape, dtypes.float32))
class TestTensorUOpGather(unittest.TestCase):
def _check(self, t, dim, idx):
self.assertIs(_strip_unique(t.gather(dim, idx).uop), _strip_unique(t.uop.gather(dim, idx.uop)))
self.assertIs(t.gather(dim, idx).uop, t.uop.gather(dim, idx.uop))
def test_gather_1d(self): self._check(_t(5), 0, Tensor([2, 1, 0, 1, 2], dtype=dtypes.int32))
def test_gather_dim0(self): self._check(_t(3, 4), 0, Tensor([[0, 1, 2, 0], [1, 2, 0, 1], [2, 0, 1, 2]], dtype=dtypes.int32))
def test_gather_dim1(self): self._check(_t(3, 4), 1, Tensor([[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1]], dtype=dtypes.int32))
class TestTensorUOpInterpolate(unittest.TestCase):
def _check(self, t, mode):
self.assertIs(_strip_unique(t.interpolate(size=(2, 2), mode=mode).uop),
_strip_unique(t.uop.interpolate(size=(2, 2), mode=mode)))
self.assertIs(t.interpolate(size=(2, 2), mode=mode).uop, t.uop.interpolate(size=(2, 2), mode=mode))
def test_interpolate_nearest(self): self._check(_t(1, 1, 4, 4).float(), "nearest")
def test_interpolate_nearest_exact(self): self._check(_t(1, 1, 4, 4).float(), "nearest-exact")
def test_interpolate_linear(self): self._check(_t(1, 1, 4, 4).float(), "linear")
@@ -194,51 +191,46 @@ class TestTensorUOpInterpolate(unittest.TestCase):
class TestTensorUOpLoss(unittest.TestCase):
def test_cross_entropy(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.cross_entropy(Y).uop), _strip_unique(t.uop.cross_entropy(Y.uop)))
self.assertIs(t.cross_entropy(Y).uop, t.uop.cross_entropy(Y.uop))
def test_sparse_categorical_crossentropy(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y).uop), _strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop)))
self.assertIs(t.sparse_categorical_crossentropy(Y).uop, t.uop.sparse_categorical_crossentropy(Y.uop))
def test_sparse_categorical_crossentropy_ignore_index(self):
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop),
_strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0)))
self.assertIs(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop, t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0))
def test_nll_loss(self):
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.nll_loss(Y).uop), _strip_unique(t.uop.nll_loss(Y.uop)))
self.assertIs(t.nll_loss(Y).uop, t.uop.nll_loss(Y.uop))
def test_nll_loss_weight(self):
t, Y, w = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32), _t(3).float()
self.assertIs(_strip_unique(t.nll_loss(Y, weight=w).uop), _strip_unique(t.uop.nll_loss(Y.uop, weight=w.uop)))
self.assertIs(t.nll_loss(Y, weight=w).uop, t.uop.nll_loss(Y.uop, weight=w.uop))
def test_nll_loss_ignore_index(self):
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.nll_loss(Y, ignore_index=1).uop), _strip_unique(t.uop.nll_loss(Y.uop, ignore_index=1)))
self.assertIs(t.nll_loss(Y, ignore_index=1).uop, t.uop.nll_loss(Y.uop, ignore_index=1))
def test_nll_loss_none_reduction(self):
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
self.assertIs(_strip_unique(t.nll_loss(Y, reduction="none").uop), _strip_unique(t.uop.nll_loss(Y.uop, reduction="none")))
self.assertIs(t.nll_loss(Y, reduction="none").uop, t.uop.nll_loss(Y.uop, reduction="none"))
def test_nll_loss_weight_ignore_index(self):
t, Y, w = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32), _t(3).float()
self.assertIs(_strip_unique(t.nll_loss(Y, weight=w, ignore_index=1).uop),
_strip_unique(t.uop.nll_loss(Y.uop, weight=w.uop, ignore_index=1)))
self.assertIs(t.nll_loss(Y, weight=w, ignore_index=1).uop, t.uop.nll_loss(Y.uop, weight=w.uop, ignore_index=1))
class TestTensorUOpScatter(unittest.TestCase):
def test_scatter(self):
x, idx, src = _t(3, 4).float(), Tensor([[0, 1, 2, 0]], dtype=dtypes.int32), _t(1, 4).float()
self.assertIs(_strip_unique(x.scatter(0, idx, src).uop), _strip_unique(x.uop.scatter(0, idx.uop, src.uop)))
self.assertIs(x.scatter(0, idx, src).uop, x.uop.scatter(0, idx.uop, src.uop))
def test_scatter_scalar_src(self):
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14).uop), _strip_unique(x.uop.scatter(1, idx.uop, 3.14)))
self.assertIs(x.scatter(1, idx, 3.14).uop, x.uop.scatter(1, idx.uop, 3.14))
# inf cannot be cast to int — this regresses if scalar src is routed through index.dtype first
def test_scatter_inf_src(self):
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
self.assertIs(_strip_unique(x.scatter(1, idx, float("inf")).uop),
_strip_unique(x.uop.scatter(1, idx.uop, float("inf"))))
self.assertIs(x.scatter(1, idx, float("inf")).uop, x.uop.scatter(1, idx.uop, float("inf")))
def test_scatter_add(self):
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="add").uop),
_strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="add")))
self.assertIs(x.scatter(1, idx, 3.14, reduce="add").uop, x.uop.scatter(1, idx.uop, 3.14, reduce="add"))
def test_scatter_multiply(self):
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="multiply").uop),
_strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="multiply")))
self.assertIs(x.scatter(1, idx, 3.14, reduce="multiply").uop, x.uop.scatter(1, idx.uop, 3.14, reduce="multiply"))
# tensor src with reduce hits the "elif reduce: raise" branch in both Tensor and UOp paths
def test_scatter_tensor_src_with_reduce_raises(self):
x, idx, src = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32), _t(1, 2).float()
@@ -247,8 +239,7 @@ class TestTensorUOpScatter(unittest.TestCase):
class TestTensorUOpScatterReduce(unittest.TestCase):
def _check(self, x, idx, src, **kw):
self.assertIs(_strip_unique(x.scatter_reduce(0, idx, src, **kw).uop),
_strip_unique(x.uop.scatter_reduce(0, idx.uop, src.uop, **kw)))
self.assertIs(x.scatter_reduce(0, idx, src, **kw).uop, x.uop.scatter_reduce(0, idx.uop, src.uop, **kw))
def test_sum(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="sum")
def test_prod(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="prod")
def test_mean(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean")
@@ -269,12 +260,12 @@ class TestTensorUOpPool(unittest.TestCase):
t = _t(1, 1, 5, 5).float()
vt, it = t.max_pool2d(return_indices=True)
vu, iu = t.uop.max_pool2d(return_indices=True)
self.assertIs(_strip_unique(vt.uop), _strip_unique(vu))
self.assertIs(_strip_unique(it.uop), _strip_unique(iu))
self.assertIs(vt.uop, vu)
self.assertIs(it.uop, iu)
def test_max_unpool2d(self):
t = _t(1, 1, 4, 4).float()
out, idx = t.max_pool2d(return_indices=True)
self.assertIs(_strip_unique(out.max_unpool2d(idx).uop), _strip_unique(out.uop.max_unpool2d(idx.uop)))
self.assertIs(out.max_unpool2d(idx).uop, out.uop.max_unpool2d(idx.uop))
class TestTensorUOpCat(unittest.TestCase):
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
@@ -342,8 +333,8 @@ class TestTensorUOpQR(unittest.TestCase):
def _check(self, t):
qt, rt = t.qr()
qu, ru = t.uop.qr()
self.assertIs(_strip_unique(qt.uop), _strip_unique(qu))
self.assertIs(_strip_unique(rt.uop), _strip_unique(ru))
self.assertIs(qt.uop, qu)
self.assertIs(rt.uop, ru)
def test_qr_square(self): self._check(_t(3, 3).float())
def test_qr_tall(self): self._check(_t(4, 3).float())
def test_qr_wide(self): self._check(_t(3, 4).float())
@@ -354,9 +345,9 @@ class TestTensorUOpSVD(unittest.TestCase):
def _check(self, t, **kw):
ut, st, vt = t.svd(**kw)
uu, su, vu = t.uop.svd(**kw)
self.assertIs(_strip_unique(ut.uop), _strip_unique(uu))
self.assertIs(_strip_unique(st.uop), _strip_unique(su))
self.assertIs(_strip_unique(vt.uop), _strip_unique(vu))
self.assertIs(ut.uop, uu)
self.assertIs(st.uop, su)
self.assertIs(vt.uop, vu)
def test_svd_square(self): self._check(_t(2, 2).float())
def test_svd_tall(self): self._check(_t(3, 2).float())
def test_svd_wide(self): self._check(_t(2, 3).float())
@@ -403,31 +394,31 @@ class TestTensorUOpCreation(unittest.TestCase):
def test_invalids(self):
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids(2, 3, dtype=dtypes.int8)))
def test_arange(self):
self.assertIs(_strip_unique(Tensor.arange(5).uop), _strip_unique(UOp.arange(5)))
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
def test_arange_empty(self):
self.assertIs(_strip_unique(Tensor.arange(5, 5).uop), _strip_unique(UOp.arange(5, 5)))
self.assertIs(Tensor.arange(5, 5).uop, UOp.arange(5, 5))
def test_arange_step(self):
self.assertIs(_strip_unique(Tensor.arange(5, 10, 2).uop), _strip_unique(UOp.arange(5, 10, 2)))
self.assertIs(Tensor.arange(5, 10, 2).uop, UOp.arange(5, 10, 2))
def test_linspace(self):
self.assertIs(_strip_unique(Tensor.linspace(0, 10, 5).uop), _strip_unique(UOp.linspace(0, 10, 5)))
self.assertIs(Tensor.linspace(0, 10, 5).uop, UOp.linspace(0, 10, 5))
def test_linspace_one_step(self):
self.assertIs(_strip_unique(Tensor.linspace(5, 10, 1).uop), _strip_unique(UOp.linspace(5, 10, 1)))
self.assertIs(Tensor.linspace(5, 10, 1).uop, UOp.linspace(5, 10, 1))
def test_eye(self):
self.assertIs(_strip_unique(Tensor.eye(3).uop), _strip_unique(UOp.eye(3)))
self.assertIs(Tensor.eye(3).uop, UOp.eye(3))
def test_eye_rect(self):
self.assertIs(_strip_unique(Tensor.eye(2, 4).uop), _strip_unique(UOp.eye(2, 4)))
self.assertIs(Tensor.eye(2, 4).uop, UOp.eye(2, 4))
def test_triu(self):
t = _t(3, 4)
self.assertIs(_strip_unique(t.triu().uop), _strip_unique(t.uop.triu()))
self.assertIs(t.triu().uop, t.uop.triu())
def test_triu_diagonal(self):
t = _t(3, 4)
self.assertIs(_strip_unique(t.triu(diagonal=1).uop), _strip_unique(t.uop.triu(diagonal=1)))
self.assertIs(t.triu(diagonal=1).uop, t.uop.triu(diagonal=1))
def test_tril(self):
t = _t(3, 4)
self.assertIs(_strip_unique(t.tril().uop), _strip_unique(t.uop.tril()))
self.assertIs(t.tril().uop, t.uop.tril())
def test_tril_diagonal(self):
t = _t(3, 4)
self.assertIs(_strip_unique(t.tril(diagonal=-1).uop), _strip_unique(t.uop.tril(diagonal=-1)))
self.assertIs(t.tril(diagonal=-1).uop, t.uop.tril(diagonal=-1))
if __name__ == "__main__":
unittest.main()