mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
494 lines
28 KiB
Python
494 lines
28 KiB
Python
import math, unittest
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
|
|
|
|
_strip_unique_pm = PatternMatcher([
|
|
(UPat((Ops.UNIQUE, Ops.LUNIQUE), name="u"), lambda u: u.replace(arg=0) if u.arg != 0 else None),
|
|
])
|
|
def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
|
|
|
|
def _t(*shape):
|
|
return Tensor.arange(math.prod(shape)).reshape(*shape)
|
|
|
|
def _ti(data): # int32 index tensor
|
|
return Tensor(data, dtype=dtypes.int32)
|
|
|
|
# Tensor().func().uop should be the same as UOp.func()
|
|
def _check(tc: unittest.TestCase, t: Tensor, fn):
|
|
tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}")
|
|
|
|
class TestTensorUOpBinop(unittest.TestCase):
|
|
# Tensor's binop upcasts mixed dtypes via least_upper_dtype + explicit CAST; UOp should match.
|
|
def test_mul_float_int(self):
|
|
t = _t(3).float()
|
|
self.assertIs((t * Tensor.arange(3)).uop, t.uop * UOp.arange(3))
|
|
def test_mul_bool_int(self):
|
|
t = _t(3)
|
|
self.assertIs((t.eq(1) * Tensor.arange(3)).uop, t.uop.eq(1) * UOp.arange(3))
|
|
# Tensor's ufix picks float dtype when scalar is float and self is int; UOp should match.
|
|
def test_add_scalar_float_on_int(self): _check(self, _t(3), lambda x: x + 1.5)
|
|
# div: Tensor.div (default case) delegates to ElementwiseMixin.div; trees must match for Tensor and UOp.
|
|
def test_div_tensor_by_tensor(self):
|
|
a, b = _t(4).float(), _t(4).float() + 1
|
|
self.assertIs((a/b).uop, a.uop/b.uop)
|
|
def test_div_int_by_int(self): _check(self, _t(4), lambda x: x / 3)
|
|
def test_div_sum_by_sum(self): _check(self, _t(4).float(), lambda x: x.sum() / (x + 1).sum())
|
|
def test_div_broadcast_tensor_by_tensor(self):
|
|
a, b = _t(3, 4).float(), _t(4).float() + 1
|
|
self.assertIs((a/b).uop, a.uop/b.uop)
|
|
# isclose used `self == other` which is Python identity on UOp (not elementwise); now uses .eq().
|
|
def test_isclose(self):
|
|
t = _t(4).float()
|
|
self.assertIs(t.isclose(t).uop, t.uop.isclose(t.uop))
|
|
# __floordiv__/mod/fmod and div(rounding_mode=...) dispatch on dtype in mixin
|
|
def test_floordiv_int(self): _check(self, _t(4), lambda x: x // 3)
|
|
def test_floordiv_float(self): _check(self, _t(4).float() + 1.5, lambda x: x // 2.0)
|
|
def test_rfloordiv_int(self): _check(self, _t(4)+1, lambda x: 7 // x)
|
|
def test_mod_int(self): _check(self, _t(4), lambda x: x % 3)
|
|
def test_mod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x % 2.0)
|
|
def test_div_trunc_int(self): _check(self, _t(4), lambda x: x.div(3, rounding_mode="trunc"))
|
|
def test_div_trunc_float(self):_check(self, _t(4).float() + 1.5, lambda x: x.div(2.0, rounding_mode="trunc"))
|
|
def test_fmod_int(self): _check(self, _t(4), lambda x: x.fmod(3))
|
|
def test_fmod_float(self): _check(self, _t(4).float() + 1.5, lambda x: x.fmod(2.0))
|
|
def test_floordiv_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x // True)
|
|
def test_mod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x % True)
|
|
def test_fmod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x.fmod(True))
|
|
|
|
class TestTensorUOpClone(unittest.TestCase):
|
|
def test_clone(self):
|
|
t = _t(3, 4).float()
|
|
self.assertIs(_strip_unique(t.clone().uop), _strip_unique(t.uop.clone()))
|
|
def test_clone_deviceless_const(self):
|
|
u = UOp.const(dtypes.float, 2.0)
|
|
self.assertIs(_strip_unique(Tensor(u).clone().uop), _strip_unique(u.clone()))
|
|
|
|
class TestTensorUOpGradient(unittest.TestCase):
|
|
def test_gradient(self):
|
|
x = _t(3, 3).float()
|
|
z = (x * 2).sum()
|
|
(tg,) = z.gradient(x)
|
|
(ug,) = z.uop.gradient(x.uop)
|
|
self.assertIs(tg.uop, ug)
|
|
|
|
class TestTensorUOpGetitem(unittest.TestCase):
|
|
# ---- pure slice patterns ----
|
|
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])
|
|
def test_slice_positive(self): _check(self, _t(8), lambda x: x[1:5])
|
|
def test_slice_open_start(self): _check(self, _t(8), lambda x: x[:5])
|
|
def test_slice_open_stop(self): _check(self, _t(8), lambda x: x[3:])
|
|
def test_slice_negative_start(self): _check(self, _t(8), lambda x: x[-3:])
|
|
def test_slice_negative_stop(self): _check(self, _t(8), lambda x: x[:-2])
|
|
def test_slice_both_negative(self): _check(self, _t(8), lambda x: x[-5:-1])
|
|
|
|
# ---- slice with stride ----
|
|
def test_slice_stride(self): _check(self, _t(6), lambda x: x[::2])
|
|
def test_slice_start_stop_stride(self): _check(self, _t(6), lambda x: x[1:5:2])
|
|
def test_slice_reverse(self): _check(self, _t(6), lambda x: x[::-1])
|
|
def test_slice_singleton_negative_step(self): _check(self, _t(8), lambda x: x[3:2:-1])
|
|
|
|
# ---- empty / out-of-bounds slice ----
|
|
def test_slice_empty(self): _check(self, _t(6), lambda x: x[3:1])
|
|
def test_slice_oob_stop(self): _check(self, _t(6), lambda x: x[0:100])
|
|
|
|
# ---- single int (reduces a dim) ----
|
|
def test_int_positive(self): _check(self, _t(8), lambda x: x[3])
|
|
def test_int_negative(self): _check(self, _t(8), lambda x: x[-1])
|
|
|
|
# ---- ellipsis ----
|
|
def test_ellipsis_only(self): _check(self, _t(2, 3, 4), lambda x: x[...])
|
|
def test_ellipsis_then_int(self): _check(self, _t(2, 3, 4), lambda x: x[..., -1])
|
|
def test_ellipsis_then_slice(self): _check(self, _t(2, 3, 4), lambda x: x[..., 1:3])
|
|
def test_ellipsis_then_none(self): _check(self, _t(2, 3), lambda x: x[..., None])
|
|
|
|
# ---- None (unsqueeze) ----
|
|
def test_none_front(self): _check(self, _t(4), lambda x: x[None])
|
|
def test_none_back(self): _check(self, _t(4), lambda x: x[:, None])
|
|
def test_none_middle(self): _check(self, _t(2, 3), lambda x: x[:, None, :])
|
|
def test_multiple_none(self): _check(self, _t(2, 3), lambda x: x[None, :, None])
|
|
|
|
# ---- mixed multi-dim ----
|
|
def test_int_then_slice(self): _check(self, _t(2, 3), lambda x: x[1, :])
|
|
def test_multi_int(self): _check(self, _t(2, 3, 4), lambda x: x[1, 2])
|
|
def test_mixed_slice_int(self): _check(self, _t(2, 3, 4), lambda x: x[0:2, -1, 1:3])
|
|
def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2])
|
|
def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None])
|
|
|
|
# ---- advanced indexing: UOp index on UOp must match Tensor index on Tensor (same uop) ----
|
|
def _check_adv(self, t, idx):
|
|
# idx is a Tensor or a tuple mixing Tensor index arrays with slice/int/None/Ellipsis
|
|
ui = idx.uop if isinstance(idx, Tensor) else (tuple(i.uop if isinstance(i,Tensor) else i for i in idx) if isinstance(idx, tuple) else idx)
|
|
self.assertIs(t[idx].uop, t.uop[ui])
|
|
|
|
def test_adv_single(self): self._check_adv(_t(5), _ti([2,1,0,1,2]))
|
|
def test_adv_negative(self): self._check_adv(_t(5), _ti([-1,-2,0]))
|
|
def test_adv_out_of_bounds(self): self._check_adv(_t(5), _ti([4,7,2])) # oob -> 0
|
|
def test_adv_2d_dim0(self): self._check_adv(_t(3,4), _ti([2,0,1]))
|
|
def test_adv_2d_index_array(self): self._check_adv(_t(4,5), _ti([[0,1],[2,3]])) # shaped index
|
|
def test_adv_two_consecutive(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([1,2,3]))) # linear path
|
|
def test_adv_two_broadcast(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([[1],[2],[3]])))
|
|
def test_adv_three_consecutive(self):self._check_adv(_t(3,4,5), (_ti([2,0]), _ti([1,3]), _ti([4,2])))
|
|
def test_adv_after_slice(self): self._check_adv(_t(2,3,4), (slice(None), _ti([2,0,1])))
|
|
def test_adv_non_consec_permute(self): self._check_adv(_t(2,3,4), (_ti([1,0]), slice(None), _ti([2,0])))
|
|
def test_adv_idx_then_int(self): self._check_adv(_t(3,4), (_ti([2,0,1]), 2))
|
|
def test_adv_int_then_idx(self): self._check_adv(_t(3,4), (1, _ti([2,0,1])))
|
|
def test_adv_idx_then_none(self): self._check_adv(_t(3,4), (_ti([2,0,1]), None))
|
|
def test_adv_none_then_idx(self): self._check_adv(_t(3,4), (None, _ti([2,0,1])))
|
|
def test_adv_ellipsis_then_idx(self):self._check_adv(_t(2,3,4), (Ellipsis, _ti([2,0,1])))
|
|
def test_adv_idx_slice_mix(self): self._check_adv(_t(4,5,6), (_ti([1,3]), slice(1,4), _ti([2,0])))
|
|
|
|
# bool index is unsupported
|
|
def test_adv_bool_index_rejected(self):
|
|
with self.assertRaises(IndexError): _t(5)[_t(5) > 2]
|
|
with self.assertRaises(IndexError): _t(5).uop[(_t(5) > 2).uop]
|
|
|
|
# python list/tuple indices
|
|
def test_adv_python_list(self):
|
|
self.assertIs(_strip_unique(_t(5)[[2,1,0]].uop), _strip_unique(_t(5).uop[[2,1,0]]))
|
|
def test_adv_python_list_negative(self):
|
|
self.assertIs(_strip_unique(_t(5)[[-1,-2,0]].uop), _strip_unique(_t(5).uop[[-1,-2,0]]))
|
|
def test_adv_python_list_nested(self):
|
|
self.assertIs(_strip_unique(_t(3,4)[[0,1],[2,0]].uop), _strip_unique(_t(3,4).uop[[0,1],[2,0]]))
|
|
def test_adv_python_list_2d_index(self):
|
|
self.assertIs(_strip_unique(_t(4,5)[[[0,1],[2,3]]].uop), _strip_unique(_t(4,5).uop[[[0,1],[2,3]]]))
|
|
|
|
class TestTensorUOpCumalu(unittest.TestCase):
|
|
def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum())
|
|
def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1))
|
|
def test_cumsum_non_last(self): _check(self, _t(3, 4), lambda x: x.cumsum(0))
|
|
def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu
|
|
def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0))
|
|
|
|
class TestTensorUOpCumMinMax(unittest.TestCase):
|
|
def _check_pair(self, t, fn):
|
|
vt, it = fn(t)
|
|
vu, iu = fn(t.uop)
|
|
self.assertIs(vt.uop, vu)
|
|
self.assertIs(it.uop, iu)
|
|
def test_cummax_1d(self): self._check_pair(_t(5), lambda x: x.cummax(0))
|
|
def test_cummax_2d(self): self._check_pair(_t(3, 4), lambda x: x.cummax(1))
|
|
def test_cummax_0d(self): self._check_pair(_t(1).reshape(()), lambda x: x.cummax(0))
|
|
def test_cummin_1d(self): self._check_pair(_t(5), lambda x: x.cummin(0))
|
|
def test_cummin_2d(self): self._check_pair(_t(3, 4), lambda x: x.cummin(1))
|
|
|
|
class TestTensorUOpArgMinMax(unittest.TestCase):
|
|
def _check(self, t, fn): self.assertIs(fn(t).uop, fn(t.uop))
|
|
def test_argmax(self): self._check(_t(3, 4), lambda x: x.argmax(axis=1))
|
|
def test_argmax_flat(self): self._check(_t(3, 4), lambda x: x.argmax())
|
|
def test_argmin(self): self._check(_t(3, 4), lambda x: x.argmin(axis=0))
|
|
|
|
class TestTensorUOpSequential(unittest.TestCase):
|
|
def test_sequential(self): _check(self, _t(4), lambda x: x.sequential([lambda y: y * 2, lambda y: y + 1]))
|
|
|
|
class TestTensorUOpOneHot(unittest.TestCase):
|
|
def test_one_hot(self):
|
|
t = _t(5)
|
|
self.assertIs(t.one_hot(5).uop, t.uop.one_hot(5))
|
|
|
|
class TestTensorUOpSort(unittest.TestCase):
|
|
def _check(self, t, **kw):
|
|
tv, ti = t.sort(**kw)
|
|
uv, ui = t.uop.sort(**kw)
|
|
self.assertIs(tv.uop, uv)
|
|
self.assertIs(ti.uop, ui)
|
|
def test_sort_1d(self): self._check(Tensor([0.5, 0.1, 0.3]).float())
|
|
def test_sort_descending(self): self._check(Tensor([0.5, 0.1, 0.3]).float(), descending=True)
|
|
def test_sort_2d(self): self._check(_t(2, 4).float())
|
|
def test_sort_single(self): self._check(Tensor([1.0]).float())
|
|
def test_argsort(self):
|
|
t = Tensor([0.5, 0.1, 0.3]).float()
|
|
self.assertIs(t.argsort().uop, t.uop.argsort())
|
|
def test_topk(self):
|
|
t = _t(2, 4).float()
|
|
tv, ti = t.topk(2)
|
|
uv, ui = t.uop.topk(2)
|
|
self.assertIs(tv.uop, uv)
|
|
self.assertIs(ti.uop, ui)
|
|
|
|
class TestTensorUOpAllclose(unittest.TestCase):
|
|
def test_allclose(self):
|
|
a, b = _t(4).float(), _t(4).float()
|
|
self.assertIs(a.allclose(b).uop, a.uop.allclose(b.uop))
|
|
|
|
class TestTensorUOpCast(unittest.TestCase):
|
|
def test_cast_str_dtype(self):
|
|
t = _t(4)
|
|
self.assertIs(t.cast("float32").uop, t.uop.cast("float32"))
|
|
self.assertIs(t.uop.cast("float32").dtype, dtypes.float32)
|
|
|
|
class TestTensorUOpBitcast(unittest.TestCase):
|
|
def test_bitcast_same_dtype(self): _check(self, _t(4).float(), lambda x: x.bitcast(dtypes.float32))
|
|
def test_bitcast_str_dtype(self):
|
|
t = _t(4)
|
|
self.assertIs(t.bitcast("uint32").uop, t.uop.bitcast("uint32"))
|
|
self.assertIs(t.uop.bitcast("uint32").dtype, dtypes.uint32)
|
|
|
|
class TestTensorUOpRand(unittest.TestCase):
|
|
def test_random_bits(self):
|
|
k = UOp.empty((2,), dtype=dtypes.uint32)
|
|
c = UOp.zeros(2, dtype=dtypes.uint32)
|
|
for num in (1, 4, 7, 1024):
|
|
self.assertIs(Tensor.random_bits(Tensor(k), Tensor(c), num).uop, UOp.random_bits(k, c, num))
|
|
def test_bits_to_rand_float32(self):
|
|
bits_uop = UOp.empty((8,), dtype=dtypes.uint32)
|
|
for shape in ((8,), (2, 4), (5,)):
|
|
self.assertIs(Tensor._bits_to_rand(Tensor(bits_uop), shape, dtypes.float32).uop, UOp._bits_to_rand(bits_uop, shape, dtypes.float32))
|
|
def test_threefry(self):
|
|
t = _t(4).cast(dtypes.uint64)
|
|
self.assertIs(t.threefry(t).uop, t.uop.threefry(t.uop))
|
|
def test_threefry_random_bits(self):
|
|
key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32)
|
|
self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1))
|
|
def test_rand(self):
|
|
k, c = UOp.empty((2,), dtype=dtypes.uint32), UOp.zeros(2, dtype=dtypes.uint32)
|
|
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (2, 2), dtypes.float32).uop, UOp._rand(k, c, (2, 2), dtypes.float32))
|
|
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (0, 3), dtypes.float32).uop, UOp._rand(k, c, (0, 3), dtypes.float32))
|
|
|
|
class TestTensorUOpGather(unittest.TestCase):
|
|
def _check(self, t, dim, idx):
|
|
self.assertIs(t.gather(dim, idx).uop, t.uop.gather(dim, idx.uop))
|
|
def test_gather_1d(self): self._check(_t(5), 0, Tensor([2, 1, 0, 1, 2], dtype=dtypes.int32))
|
|
def test_gather_dim0(self): self._check(_t(3, 4), 0, Tensor([[0, 1, 2, 0], [1, 2, 0, 1], [2, 0, 1, 2]], dtype=dtypes.int32))
|
|
def test_gather_dim1(self): self._check(_t(3, 4), 1, Tensor([[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1]], dtype=dtypes.int32))
|
|
|
|
class TestTensorUOpInterpolate(unittest.TestCase):
|
|
def _check(self, t, mode):
|
|
self.assertIs(t.interpolate(size=(2, 2), mode=mode).uop, t.uop.interpolate(size=(2, 2), mode=mode))
|
|
def test_interpolate_nearest(self): self._check(_t(1, 1, 4, 4).float(), "nearest")
|
|
def test_interpolate_nearest_exact(self): self._check(_t(1, 1, 4, 4).float(), "nearest-exact")
|
|
def test_interpolate_linear(self): self._check(_t(1, 1, 4, 4).float(), "linear")
|
|
|
|
class TestTensorUOpLoss(unittest.TestCase):
|
|
def test_cross_entropy(self):
|
|
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.cross_entropy(Y).uop, t.uop.cross_entropy(Y.uop))
|
|
def test_sparse_categorical_crossentropy(self):
|
|
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.sparse_categorical_crossentropy(Y).uop, t.uop.sparse_categorical_crossentropy(Y.uop))
|
|
def test_sparse_categorical_crossentropy_ignore_index(self):
|
|
t, Y = _t(2, 3).float(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop, t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0))
|
|
def test_nll_loss(self):
|
|
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.nll_loss(Y).uop, t.uop.nll_loss(Y.uop))
|
|
def test_nll_loss_weight(self):
|
|
t, Y, w = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32), _t(3).float()
|
|
self.assertIs(t.nll_loss(Y, weight=w).uop, t.uop.nll_loss(Y.uop, weight=w.uop))
|
|
def test_nll_loss_ignore_index(self):
|
|
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.nll_loss(Y, ignore_index=1).uop, t.uop.nll_loss(Y.uop, ignore_index=1))
|
|
def test_nll_loss_none_reduction(self):
|
|
t, Y = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32)
|
|
self.assertIs(t.nll_loss(Y, reduction="none").uop, t.uop.nll_loss(Y.uop, reduction="none"))
|
|
def test_nll_loss_weight_ignore_index(self):
|
|
t, Y, w = _t(2, 3).float().log_softmax(), Tensor([1, 2], dtype=dtypes.int32), _t(3).float()
|
|
self.assertIs(t.nll_loss(Y, weight=w, ignore_index=1).uop, t.uop.nll_loss(Y.uop, weight=w.uop, ignore_index=1))
|
|
|
|
class TestTensorUOpScatter(unittest.TestCase):
|
|
def test_scatter(self):
|
|
x, idx, src = _t(3, 4).float(), Tensor([[0, 1, 2, 0]], dtype=dtypes.int32), _t(1, 4).float()
|
|
self.assertIs(x.scatter(0, idx, src).uop, x.uop.scatter(0, idx.uop, src.uop))
|
|
def test_scatter_scalar_src(self):
|
|
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
|
self.assertIs(x.scatter(1, idx, 3.14).uop, x.uop.scatter(1, idx.uop, 3.14))
|
|
# inf cannot be cast to int — this regresses if scalar src is routed through index.dtype first
|
|
def test_scatter_inf_src(self):
|
|
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
|
self.assertIs(x.scatter(1, idx, float("inf")).uop, x.uop.scatter(1, idx.uop, float("inf")))
|
|
def test_scatter_add(self):
|
|
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
|
self.assertIs(x.scatter(1, idx, 3.14, reduce="add").uop, x.uop.scatter(1, idx.uop, 3.14, reduce="add"))
|
|
def test_scatter_multiply(self):
|
|
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
|
self.assertIs(x.scatter(1, idx, 3.14, reduce="multiply").uop, x.uop.scatter(1, idx.uop, 3.14, reduce="multiply"))
|
|
# tensor src with reduce hits the "elif reduce: raise" branch in both Tensor and UOp paths
|
|
def test_scatter_tensor_src_with_reduce_raises(self):
|
|
x, idx, src = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32), _t(1, 2).float()
|
|
with self.assertRaises(TypeError): x.scatter(1, idx, src, reduce="add")
|
|
with self.assertRaises(TypeError): x.uop.scatter(1, idx.uop, src.uop, reduce="add")
|
|
|
|
class TestTensorUOpScatterReduce(unittest.TestCase):
|
|
def _check(self, x, idx, src, **kw):
|
|
self.assertIs(x.scatter_reduce(0, idx, src, **kw).uop, x.uop.scatter_reduce(0, idx.uop, src.uop, **kw))
|
|
def test_sum(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="sum")
|
|
def test_prod(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="prod")
|
|
def test_mean(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean")
|
|
def test_amax(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="amax")
|
|
def test_amin(self): self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="amin")
|
|
def test_mean_exclude_self(self):
|
|
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
|
|
|
|
class TestTensorUOpPool(unittest.TestCase):
|
|
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
|
|
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
|
|
def test_avg_pool2d_ceil(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(ceil_mode=True))
|
|
def test_avg_pool2d_no_count_pad(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1, count_include_pad=False))
|
|
def test_max_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.max_pool2d())
|
|
def test_max_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.max_pool2d(padding=1))
|
|
def test_max_pool2d_ceil(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.max_pool2d(ceil_mode=True))
|
|
def test_max_pool2d_return_indices(self):
|
|
t = _t(1, 1, 5, 5).float()
|
|
vt, it = t.max_pool2d(return_indices=True)
|
|
vu, iu = t.uop.max_pool2d(return_indices=True)
|
|
self.assertIs(vt.uop, vu)
|
|
self.assertIs(it.uop, iu)
|
|
def test_max_unpool2d(self):
|
|
t = _t(1, 1, 4, 4).float()
|
|
out, idx = t.max_pool2d(return_indices=True)
|
|
self.assertIs(out.max_unpool2d(idx).uop, out.uop.max_unpool2d(idx.uop))
|
|
|
|
class TestTensorUOpCat(unittest.TestCase):
|
|
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
|
|
def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1))
|
|
def test_cat_3tensors(self): _check(self, _t(2, 3), lambda x: x.cat(x, x, dim=0))
|
|
def test_cat_neg_dim(self): _check(self, _t(2, 3, 4), lambda x: x.cat(x, dim=-1))
|
|
|
|
class TestTensorUOpPad(unittest.TestCase):
|
|
def test_pad_flat(self): _check(self, _t(4, 5), lambda x: x.pad((1, 2, 0, 3)))
|
|
def test_pad_flat_negative(self): _check(self, _t(4, 5), lambda x: x.pad((1, -1, 0, 2), value=-1.0))
|
|
def test_pad_grouped_none(self): _check(self, _t(4, 5), lambda x: x.pad((None, (0, 3))))
|
|
def test_pad_circular(self): _check(self, _t(4, 5), lambda x: x.pad(((1, 2), (0, 3)), mode="circular"))
|
|
def test_pad_circular_zero_after(self):_check(self, _t(4, 5), lambda x: x.pad(((1, 0), (2, 0)), mode="circular"))
|
|
def test_pad_reflect(self): _check(self, _t(4, 5), lambda x: x.pad(((1, 2), (0, 3)), mode="reflect"))
|
|
def test_pad_reflect_negative(self): _check(self, _t(4, 5), lambda x: x.pad(((1, -1), (0, 2)), mode="reflect"))
|
|
def test_pad_replicate(self): _check(self, _t(4, 5), lambda x: x.pad(((1, 2), (0, 3)), mode="replicate"))
|
|
def test_pad_replicate_negative(self): _check(self, _t(4, 5), lambda x: x.pad(((1, -1), (0, 2)), mode="replicate"))
|
|
|
|
class TestTensorUOpStack(unittest.TestCase):
|
|
def test_stack_dim0(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=0))
|
|
def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1))
|
|
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
|
|
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
|
|
|
|
class TestTensorUOpConv2d(unittest.TestCase):
|
|
def test_conv2d_basic(self):
|
|
w = _t(1, 1, 2, 2).float()
|
|
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
|
|
def test_conv2d_padded(self):
|
|
w = _t(1, 1, 2, 2).float()
|
|
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=1))
|
|
def test_conv2d_negative_padding(self):
|
|
w = _t(1, 1, 3, 3).float()
|
|
_check(self, _t(1, 1, 5, 5).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, padding=(-1,-1,-1,-1)))
|
|
def test_conv2d_multichannel_bias(self):
|
|
w, b = _t(4, 2, 3, 3).float(), _t(4).float()
|
|
_check(self, _t(2, 2, 5, 5).float(), lambda x: x.conv2d(*(y if isinstance(x, Tensor) else y.uop for y in (w, b))))
|
|
def test_conv2d_stride_dilation(self):
|
|
w = _t(2, 2, 2, 2).float()
|
|
_check(self, _t(1, 2, 6, 6).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, stride=2, dilation=2))
|
|
def test_conv2d_groups(self):
|
|
w = _t(4, 1, 2, 2).float()
|
|
_check(self, _t(1, 4, 4, 4).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop, groups=4))
|
|
def test_conv2d_3d(self):
|
|
w = _t(1, 1, 2, 2, 2).float()
|
|
_check(self, _t(1, 1, 3, 3, 3).float(), lambda x: x.conv2d(w if isinstance(x, Tensor) else w.uop))
|
|
def test_conv_transpose2d_basic(self):
|
|
w = _t(1, 1, 2, 2).float()
|
|
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop))
|
|
def test_conv_transpose2d_stride(self):
|
|
w = _t(1, 1, 2, 2).float()
|
|
_check(self, _t(1, 1, 3, 3).float(), lambda x: x.conv_transpose2d(w if isinstance(x, Tensor) else w.uop, stride=2))
|
|
|
|
class TestTensorUOpEinsum(unittest.TestCase):
|
|
def test_einsum_dot(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij,ij->", x, x))
|
|
def test_einsum_transpose(self): _check(self, _t(2, 3), lambda x: type(x).einsum("ij->ji", x))
|
|
|
|
class TestTensorUOpSoftmax(unittest.TestCase):
|
|
def test_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.softmax())
|
|
def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.softmax(axis=0))
|
|
def test_log_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax())
|
|
def test_log_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax(axis=0))
|
|
|
|
class TestTensorUOpQR(unittest.TestCase):
|
|
def _check(self, t):
|
|
qt, rt = t.qr()
|
|
qu, ru = t.uop.qr()
|
|
self.assertIs(qt.uop, qu)
|
|
self.assertIs(rt.uop, ru)
|
|
def test_qr_square(self): self._check(_t(3, 3).float())
|
|
def test_qr_tall(self): self._check(_t(4, 3).float())
|
|
def test_qr_wide(self): self._check(_t(3, 4).float())
|
|
def test_qr_zero_col(self): self._check(Tensor([[0.0, 1.0], [0.0, 2.0]]))
|
|
def test_qr_batched(self): self._check(_t(2, 3, 3).float())
|
|
|
|
class TestTensorUOpSVD(unittest.TestCase):
|
|
def _check(self, t, **kw):
|
|
ut, st, vt = t.svd(**kw)
|
|
uu, su, vu = t.uop.svd(**kw)
|
|
self.assertIs(ut.uop, uu)
|
|
self.assertIs(st.uop, su)
|
|
self.assertIs(vt.uop, vu)
|
|
def test_svd_square(self): self._check(_t(2, 2).float())
|
|
def test_svd_tall(self): self._check(_t(3, 2).float())
|
|
def test_svd_wide(self): self._check(_t(2, 3).float())
|
|
def test_svd_odd_num(self): self._check(_t(3, 3).float()) # exercises odd-num runoff path
|
|
def test_svd_batched(self): self._check(_t(2, 2, 2).float())
|
|
def test_svd_nonfull(self): self._check(_t(3, 2).float(), full_matrices=False)
|
|
|
|
# UOp.empty / UOp.empty_like are the canonical buffer allocators; Tensor.empty / Tensor.empty_like just forward.
|
|
class TestUOpEmpty(unittest.TestCase):
|
|
def test_empty_dtype_string(self):
|
|
self.assertEqual(UOp.empty((3, 4), dtype="float32").dtype, dtypes.float32)
|
|
|
|
def test_empty_like_dtype_override(self):
|
|
u = Tensor.ones(3, 4).uop.empty_like(dtype=dtypes.int8)
|
|
self.assertEqual((u.shape, u.dtype), ((3, 4), dtypes.int8))
|
|
self.assertTrue(u.has_buffer_identity())
|
|
|
|
def test_empty_like_sharded_to_single_device(self):
|
|
# regression: sharded source, override to single device must yield full logical shape with no axis
|
|
t = Tensor.ones(8, 4).shard(("NULL:0", "NULL:1"), axis=0)
|
|
for dev in ("NULL:2", ("NULL:2",)): # singleton tuple also canonicalizes to single device
|
|
u = t.uop.empty_like(device=dev, dtype=dtypes.int32)
|
|
self.assertEqual((u.shape, u.device, u.dtype, u.axis), ((8, 4), "NULL:2", dtypes.int32, None))
|
|
self.assertTrue(u.has_buffer_identity())
|
|
|
|
def test_empty_direct_singleton_tuple_device(self):
|
|
# regression: direct UOp.empty with a singleton-tuple device + axis must not trip .multi()'s tuple assert
|
|
u = UOp.empty((4,), dtype=dtypes.float32, device=("NULL:0",), axis=0)
|
|
self.assertEqual((u.shape, u.device, u.axis), ((4,), "NULL", None))
|
|
|
|
class TestTensorUOpCreation(unittest.TestCase):
|
|
def test_full(self):
|
|
self.assertIs(_strip_unique(Tensor.full((2, 3), 42).uop), _strip_unique(UOp.full((2, 3), 42)))
|
|
def test_full_kwargs(self):
|
|
self.assertIs(_strip_unique(Tensor.full((2, 3), 42, dtype=dtypes.int8, device="NULL").uop),
|
|
_strip_unique(UOp.full((2, 3), 42, dtype=dtypes.int8, device="NULL")))
|
|
def test_full_symbolic_fill(self):
|
|
t = Tensor.full((2, 3), UOp.variable("x", 1, 10).bind(5))
|
|
self.assertEqual(t.shape, (2, 3))
|
|
def test_zeros(self):
|
|
self.assertIs(_strip_unique(Tensor.zeros(2, 3).uop), _strip_unique(UOp.zeros(2, 3)))
|
|
def test_ones(self):
|
|
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
|
|
def test_invalids(self):
|
|
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids(2, 3, dtype=dtypes.int8)))
|
|
def test_arange(self):
|
|
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
|
|
def test_arange_empty(self):
|
|
self.assertIs(Tensor.arange(5, 5).uop, UOp.arange(5, 5))
|
|
def test_arange_step(self):
|
|
self.assertIs(Tensor.arange(5, 10, 2).uop, UOp.arange(5, 10, 2))
|
|
def test_linspace(self):
|
|
self.assertIs(Tensor.linspace(0, 10, 5).uop, UOp.linspace(0, 10, 5))
|
|
def test_linspace_one_step(self):
|
|
self.assertIs(Tensor.linspace(5, 10, 1).uop, UOp.linspace(5, 10, 1))
|
|
def test_eye(self):
|
|
self.assertIs(Tensor.eye(3).uop, UOp.eye(3))
|
|
def test_eye_rect(self):
|
|
self.assertIs(Tensor.eye(2, 4).uop, UOp.eye(2, 4))
|
|
def test_triu(self):
|
|
t = _t(3, 4)
|
|
self.assertIs(t.triu().uop, t.uop.triu())
|
|
def test_triu_diagonal(self):
|
|
t = _t(3, 4)
|
|
self.assertIs(t.triu(diagonal=1).uop, t.uop.triu(diagonal=1))
|
|
def test_tril(self):
|
|
t = _t(3, 4)
|
|
self.assertIs(t.tril().uop, t.uop.tril())
|
|
def test_tril_diagonal(self):
|
|
t = _t(3, 4)
|
|
self.assertIs(t.tril(diagonal=-1).uop, t.uop.tril(diagonal=-1))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|