import time, math, unittest, functools, platform, warnings, sys import numpy as np from typing import List, Callable import torch from tinygrad.helpers import getenv, DEBUG, DEV, IMAGE, Context from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.renderer.cstyle import QCOMCLRenderer from tinygrad.renderer.nir import NIRRenderer TINY_BACKEND = getenv("TINY_BACKEND") if TINY_BACKEND: import tinygrad.nn.torch # noqa: F401 # pylint: disable=unused-import torch.set_default_device("tiny") warnings.filterwarnings("ignore", message="Non-empty compiler output encountered") FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) COMPILE_ONLY = Device.DEFAULT == "NULL" def slow_test(test_func): return unittest.skipIf(getenv("SKIP_SLOW_TEST"), "Skipping slow test")(test_func) def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, low=-2, high=2): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn ts, tst = prepare_test_op(low, high, shps, vals, forward_only) st = time.monotonic() out = torch_fxn(*ts) torch_fp = time.monotonic() - st # move inputs to a different device, test the device of intermediate tensors are correct if mt:=getenv("MOVE_TENSOR", ""): for t in tst: t.to_(mt) st = time.monotonic() ret = tinygrad_fxn(*tst).realize() tinygrad_fp = time.monotonic() - st def compare(s, tinygrad_output, torch_output, atol, rtol): if COMPILE_ONLY: return if PRINT_TENSORS: print(s, tinygrad_output, torch_output) try: assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}" assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}" if np.issubdtype(tinygrad_output.dtype, np.floating): np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol) else: np.testing.assert_equal(tinygrad_output, torch_output) except Exception as e: raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}") if DEBUG >= 6: np.set_printoptions(linewidth=200, suppress=True) print(ret.numpy()) print(out.detach().cpu().numpy()) compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol) torch_fbp, tinygrad_fbp = np.nan, np.nan if not forward_only and not FORWARD_ONLY and ts and tst: st = time.monotonic() torch_grads = torch.autograd.grad(torch_fxn(*ts).sum(), ts) torch_fbp = time.monotonic() - st st = time.monotonic() # NOTE: we now have to recompute the forward pass since we realized it tiny_grads = tinygrad_fxn(*tst).sum().gradient(*tst) Tensor.realize(*tiny_grads) tinygrad_fbp = time.monotonic() - st for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)): compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol) if sys.stdout.isatty(): print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \ (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") def prepare_test_op(low, high, shps, vals, forward_only=False): if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] else: np.random.seed(0) np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps] ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data] for i in range(len(ts)): # NOTE: torch default int64 for python ints input if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32) tst = [Tensor(x.detach().cpu().numpy()) for x in ts] return ts, tst class TestOps(unittest.TestCase): def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn=None, expected=None, forward_only=False, exact=False, vals=None, low=-1.5, high=1.5): if DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV": self.skipTest('helper_test_exception fails in CI CUDA') ts, tst = prepare_test_op(low, high, shps, vals, forward_only) if tinygrad_fxn is None: tinygrad_fxn = torch_fxn with self.assertRaises(expected) as torch_cm: torch_fxn(*ts) with self.assertRaises(expected) as tinygrad_cm: tinygrad_fxn(*tst) if exact: self.assertEqual(str(torch_cm.exception), str(tinygrad_cm.exception)) if sys.stdout.isatty(): print("\ntesting %40r torch/tinygrad exception: %s / %s" % (shps, torch_cm.exception, tinygrad_cm.exception), end="") def test_full_like(self): a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32) helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True) a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32) helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True) def test_full(self): helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True) def test_negative_dims(self): creation_methods: List[Callable[..., Tensor]] = [ Tensor.empty, Tensor.rand, Tensor.zeros, Tensor.ones, Tensor.randn, Tensor.randint, Tensor.normal, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform ] for method in creation_methods: with self.assertRaises(ValueError): method(-3, 2) with self.assertRaises(ValueError): method((2, -3)) with self.assertRaises(ValueError): method((2, -3, 0)) def test_negative_dims_full(self): with self.assertRaises(ValueError): Tensor.full((-3,), 2) with self.assertRaises(ValueError): Tensor.full((2, -3), 4) with self.assertRaises(ValueError): Tensor.full((2, -3, 0), 4) def test_negative_dims_eye(self): with self.assertRaises(ValueError): Tensor.eye(-3, 3) with self.assertRaises(ValueError): Tensor.eye(3, -3) with self.assertRaises(ValueError): Tensor.eye(-3, -3) def test_negative_dims_kaiming(self): creation_methods = [Tensor.kaiming_uniform, Tensor.kaiming_normal] for method in creation_methods: with self.assertRaises(ValueError): method(-3, 3) with self.assertRaises(ValueError): method((-3, 3), 3) with self.assertRaises(ValueError): method((-3, -3), 3) def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True) helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True) def test_zeros_like(self): a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32) helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True) a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32) helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True) def test_empty_0(self): helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True) def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True) helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True) def test_ones_like(self): a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32) helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32) b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32) helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) helper_test_op([], lambda: torch.eye(3, 5), lambda: Tensor.eye(3, 5), forward_only=True) helper_test_op([], lambda: torch.eye(5, 3), lambda: Tensor.eye(5, 3), forward_only=True) helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True) def test_split(self): def tensor(s): return torch.arange(math.prod(s), dtype=torch.int32).reshape(s), Tensor.arange(math.prod(s)).reshape(s) test_cases = [ (tensor((10,)), 5, {}), (tensor((10,)), [1,4,5], {}), (tensor((10,)), 3, {}), (tensor((3,4,)), 1, {}), (tensor((3,4,)), 1, {'dim':1}), (tensor((4,4,)), [2,2], {}), (tensor((4,4,)), [2,2], {'dim':1}), (tensor((10000,)), 2500, {}), ] for (tor, ten), sizes, args in test_cases: tor_splits, ten_splits = tor.split(sizes, **args), ten.split(sizes, **args) assert len(tor_splits) == len(ten_splits) for tor_chunk, ten_chunk in zip(tor_splits, ten_splits): helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True) def test_chunk(self): tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 1) ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1) assert len(tor) == len(ten) for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 0) ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0) assert len(tor) == len(ten) for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(3, -1) ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1) assert len(tor) == len(ten) for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) tor = torch.arange(13, dtype=torch.int32).repeat(8, 3, 3).chunk(3, -2) ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2) assert len(tor) == len(ten) for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) def test_unfold(self): helper_test_op([(8,)], lambda x: x.unfold(0, 2, 1)) helper_test_op([(8,)], lambda x: x.unfold(0, 2, 2)) helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3)) helper_test_op([(3,3,3)], lambda x: x.unfold(2, 2, 8)) helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8)) helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2)) self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError) self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), expected=IndexError) self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError) self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), expected=RuntimeError) def test_meshgrid(self): x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.]) y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.]) z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.]) for indexing in ("ij", "xy"): tor = torch.meshgrid(x, indexing=indexing) ten = xt.meshgrid(indexing=indexing) self.assertEqual(len(tor), len(ten)) for tor_i, ten_i in zip(tor, ten): helper_test_op([], lambda: tor_i, lambda: ten_i) tor = torch.meshgrid(x, y, indexing=indexing) ten = xt.meshgrid(yt, indexing=indexing) self.assertEqual(len(tor), len(ten)) for tor_i, ten_i in zip(tor, ten): helper_test_op([], lambda: tor_i, lambda: ten_i) tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing) ten = xt.meshgrid(Tensor(10.), yt, zt, indexing=indexing) self.assertEqual(len(tor), len(ten)) for tor_i, ten_i in zip(tor, ten): helper_test_op([], lambda: tor_i, lambda: ten_i) self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError) def test_arange(self): helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True) helper_test_op([], lambda: torch.arange(36, dtype=torch.int32), lambda: Tensor.arange(36), forward_only=True) helper_test_op([], lambda: torch.arange(5, 10, 3, dtype=torch.int32), lambda: Tensor.arange(5, 10, 3), forward_only=True) helper_test_op([], lambda: torch.arange(10, 5, -3, dtype=torch.int32), lambda: Tensor.arange(10, 5, -3), forward_only=True) helper_test_op([], lambda: torch.arange(11, 5, -3, dtype=torch.int32), lambda: Tensor.arange(11, 5, -3), forward_only=True) helper_test_op([], lambda: torch.arange(1, 78, 2, dtype=torch.int32), lambda: Tensor.arange(1, 78, 2), forward_only=True) helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True) helper_test_op([], lambda: torch.arange(-30.2, -0.3, 0.75), lambda: Tensor.arange(-30.2, -0.3, 0.75), forward_only=True) helper_test_op([], lambda: torch.arange(-50.3, -380.2, -2.25), lambda: Tensor.arange(-50.3, -380.2, -2.25), forward_only=True) # boundary values that fit exactly in int8 (min=-128, max=127) helper_test_op([], lambda: torch.arange(128, dtype=torch.int8), lambda: Tensor.arange(128, dtype=dtypes.int8), forward_only=True) helper_test_op([], lambda: torch.arange(-128, 128, dtype=torch.int8), lambda: Tensor.arange(-128, 128, dtype=dtypes.int8), forward_only=True) helper_test_op([], lambda: torch.arange(127, -129, -1, dtype=torch.int8), lambda: Tensor.arange(127, -129, -1, dtype=dtypes.int8), forward_only=True) # overflow: tinygrad raises (torch silently wraps) with self.assertRaises(OverflowError): Tensor.arange(2**33, dtype=dtypes.int) with self.assertRaises(OverflowError): Tensor.arange(129, dtype=dtypes.int8) # last=128 overflows with self.assertRaises(OverflowError): Tensor.arange(-129, 128, dtype=dtypes.int8) # start=-129 overflows with self.assertRaises(OverflowError): Tensor.arange(128, 0, -1, dtype=dtypes.int8) # start=128 overflows with self.assertRaises(OverflowError): Tensor.arange(127, -130, -1, dtype=dtypes.int8) # last=-129 overflows def test_arange_big(self): helper_test_op([], lambda: torch.arange(256, dtype=torch.int32), lambda: Tensor.arange(256), forward_only=True) def test_arange_4096(self): helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True) def test_linspace(self): helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True) helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True) helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True) helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True) helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True) helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True) helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype="int32"), forward_only=True) helper_test_op([], lambda: torch.linspace(5, 10, 20, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 20, dtype="int32"), forward_only=True) helper_test_op([], lambda: torch.linspace(5, -5, 20, dtype=torch.int32), lambda: Tensor.linspace(5, -5, 20, dtype="int32"), forward_only=True) self.helper_test_exception([], lambda: torch.linspace(5, 10, 3, dtype=torch.bool), lambda: Tensor.linspace(5, 10, 3, dtype="bool"), expected=(RuntimeError, ValueError)) self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError)) def test_sum_fake(self): helper_test_op([(256, 1)], lambda x: x.sum(axis=1)) def test_sum_collapse(self): helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True) def test_sum_collapse_neg(self): helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True) def test_sum_pad_collapse(self): helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True) def test_sum_twice(self): helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 1)).sum()) helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 2)).sum()) helper_test_op([(4, 4, 4)], lambda x: x.sum((1, 2)).sum()) # this is more complex and won't fold for a while def test_sum_cat_collapse(self): helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True) def test_max_dont_collapse(self): helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True) def test_where(self): helper_test_op( [(100,)], lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32), lambda x: (x > 0.5).where(4, 2), forward_only=True) for shps in [[(8,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]: helper_test_op( shps, lambda x, a, b: torch.where(x > 0.5, a, b), lambda x, a, b: (x > 0.5).where(a, b), forward_only=True) def test_where_permute(self): helper_test_op( [(5, 5)], lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32).permute((1, 0)), lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True) def _test_cmp(self, fxn, reverse=True): # test different dtypes helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]]) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]]) # test broadcasting for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: helper_test_op(shps, fxn, fxn, forward_only=True) # test cmp with const helper_test_op(None, lambda x,y: fxn(x,2), lambda x,y: fxn(x,2), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) if reverse: helper_test_op(None, lambda x,y: fxn(2,y), lambda x,y: fxn(2,y), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) # test special floats # TODO: fix nan specials = [0.0, 1.0, -1.0, math.inf, -math.inf]#, math.nan] for s0 in specials: for s1 in specials: helper_test_op(None, fxn, fxn, forward_only=True, vals=[[s0], [s1]]) def test_cmp_eq(self): self._test_cmp(lambda x,y: x==y, reverse=False) def test_cmp_gt(self): self._test_cmp(lambda x,y: x>y) def test_cmp_ge(self): self._test_cmp(lambda x,y: x>=y) def test_cmp_lt(self): self._test_cmp(lambda x,y: x> 0, lambda: (ten >> 0).cast(dtypes.int32), forward_only=True) helper_test_op([], lambda: tor >> 2, lambda: (ten >> 2).cast(dtypes.int32), forward_only=True) helper_test_op([], lambda: tor >> 31, lambda: (ten >> 31).cast(dtypes.int32), forward_only=True) helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True) helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True) def test_lshift_signed(self): data = [[-1, -3, 1, 7], [0, -2147483648, 2147483647, -1]] tor = torch.tensor(data, dtype=torch.int32) ten = Tensor(data, dtype=dtypes.int) helper_test_op([], lambda: tor << 0, lambda: ten << 0, forward_only=True) helper_test_op([], lambda: tor << 2, lambda: ten << 2, forward_only=True) helper_test_op([], lambda: tor << 8, lambda: ten << 8, forward_only=True) helper_test_op([], lambda: tor << 31, lambda: ten << 31, forward_only=True) def test_rshift_signed(self): data = [[-1, -3, 1, 7], [0, -2147483648, 2147483647, -1]] tor = torch.tensor(data, dtype=torch.int32) ten = Tensor(data, dtype=dtypes.int) helper_test_op([], lambda: tor >> 0, lambda: ten >> 0, forward_only=True) helper_test_op([], lambda: tor >> 2, lambda: ten >> 2, forward_only=True) helper_test_op([], lambda: tor >> 8, lambda: ten >> 8, forward_only=True) helper_test_op([], lambda: tor >> 31, lambda: ten >> 31, forward_only=True) def test_idiv_shift_rewrite_negative(self): a = Tensor(-5).div(2, rounding_mode="trunc").item() b = Tensor(-5).contiguous().div(2, rounding_mode="trunc").item() self.assertEqual(a, b) self.assertEqual(Tensor(-1).contiguous().div(4, rounding_mode="trunc").item(), 0) # NOTE this is trunc-div behaviour @unittest.skipIf(DEV.renderer == "NAK", "MUFU.SIN is not accurate enough") def test_sin(self): helper_test_op([(45,65)], lambda x: x.sin()) helper_test_op([()], lambda x: x.sin()) # works on real CUDA but not CI if not ((DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"): helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]]) helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]], atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3) @unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend") @unittest.skipIf(DEV.renderer == "NAK", "MUFU.SIN is not accurate enough") def test_cos(self): helper_test_op([(45,65)], lambda x: x.cos()) helper_test_op([()], lambda x: x.cos()) if not ((DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"): helper_test_op(None, lambda x: x.cos(), vals=[[math.nan, math.inf, -math.inf, 0.0]]) helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]], atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3) @unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend") @unittest.skipIf(DEV.renderer == "NAK", "MUFU.SIN is not accurate enough") def test_tan(self): # NOTE: backward has much higher diff with input close to pi/2 and -pi/2 helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5) helper_test_op([(45,65)], lambda x: x.tan(), low=-5, high=5) helper_test_op([()], lambda x: x.tan()) if not ((DEV.interface.startswith("MOCK") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"): helper_test_op(None, lambda x: x.tan(), vals=[[math.nan, math.inf, -math.inf, 0.0]]) helper_test_op(None, lambda x: x.tan(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]], atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3) def test_asin(self): helper_test_op([(45,65)], lambda x: x.asin(), low=-1, high=1) helper_test_op([(45,65)], lambda x: x.asin(), low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.asin(), low=300, high=303) def test_acos(self): # high grad atol helper_test_op([(45,65)], lambda x: x.acos(), low=-1, high=1) helper_test_op([(45,65)], lambda x: x.acos(), low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.acos(), low=300, high=303) def test_atan(self): helper_test_op([(45,65)], lambda x: x.atan()) helper_test_op([(45,65)], lambda x: x.atan(), low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.atan(), low=300, high=303) def test_relu(self): helper_test_op([(64,64)], lambda x: x.relu()) helper_test_op([()], lambda x: x.relu()) def test_relu_exact(self): helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]]) def test_relu_maximum_exact(self): helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]]) def test_leaky_relu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu) helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu) def test_celu(self): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) def test_selu(self): helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu) helper_test_op([()], torch.nn.functional.selu, Tensor.selu) def test_silu(self): helper_test_op([(45,65)], torch.nn.functional.silu, Tensor.silu) helper_test_op([()], torch.nn.functional.silu, Tensor.silu) def test_swish(self): helper_test_op([(45,65)], torch.nn.functional.silu, Tensor.swish) helper_test_op([()], torch.nn.functional.silu, Tensor.swish) def test_abs(self): helper_test_op([(45,65)], torch.abs, Tensor.abs) helper_test_op([()], torch.abs, Tensor.abs) def test_abs_exact(self): for v in [-1., -0., 0., 1., math.inf, -math.inf, math.nan, -math.nan]: # abs(nan) gradient is undefined: torch=0, tinygrad=1, jax=-1 helper_test_op(None, torch.abs, Tensor.abs, vals=[[v]], forward_only=math.isnan(v)) def test_log(self): helper_test_op([(45,65)], torch.log, Tensor.log) helper_test_op(None, torch.log, Tensor.log, vals=[[math.inf, -math.inf, math.nan]]) helper_test_op([()], torch.log, Tensor.log) def test_log10(self): helper_test_op([(45,65)], torch.log10, Tensor.log10) helper_test_op(None, torch.log10, Tensor.log10, vals=[[math.inf, -math.inf, math.nan]]) helper_test_op([()], torch.log10, Tensor.log10) def test_log2(self): helper_test_op([(45,65)], torch.log2, Tensor.log2) helper_test_op(None, torch.log2, Tensor.log2, vals=[[math.inf, -math.inf, math.nan]]) helper_test_op([()], torch.log2, Tensor.log2) def test_exp(self): helper_test_op([(45,65)], torch.exp, Tensor.exp) helper_test_op(None, torch.exp, Tensor.exp, vals=[[math.inf, -math.inf, math.nan]]) helper_test_op([()], torch.exp, Tensor.exp) def test_exp2(self): helper_test_op([(45,65)], torch.exp2, Tensor.exp2) helper_test_op(None, torch.exp2, Tensor.exp2, vals=[[math.inf, -math.inf, math.nan]]) helper_test_op([()], torch.exp2, Tensor.exp2) def test_sign(self): helper_test_op([(45,65)], torch.sign, Tensor.sign) helper_test_op([()], torch.sign, Tensor.sign) def test_sign_exact(self): helper_test_op(None, torch.sign, Tensor.sign, vals=[[-1.,0,1]]) def test_copysign(self): helper_test_op([(45,65), (45,65)], torch.copysign, Tensor.copysign) helper_test_op([(45,65), (45,1)], torch.copysign, Tensor.copysign) helper_test_op([(45,1), (1,65)], torch.copysign, Tensor.copysign) helper_test_op([(), ()], torch.copysign, Tensor.copysign) @unittest.skipIf(Device.DEFAULT == "WEBGPU", "fails locally") def test_copysign_exact(self): # NOTE: -nan (negative nan) is not tested because we can't detect its sign bit without bitcast v = [-1., -0., 0., 1., math.inf, -math.inf, math.nan] for i in v: for j in v: # torch returns nan gradient for copysign at inf, but mathematically (and per jax) it's ±1 helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]], forward_only=math.isinf(i) or math.isnan(i)) def test_logaddexp(self): helper_test_op([(45,65), (45,65)], torch.logaddexp, Tensor.logaddexp) helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-1.], [-1.0, 2, 3]]) helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-100.0, -200, -300], [-1.0, 2, 3]]) helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[1.0, 2000, 30000], [-1.0, 2, 3]]) def test_softsign(self): helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign) helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign) def test_softsign_exact(self): helper_test_op(None, torch.nn.functional.softsign, Tensor.softsign, vals=[[-1.,0,1]]) def test_sigmoid(self): helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) helper_test_op([()], torch.sigmoid, Tensor.sigmoid) def test_sigmoid_extreme(self): helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400) helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300) x = Tensor([300.0]) self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0) x = Tensor([-300.0]) self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0) def test_sigmoid_alt_extreme(self): def sigmoid(x:Tensor): return x.exp() / (1 + x.exp()) x = Tensor([300.0]) self.assertAlmostEqual(sigmoid(x)[0].gradient(x)[0].item(), 0.0) x = Tensor([-300.0]) self.assertAlmostEqual(sigmoid(x)[0].gradient(x)[0].item(), 0.0) def test_logsigmoid(self): helper_test_op([(45,65)], torch.nn.functional.logsigmoid, Tensor.logsigmoid) helper_test_op([()], torch.nn.functional.logsigmoid, Tensor.logsigmoid) def test_hardsigmoid(self): helper_test_op([(45,65)], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid) helper_test_op([()], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid) def test_hardsigmoid_extreme(self): helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400) helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300) def test_softplus(self): helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=3), lambda t: Tensor.softplus(t, beta=3), grad_atol=1e-6) helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=1/3), lambda t: Tensor.softplus(t, beta=1/3), grad_atol=1e-6) helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=300, high=400) helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=-400, high=-300) helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) def test_erf(self): helper_test_op([(45,65)], torch.erf, Tensor.erf) helper_test_op([(45,65)], torch.erf, Tensor.erf, low=300, high=400) helper_test_op([(45,65)], torch.erf, Tensor.erf, low=-400, high=-300) helper_test_op([()], torch.erf, Tensor.erf) def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), lambda x: Tensor.gelu(x, approximate="tanh")) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="none"), lambda x: Tensor.gelu(x, approximate="none")) def test_gelu_extreme(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), lambda x: Tensor.gelu(x, approximate="tanh"), low=300, high=400) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), lambda x: Tensor.gelu(x, approximate="tanh"), low=-400, high=-300) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="none"), lambda x: Tensor.gelu(x, approximate="none"), low=300, high=400) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="none"), lambda x: Tensor.gelu(x, approximate="none"), low=-400, high=-300) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) def test_quick_gelu_extreme(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=400) helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-400, high=-300) def test_elu(self): helper_test_op([(45,65)], torch.nn.functional.elu, Tensor.elu) helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1)) helper_test_op([()], torch.nn.functional.elu, Tensor.elu) def test_relu6(self): helper_test_op([(45,65)], torch.nn.functional.relu6, Tensor.relu6) helper_test_op([()], torch.nn.functional.relu6, Tensor.relu6) def test_hardswish(self): helper_test_op([(45,65)], torch.nn.functional.hardswish, Tensor.hardswish, grad_atol=1e-6) helper_test_op([()], torch.nn.functional.hardswish, Tensor.hardswish, grad_atol=1e-6) def test_mish(self): helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish) helper_test_op([()], torch.nn.functional.mish, Tensor.mish) def test_small_cumsum(self): helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) @slow_test def test_simple_cumsum(self): helper_test_op([(512)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) @slow_test def test_cumsum(self): helper_test_op([()], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) self.helper_test_exception([()], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), expected=IndexError) helper_test_op([(20,)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) self.helper_test_exception([(20,)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), expected=IndexError) self.helper_test_exception([(20,)], lambda x: torch.cumsum(x, dim=-2), lambda x: Tensor.cumsum(x, axis=-2), expected=IndexError) helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1)) def test_cumsum_zero_axis(self): helper_test_op([(2,0,4)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) def test_small_cumprod(self): helper_test_op([(10)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) @slow_test def test_simple_cumprod(self): helper_test_op([(512)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) helper_test_op([(1022)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) @slow_test def test_cumprod(self): helper_test_op([()],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) self.helper_test_exception([()],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError) helper_test_op([(20,)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError) self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=-2),lambda x: Tensor.cumprod(x, axis=-2),expected=IndexError) helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1)) helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2)) helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=-1),lambda x: Tensor.cumprod(x, axis=-1)) def test_cumprod_zero_axis(self): helper_test_op([(2, 0, 4)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1)) helper_test_op([(0, 3)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2)) def test_small_cummax(self): helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) @slow_test def test_simple_cummax(self): helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) @slow_test def test_cummax(self): helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([()], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError) helper_test_op([(5,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(5,)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) self.helper_test_exception([(5,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError) self.helper_test_exception([(5,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2)[0], expected=IndexError) helper_test_op([(5,6)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(5,6)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) helper_test_op([(5,6)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0]) helper_test_op([(5,6)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True) helper_test_op([(5,6,7)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0]) helper_test_op([(5,6,7)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True) helper_test_op([(5,6,7)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)[0]) helper_test_op([(5,6,7)], lambda x: torch.cummax(x, dim=-1).indices.int(), lambda x: Tensor.cummax(x, axis=-1)[1], forward_only=True) def test_cummax_zero_axis(self): helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0]) helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True) helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0]) helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True) def test_small_cummin(self): helper_test_op([(10)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(10)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) @slow_test def test_simple_cummin(self): helper_test_op([(512)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(512)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) helper_test_op([(1022)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(1022)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) @slow_test def test_cummin(self): helper_test_op([()], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([()], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) self.helper_test_exception([()], lambda x: torch.cummin(x, dim=1).values, lambda x: Tensor.cummin(x, axis=1)[0], expected=IndexError) helper_test_op([(5,)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(5,)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) self.helper_test_exception([(5,)], lambda x: torch.cummin(x, dim=1).values, lambda x: Tensor.cummin(x, axis=1)[0], expected=IndexError) self.helper_test_exception([(5,)], lambda x: torch.cummin(x, dim=-2).values, lambda x: Tensor.cummin(x, axis=-2)[0], expected=IndexError) helper_test_op([(5,6)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(5,6)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) helper_test_op([(5,6)], lambda x: torch.cummin(x, dim=1).values, lambda x: Tensor.cummin(x, axis=1)[0]) helper_test_op([(5,6)], lambda x: torch.cummin(x, dim=1).indices.int(), lambda x: Tensor.cummin(x, axis=1)[1], forward_only=True) helper_test_op([(5,6,7)], lambda x: torch.cummin(x, dim=2).values, lambda x: Tensor.cummin(x, axis=2)[0]) helper_test_op([(5,6,7)], lambda x: torch.cummin(x, dim=2).indices.int(), lambda x: Tensor.cummin(x, axis=2)[1], forward_only=True) helper_test_op([(5,6,7)], lambda x: torch.cummin(x, dim=-1).values, lambda x: Tensor.cummin(x, axis=-1)[0]) helper_test_op([(5,6,7)], lambda x: torch.cummin(x, dim=-1).indices.int(), lambda x: Tensor.cummin(x, axis=-1)[1], forward_only=True) def test_cummin_zero_axis(self): helper_test_op([(2,0,4)], lambda x: torch.cummin(x, dim=1).values, lambda x: Tensor.cummin(x, axis=1)[0]) helper_test_op([(2,0,4)], lambda x: torch.cummin(x, dim=1).indices.int(), lambda x: Tensor.cummin(x, axis=1)[1], forward_only=True) helper_test_op([(0,3)], lambda x: torch.cummin(x, dim=0).values, lambda x: Tensor.cummin(x, axis=0)[0]) helper_test_op([(0,3)], lambda x: torch.cummin(x, dim=0).indices.int(), lambda x: Tensor.cummin(x, axis=0)[1], forward_only=True) helper_test_op([(2,3,0)], lambda x: torch.cummin(x, dim=2).values, lambda x: Tensor.cummin(x, axis=2)[0]) helper_test_op([(2,3,0)], lambda x: torch.cummin(x, dim=2).indices.int(), lambda x: Tensor.cummin(x, axis=2)[1], forward_only=True) def test_argmax(self): # check if it returns the first index for multiple occurrences helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]]) helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]]) if not COMPILE_ONLY: np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), 0) np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), 1) helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True) # regression test for bitwise_not then argmax helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]]) helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]]) helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]]) # NOTE: torch does not support this on bool helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]]) @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_argmin(self): # check if it returns the first index for multiple occurrences helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]]) helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]]) if not COMPILE_ONLY: np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), 0) np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), 1) helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True) helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[0, -2**31]]) helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[-2**31, 0]]) # NOTE: torch does not support this on bool helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]]) def test_sort(self): for shape in [(0,), (0,5), (1,), (1,5)]: helper_test_op([shape], lambda x: x.sort(0).values, lambda x: x.sort(0)[0], forward_only=True) helper_test_op([shape], lambda x: x.sort(0).indices.type(torch.int32), lambda x: x.sort(0)[1], forward_only=True) for dim in [-1, 0, 1]: for descending in [True, False]: helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True) helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1], forward_only=True) # repeated values helper_test_op(None, lambda x: x.sort(stable=True).values, lambda x: x.sort()[0], forward_only=True, vals=[[0, 1] * 9]) helper_test_op(None, lambda x: x.sort(stable=True).indices.type(torch.int32), lambda x: x.sort()[1], forward_only=True, vals=[[0, 1] * 9]) helper_test_op(None, lambda x: x.sort(stable=True, descending=True).values, lambda x: x.sort(descending=True)[0], forward_only=True, vals=[[0, 1] * 9]) helper_test_op(None, lambda x: x.sort(stable=True, descending=True).indices.type(torch.int32), lambda x: x.sort(descending=True)[1], forward_only=True, vals=[[0, 1] * 9]) def test_argsort(self): for dim in [-1, 0, 1]: for descending in [True, False]: helper_test_op([(8,8,6)], lambda x: torch.argsort(x, dim=dim, descending=descending, stable=True).type(torch.int32), lambda x: x.argsort(dim, descending), forward_only=True) def test_topk(self): helper_test_op([(8)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True) helper_test_op([(8)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True) for dim in [0, 1, -1]: for largest in [True, False]: for sorted_ in [True]: # TODO support False helper_test_op([(5,5,4)], lambda x: x.topk(4, dim, largest, sorted_).values, lambda x: x.topk(4, dim, largest, sorted_)[0], forward_only=True) helper_test_op([(5,5,4)], lambda x: x.topk(4, dim, largest, sorted_).indices.type(torch.int32), lambda x: x.topk(4, dim, largest, sorted_)[1], forward_only=True) # repeated values if not COMPILE_ONLY: value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3) np.testing.assert_equal(value.numpy(), [1, 1, 1]) np.testing.assert_equal(indices.numpy(), [0, 1, 3]) value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False) np.testing.assert_equal(value.numpy(), [0, 0, 0]) np.testing.assert_equal(indices.numpy(), [2, 4, 6]) self.helper_test_exception([(4)], lambda x: x.topk(5), expected=(RuntimeError, ValueError)) @slow_test def test_einsum(self): # scalar helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a)) # matrix transpose helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a)) helper_test_op([(10,10)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a)) helper_test_op([(4,6,8)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a)) helper_test_op([(4,6,8)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a)) # no -> and empty rhs helper_test_op([(4,6),(6,8)], lambda a, b: torch.einsum('ij,jk', a, b), lambda a, b: Tensor.einsum('ij,jk', a, b)) # sum all elements helper_test_op([(4,6,8)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a)) # column sum helper_test_op([(5,5)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a)) # row sum helper_test_op([(5,5)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a)) # matrix-vector multiplication helper_test_op([(3,4), (4,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b)) # matrix-matrix multiplication helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b)) # matrix-matrix multiplication, different letter order helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b)) # dot product helper_test_op([(5),(5)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b])) # hadamard product helper_test_op([(5,6),(5,6)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b)) # outer product helper_test_op([(5,), (5,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b)) # batch matrix multiplication helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b])) # batch matrix multiplication, result permuted helper_test_op([(2,4,5),(2,5,7)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b])) # batch matrix multiplication, result & input permuted helper_test_op([(4,2,5),(2,5,7)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b])) # batch matrix multiplication, result with different letters helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b])) # tensor contraction helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b), lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5) # tensor contraction, input permuted helper_test_op([(3,8,10,5),(11,5,7,13,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b), lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5) # tensor contraction, result with different letters helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b), lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5) # bilinear transformation helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c])) @slow_test def test_einsum_ellipsis(self): """The expected behavior for einsum is described in the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.einsum.html""" # test ellipsis helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd -> ...ij', [a, b]), lambda a, b: Tensor.einsum('...id, ...jd -> ...ij', [a, b])) # ellipsis will come first in the output before the subscript labels, if rhs is not specified helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd', [a, b]), lambda a, b: Tensor.einsum('...id, ...jd', [a, b])) # multiple ellipsis in different operands with different shapes are allowed helper_test_op([(2, 3, 4, 5), (5, 2, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b])) # match torch ellipsis handling helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]), lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b])) # multiple ellipsis in one operand are not allowed self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]), lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]), expected=(RuntimeError, IndexError)) # multiple ellipsis must broadcast together self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError) def test_einsum_trace(self): # inner product helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b)) # simple diagonal helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a)) # trace (sum of diagonal) helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a)) # batch diagonal helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a)) # batch trace helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a)) def test_einsum_shape_check(self): self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]), lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError) def test_einsum_arity_check1(self): self.helper_test_exception([(10,15), (15,20), (20,10)], lambda a, b, c: torch.einsum('ij,jk->ij', [a, b, c]), lambda a, b, c: Tensor.einsum('ij,jk->ij', [a, b, c]), expected=(ValueError, RuntimeError)) def test_einsum_arity_check2(self): self.helper_test_exception([(10,10)], lambda a: torch.einsum('ij,jk->ij', a), lambda a: Tensor.einsum('ij,jk->ij', a), expected=(ValueError, RuntimeError)) @unittest.skipIf(IMAGE>0, "no 1d dot for images") def test_dot_1d(self): helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot) self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) @slow_test def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) with self.assertRaises(RuntimeError): a = Tensor(3.14) a.matmul(a) def test_mulacc_with_zero_strides(self): helper_test_op( [], lambda: torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), lambda: Tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(Tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), forward_only=True ) a = [[1.,1.,1.,1.], [1.,1.,1.,1.]] b = [1.,1.,1.,1.] helper_test_op( [], lambda: torch.tensor(a).reshape((2,4,1)).expand(2,4,3).mul(torch.tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), lambda: Tensor(a).reshape((2,4,1)).expand(2,4,3).mul(Tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), forward_only=True ) helper_test_op( [], lambda: torch.ones((1,2)).matmul(torch.ones((2,3))), lambda: Tensor.ones((1,2)).dot(Tensor.ones((2,3))), forward_only=True ) def test_matmul_simple(self): helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot) def test_matmul(self): helper_test_op([(64), (64,99)], lambda x,y: x.matmul(y), Tensor.dot) @unittest.skipIf(IMAGE>0, "no batched matmul on images") def test_matmul_batched(self): helper_test_op([(3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot) @unittest.skipIf(IMAGE>0, "no batched matmul on images") def test_matmul_batched_vector(self): helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot) def test_small_gemm(self): helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y) def test_9_gemm(self): helper_test_op([(9,9), (9,9)], lambda x,y: x.matmul(y), lambda x,y: x@y) def test_small_gemm_padded(self): helper_test_op([(9,9), (9,9)], lambda x,y: torch.nn.functional.pad(x, (0,7,0,7)).matmul(torch.nn.functional.pad(y, (0,7,0,7))), lambda x,y: x.pad(((0,7),(0,7)))@y.pad(((0,7),(0,7)))) def test_small_gemm_range(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8), np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) @unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "not precise enough when emulating") def test_gemm_fp16(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3, grad_atol=5e-3, grad_rtol=5e-3) def test_gemm(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y)) @slow_test def test_big_gemm(self): helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), atol=1e-4) @unittest.skipIf(IMAGE>0, "no 0 in shape matmul on images") def test_gemm_with_zeros_shape(self): helper_test_op([(8,8), (8,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0,8), (8,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0,8), (8,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(8,0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0,0), (0,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) @slow_test def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) with self.assertRaises(RuntimeError): a = Tensor(3.14) b = Tensor.ones(3,3) a @ b @slow_test def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) def test_sum_simple(self): helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) # NOTE: simple test for locals # FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full def test_sum_full(self): helper_test_op([(16384)], lambda x: x.sum()) def test_sum_relu(self): helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu()) def test_sum_tiny(self): helper_test_op([(4,2,2)], lambda x: x.sum(axis=(0,2))) def test_sum(self): helper_test_op([(45,3)], lambda x: x.sum()) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3)) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1)) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1, keepdim=True)) helper_test_op([()], lambda x: x.sum()) helper_test_op([()], lambda x: x.sum(0)) helper_test_op([()], lambda x: x.sum(-1)) helper_test_op([()], lambda x: x.sum(())) self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), expected=IndexError) self.helper_test_exception([()], lambda x: x.sum(1), expected=IndexError) self.helper_test_exception([()], lambda x: x.sum((1,)), expected=IndexError) def test_sum_dtype_arg(self): helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32)) if dtypes.float64 in Device[Device.DEFAULT].renderer.supported_dtypes(): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(dtype=dtypes.float64)) with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(dtype="") def test_sum_with_zeros_shape(self): helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,))) helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1))) def test_prod(self): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]]) with Context(NOOPT=1): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]]) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=3), lambda x: x.prod(axis=3)) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1), lambda x: x.prod(axis=1)) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1, keepdim=True), lambda x: x.prod(axis=1, keepdim=True)) helper_test_op([()], lambda x: x.prod()) helper_test_op([()], lambda x: x.prod(0)) helper_test_op([()], lambda x: x.prod(-1)) def test_prod_dtype_arg(self): with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(dtype="") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_min(self): helper_test_op([(3,3)], lambda x: x.min()) helper_test_op([(45,3)], lambda x: x.min()) helper_test_op([(45,3)], lambda x: x.min().mul(0.5)) helper_test_op([()], lambda x: x.min()) helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[0, -2**31]]) helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[-2**31, 0]]) helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[True, False]]) def test_max(self): helper_test_op([(45,3)], lambda x: x.max()) helper_test_op([(45,3)], lambda x: x.max().mul(0.5)) helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) helper_test_op([()], lambda x: x.max()) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]]) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]]) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]]) def test_const_reduce(self): helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).sum(), lambda x: (x.full_like(2)).sum(), forward_only=True) helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).prod(), lambda x: (x.full_like(2)).prod(), forward_only=True) helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).max(), lambda x: (x.full_like(2)).max(), forward_only=True) def test_any(self): helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[True, True]], forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[True, False]], forward_only=True) helper_test_op(None, lambda x: x.any(), vals=[[False, False]], forward_only=True) helper_test_op([()], lambda x: x.any(), forward_only=True) def test_any_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.any(axis=(1,2)), forward_only=True) def test_any_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.any(axis=(1,3)), forward_only=True) @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_all(self): helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True) helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True) helper_test_op([()], lambda x: x.all(), forward_only=True) def test_all_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.all(axis=(1,2)), forward_only=True) def test_all_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.all(axis=(1,3)), forward_only=True) def test_all_large(self): for exp in [15, 16, 20]: helper_test_op(None, lambda: torch.ones(2**exp).bool().all(), lambda: Tensor.ones(2**exp).bool().all(), vals=[], forward_only=True) def test_isclose(self): helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x), forward_only=True) helper_test_op([(3, 4, 5, 6), (3, 4, 5, 6)], lambda x,y: x.isclose(y), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x, equal_nan=True), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-6), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-9), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-6, atol=0.0), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-9, atol=0.0), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-6, rtol=0.01), forward_only=True) helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x + 1e-9, rtol=0.01), forward_only=True) helper_test_op(None, lambda x,y: x.isclose(y), vals=[[1e-7, 1e-8, 1e-9], [0.0, 0.0, 0.0]], forward_only=True) def test_isclose_edge_cases(self): for a in [math.inf, -math.inf, math.nan, 0.0]: for b in [math.inf, -math.inf, math.nan, 0.0]: helper_test_op(None, lambda x,y: x.isclose(y), vals=[[a], [b]], forward_only=True) helper_test_op(None, lambda x,y: x.isclose(y, equal_nan=True), vals=[[a], [b]], forward_only=True) def test_mean(self): helper_test_op([(3,4,5,6)], lambda x: x.mean()) helper_test_op([()], lambda x: x.mean()) def test_mean_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2))) def test_mean_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3))) @slow_test def test_var(self): helper_test_op([(15, 25, 35)], lambda x: x.var()) helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5)) # TODO: fix this # helper_test_op([(10, 2)], lambda x: x.var(correction=50)) @slow_test def test_var_axis(self): helper_test_op([(15, 25, 35)], lambda x: x.var(0)) helper_test_op([(15, 25, 35)], lambda x: x.var(2)) helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2])) helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0)) def test_var_zero_in_axis(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="var\\(\\): degrees of freedom is <= 0") helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3))) helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0)) helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5)) def test_var_one_in_axis(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="var\\(\\): degrees of freedom is <= 0") helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3))) # TODO: fix backward helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=5), forward_only=True) helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4), correction=5), forward_only=True) helper_test_op([(1,)], lambda x: x.var(axis=(0,), correction=0)) helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=0)) helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4))) helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4), correction=0)) def test_var_keepdim(self): helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True)) helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0)) @slow_test def test_std(self): helper_test_op([(15, 25, 35)], lambda x: x.std()) helper_test_op([(15, 25, 35)], lambda x: x.std(correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.std(correction=5)) @slow_test def test_std_axis(self): helper_test_op([(15, 25, 35)], lambda x: x.std(0)) helper_test_op([(15, 25, 35)], lambda x: x.std(2)) helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2])) helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0)) def test_std_zero_in_axis(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="std\\(\\): degrees of freedom is <= 0") helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3))) helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0)) helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5)) def test_std_one_in_axis(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="std\\(\\): degrees of freedom is <= 0") helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3))) helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=5)) helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4), correction=5)) # TODO: fix backward helper_test_op([(1,)], lambda x: x.std(axis=(0,), correction=0), forward_only=True) helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=0), forward_only=True) helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4))) helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4), correction=0)) def test_std_keepdim(self): helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True)) helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0)) @slow_test def test_std_mean(self): helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x)), lambda x: Tensor.stack(*x.std_mean())) helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x, correction=5)), lambda x: Tensor.stack(*x.std_mean(correction=5))) helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x, keepdim=True, correction=0)), lambda x: Tensor.stack(*x.std_mean(keepdim=True, correction=0))) helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))), lambda x: Tensor.stack(*x.std_mean(axis=(1,2)))) def test_std_mean_loaded_nan(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="std_mean\\(\\): degrees of freedom is <= 0") helper_test_op([(1,0,3,0,5)], lambda x: torch.stack(torch.std_mean(x, axis=(1,3))), lambda x: Tensor.stack(*x.std_mean(axis=(1,3)))) def test_softmax(self): helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) def test_softmax_other_axis(self): helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=2e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=2e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=2e-7) def test_softmax_argmax(self): helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32), lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: x.softmax(1).argmax().type(torch.int32), lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) def test_log_softmax(self): helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([()], torch.nn.LogSoftmax(dim=-1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) def test_log_softmax_other_axis(self): helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) def test_normalize(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x), lambda x: x.normalize(), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, dim=0), lambda x: x.normalize(dim=0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: torch.nn.functional.normalize(x, dim=2), lambda x: x.normalize(dim=2), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=1), lambda x: x.normalize(p=1), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=3, dim=0), lambda x: x.normalize(p=3, dim=0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=0), lambda x: x.normalize(p=0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.nn.functional.normalize(x, p=-1), lambda x: x.normalize(p=-1), atol=1e-7, grad_atol=1e-7) def test_logsumexp(self): helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1, keepdim=True), lambda x: x.logsumexp(1, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7) @slow_test def test_logcumsumexp(self): helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=3), lambda x: x.logcumsumexp(3), atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7) def test_logcumsumexp_numerical(self): helper_test_op(None, lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7, vals=[[0.0, 100.0]]) def test_sinh(self): helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6) # TODO: backward nan instead of inf helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6, low=-300, high=-297, forward_only=True) helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6, low=300, high=303, forward_only=True) def test_cosh(self): helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6) # TODO: backward nan instead of inf helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6, low=-300, high=-297, forward_only=True) helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6, low=300, high=303, forward_only=True) def test_tanh(self): helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6) def test_tanh_extreme(self): helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6, low=300, high=303) def test_hardtanh(self): for val in range(10, 30, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), grad_atol=1e-6) helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), grad_atol=1e-6) def test_asinh(self): helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6) # TODO: this one has larger tol? helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_rtol=2e-2, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303) def test_acosh(self): helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6) helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-3, grad_rtol=1e-2, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=300, high=303) def test_atanh(self): helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6) helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6, low=300, high=303) def test_topo_sort(self): helper_test_op([(45,65)], lambda x: (x+x)*x, grad_atol=1e-6) helper_test_op([()], lambda x: (x+x)*x, grad_atol=1e-6) def test_flip_eye_crash(self): helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,3,14,16), (5,1,14,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): if tinygrad_op != Tensor.pow: helper_test_op(shapes, torch_op, tinygrad_op) else: helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_broadcast_simple(self): helper_test_op([(45,65), (45,1)], lambda x,y: x/y) helper_test_op([(45,65), ()], lambda x,y: x/y) @slow_test def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)), ((4,1), (4,5)), ((1,4), (5,4))]: with self.subTest(op=torch_op.__name__, shapes=shapes): # NOTE: ANE backwards? if tinygrad_op != Tensor.pow: helper_test_op(shapes, torch_op, tinygrad_op) else: helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_slice_in_bounds_1dim(self): helper_test_op([(3)], lambda x: x[1:3]) helper_test_op([(3)], lambda x: x[0:2]) helper_test_op([(3)], lambda x: x[-2:2]) def test_slice_on_0dim_tensor(self): helper_test_op([()], lambda x: x[None]) with self.assertRaises(IndexError): a = Tensor(3.14) a[0] def test_slice_int_indexing(self): helper_test_op([(3)], lambda x: x[0]) helper_test_op([(3)], lambda x: x[2]) helper_test_op([(3)], lambda x: x[-1]) helper_test_op([(3)], lambda x: x[-3]) helper_test_op([(10,10)], lambda x: x[1]) helper_test_op([(3,3,3)], lambda x: x[1,1,1]) def test_slice_in_bounds_multidim(self): helper_test_op([(3,3,3)], lambda x: x[1:2]) helper_test_op([(3,3,3)], lambda x: x[1:2, 2]) helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2]) helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1]) def test_slice_with_none(self): helper_test_op([(3,3,3)], lambda x: x[None]) helper_test_op([(3,3,3)], lambda x: x[1:2, None]) helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2]) helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1]) helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2]) def test_slice_with_const_tensor(self): t = Tensor.zeros(1, dtype=dtypes.int) helper_test_op([(3,3,3)], lambda x: x[:, [0], :], lambda x: x[:, t, :]) helper_test_op([(3,3,3)], lambda x: x[:, [0], :], lambda x: x[:, t.contiguous(), :]) def test_slice_one_endpoint_out_of_bounds(self): helper_test_op([(3,3,3)], lambda x: x[0:4]) helper_test_op([(3,3,3)], lambda x: x[-6:4]) helper_test_op([(3,3,3)], lambda x: x[1:50]) helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1]) def test_slice_stride_gt_one(self): helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4]) helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4]) helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4]) helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4]) @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_slice_negative_strides(self): # Torch doesn't support slicing with negative steps a = np.random.randn(10, 10, 10).astype(np.float32) t = Tensor(a) np.testing.assert_allclose(a[::-1], t[::-1].numpy()) np.testing.assert_allclose(a[::-2], t[::-2].numpy()) np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) def test_slice_both_endpoints_out_of_bounds(self): helper_test_op([(3,3,3)], lambda x: x[5:10]) helper_test_op([(3,3,3)], lambda x: x[-15:-7]) def test_slice_start_gt_end(self): helper_test_op([(3,3,3)], lambda x: x[-2:2]) helper_test_op([(3,3,3)], lambda x: x[-2:-5]) def test_slice_empty(self): helper_test_op([(10,10)], lambda x: x[1:1]) def test_slice_zero_in_shape(self): helper_test_op([(10,10)], lambda x: x[1:1]) # x.shape = (0, 10) helper_test_op([(3,3,3)], lambda x: x[-2:-5]) # x.shape = (0, 3, 3) def test_slice_errors(self): a = Tensor.ones(4, 3) b = Tensor(2) with self.assertRaisesRegex(IndexError, "too many"): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds) with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, 3] # IndexError: (out of bounds). with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, -4] with self.assertRaisesRegex(IndexError, "single ellipsis"): a[..., ...] # IndexError: only single ellipsis with self.assertRaises(ValueError): a[::0, 1] # no 0 strides with self.assertRaises(TypeError): a[:Tensor([3]), 1] # Tensor can't be used as a slice parameter with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor def test_slice_ellipsis(self): helper_test_op([(3,3,3,3)], lambda x: x[..., 0]) helper_test_op([(3,3,3,3)], lambda x: x[0, ...]) helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0]) helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3]) helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None]) # this was the failure in llama early realizing freqs_cis def test_double_slice(self): helper_test_op([(4,4)], lambda x: x[:, 1:2][1:2]) helper_test_op([(4,4)], lambda x: x[1:3][1:2]) helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1]) helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1]) def test_pad(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad(padding=(1,2,3,4))) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(padding=(-1,2,-3,4))) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(padding=(1,2,3,4),value=5)) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad(padding=(-1,2,-3,4),value=5)) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(padding=(1,2,3,4),value=math.inf)) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=-math.inf), lambda x: x.pad(padding=(-1,2,-3,4),value=-math.inf)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(((-3,4), (-1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1)) # raise error for uneven pads self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (2,0,2)), lambda x: x.pad((2,0,2)), expected=(RuntimeError, ValueError)) # raise error for too many or too little pads self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0,1,0,3,0)), lambda x: x.pad((0,0,0,0,1,0,3,0)), expected=(RuntimeError, ValueError)) # raise error for mode string typo self.helper_test_exception([(3,3,3)], lambda x: torch.nn.functional.pad(x, (3,0), mode="typo"), lambda x: x.pad((3,0), mode="typo"), expected=NotImplementedError) x = Tensor.ones(3,3) with self.assertRaises(ValueError): x.pad((None,(0,1),(3,0))) with self.assertRaises(ValueError): x.pad(((0,1),)) def test_pad_reflect_mode(self): helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="reflect"), lambda x: x.pad((0,2,3,2), mode="reflect")) helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="reflect"), lambda x: x.pad((0,2), mode="reflect")) helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2), mode="reflect"), lambda x: x.pad((1,2,3,4,1,2), mode="reflect")) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="reflect"), lambda x: x.pad((-1,2,2,-1), mode="reflect")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="reflect"), lambda x: x.pad((3,-3,0,-3), mode="reflect")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="reflect"), lambda x: x.pad((3,-5,1,-5), mode="reflect")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="reflect"), lambda x: x.pad((0,0,0,-5), mode="reflect")) # max pad size for reflect is exactly once: pad < input size helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (4,4,0,4), mode="reflect"), lambda x:x.pad((4,4,0,4),mode="reflect")) # raise error for relfection padding when: pad >= input size self.helper_test_exception([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,5,0,0),mode="reflect"), lambda x: x.pad((3,5,0,0),mode="reflect"), expected=(RuntimeError, ValueError)) def test_pad_replicate_mode(self): helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="replicate"), lambda x: x.pad((0,2,3,2), mode="replicate")) helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="replicate"), lambda x: x.pad((0,2), mode="replicate")) helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2),mode="replicate"),lambda x:x.pad((1,2,3,4,1,2),mode="replicate")) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="replicate"), lambda x: x.pad((-1,2,2,-1), mode="replicate")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="replicate"), lambda x: x.pad((3,-3,0,-3), mode="replicate")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="replicate"), lambda x: x.pad((3,-5,1,-5), mode="replicate")) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="replicate"), lambda x: x.pad((0,0,0,-5), mode="replicate")) # no max pad sizes for replicate helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,11,0,30), mode="replicate"), lambda x: x.pad((3,11,0,30), mode="replicate")) def test_pad_circular_mode(self): helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="circular"), lambda x: x.pad((0,2,3,2), mode="circular")) helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="circular"), lambda x: x.pad((0,2), mode="circular")) helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,5,1,2),mode="circular"),lambda x:x.pad((1,2,3,5,1,2),mode="circular")) # circular pad cannot wrap around more than once self.helper_test_exception([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"), expected=(RuntimeError, ValueError)) helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (1,-2,2,-1), mode="circular"), lambda x: x.pad((1,-2,2,-1), mode="circular")) def test_pad_reshape(self): helper_test_op([(1, 2)], lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)), lambda x: x.pad((0, 1, 1, 0)).reshape((3, 2))) helper_test_op([(1, 2)], lambda x: torch.nn.functional.pad(x, (0, 2, 1, 1)).reshape((4, 3)), lambda x: x.pad((0, 2, 1, 1)).reshape((4, 3))) helper_test_op([(1, 1, 1, 2)], lambda x: torch.nn.functional.pad(x, (0, 4, 2, 2, 1, 2, 0, 2)).reshape((4, 3, 6, 5)), lambda x: x.pad(((0, 2), (1, 2), (2, 2), (0, 4))).reshape((4, 3, 6, 5))) def test_pad_slice(self): for value in 0., 3.456: helper_test_op([(1)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(3,0), value=value)[0:1], lambda x: x.pad(((3,0),), value=value)[0:1]) helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[6], lambda x: x.pad(((0,3),), value=value)[6]) helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[4:6], lambda x: x.pad(((0,3),), value=value)[4:6]) helper_test_op([(5,5)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0], lambda x: x.pad(((1,0),(0,0)), value=value)[0]) helper_test_op([(2,2)], lambda x: torch.nn.functional.pad(x,(0,1,0,0), value=value)[0,2], lambda x: x.pad(((0,0),(0,1)), value=value)[0,2]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0,2], lambda x: x.pad(((1,0),(0,0)), value=value)[0,2]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[5], lambda x: x.pad(((0,2),(0,0)), value=value)[5]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[3:5], lambda x: x.pad(((0,2),(0,0)), value=value)[3:5]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0:4], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0:4]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[0], lambda x: x.pad(((1,2),(3,4)), value=value)[0]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,1], lambda x: x.pad(((1,2),(3,4)), value=value)[:,1]) helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,4], lambda x: x.pad(((1,2),(3,4)), value=value)[:,4]) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,3,0,0), value=value)[:,4:6], lambda x: x.pad(((0,0),(0,3)), value=value)[:,4:6]) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,1,3,2), value=value)[0:2,:], lambda x: x.pad(((3,2),(0,1)), value=value)[0:2,:]) helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[0:2,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[0:2,:,:]) helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[2:4,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:]) def test_stack_slice(self): helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack(*[x for i in range(3)])[0,:]) helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack(*[x for i in range(3)])[0,0]) helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack(*[x for i in range(4)])[3]) def test_transpose(self): helper_test_op([(3,3)], lambda x: x.T) helper_test_op([(3,3,3)], lambda x: x.transpose(1,2)) helper_test_op([(3,3,3)], lambda x: x.transpose(0,2)) def test_permute(self): helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1))) helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0))) helper_test_op([(3,4,5,6)], lambda x: x.permute((-2,-1,1,0))) helper_test_op([()], lambda x: x.permute(())) self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), expected=RuntimeError) self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError) self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError) def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: x.reshape((12,6,6))) helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6))) helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6))) helper_test_op([(4,3,6,6)], lambda x: x.reshape((4,3,6,6)), lambda x: x.reshape((None,None,6,6))) helper_test_op([()], lambda x: x.reshape(())) helper_test_op([(1,)], lambda x: x.reshape(())) helper_test_op([()], lambda x: x.reshape((1,))) helper_test_op([()], lambda x: x.reshape((1,1,1))) self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), expected=RuntimeError) self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError) with self.assertRaises(ValueError): x = Tensor.ones((4,3,6,6)) x.reshape([]) def test_view(self): helper_test_op([(4,3,6,6)], lambda x: x.view((12,6,6))) helper_test_op([(4,3,6,6)], lambda x: x.view((-1,3,6,6))) helper_test_op([(6,)], lambda x: x.view(2, 3)) helper_test_op([(6,1)], lambda x: x.view([2, 3])) helper_test_op([(1,6)], lambda x: x.view((3, 2))) helper_test_op([(3,2)], lambda x: x.view((2, 3))) helper_test_op([(3,2)], lambda x: x.view(6)) def test_flip(self): helper_test_op([(4,3,6,6)], lambda x: x.flip((0,))) helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1))) helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3))) helper_test_op([(4,3,6,6)], lambda x: x.flip((3,))) helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3)).flip(0)) helper_test_op([(4,3,6,6)], lambda x: x.flip((-1,))) helper_test_op([()], lambda x: x.flip(())) helper_test_op([(1,)], lambda x: x.flip(())) helper_test_op([(4,3,6,6)], lambda x: x.flip(())) self.helper_test_exception([(3,4)], lambda x: x.flip((0,0)), expected=RuntimeError) self.helper_test_exception([(3,4)], lambda x: x.flip((1,1)), expected=RuntimeError) self.helper_test_exception([(3,4)], lambda x: x.flip((1,-1)), expected=RuntimeError) def test_squeeze(self): helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0)) helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1)) helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3)) self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError) self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError) helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1)) helper_test_op([(4,3,6,6)], lambda x: x.squeeze()) helper_test_op([(1,3,6,6)], lambda x: x.squeeze()) helper_test_op([(2,3,1)], lambda x: x.squeeze()) helper_test_op([()], lambda x: x.squeeze(-1)) helper_test_op([()], lambda x: x.squeeze(0)) helper_test_op([()], lambda x: x.squeeze()) self.helper_test_exception([()], lambda x: x.squeeze(10), expected=IndexError) self.helper_test_exception([()], lambda x: x.squeeze(1), expected=IndexError) self.helper_test_exception([()], lambda x: x.squeeze(-2), expected=IndexError) def test_unsqueeze(self): helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0)) helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(4)) helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-1)) helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-3)) helper_test_op([()], lambda x: x.unsqueeze(0)) def test_flatten(self): for axis in range(3): helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=axis)) for axis in range(3): helper_test_op([(4,3,6,6)], lambda x: x.flatten(end_dim=axis)) helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=1, end_dim=3)) helper_test_op([()], lambda x: x.flatten()) helper_test_op([(1,)], lambda x: x.flatten()) def test_unflatten(self): helper_test_op([(4,3,6,6)], lambda x: x.unflatten(0, (2, 2))) helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2))) helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1))) def test_diag(self): helper_test_op([(5,)], lambda x: x.diag()) def test_diagonal(self): helper_test_op([(5,5)], lambda x: x.diagonal()) helper_test_op([(3,4)], lambda x: x.diagonal()) # rectangular helper_test_op([(4,3)], lambda x: x.diagonal()) # rectangular (other way) helper_test_op([(3,3,3)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched helper_test_op([(4,5,6)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched rectangular helper_test_op([(2,3,4,5)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # 4D batched helper_test_op([(5,5)], lambda x: x.diagonal(offset=1)) # positive offset helper_test_op([(5,5)], lambda x: x.diagonal(offset=-1)) # negative offset helper_test_op([(3,5)], lambda x: x.diagonal(offset=2)) # offset on rectangular self.helper_test_exception([(3,3)], lambda x: x.diagonal(dim1=0, dim2=0), expected=RuntimeError) def test_roll(self): helper_test_op([(2, 4)], lambda x: x.roll(1)) helper_test_op([(2, 4)], lambda x: x.roll((1,))) self.helper_test_exception([(2, 4)], lambda x: x.roll((1, 2)), expected=RuntimeError) helper_test_op([(2, 4)], lambda x: x.roll(1, 0)) helper_test_op([(2, 4)], lambda x: x.roll(-1, 0)) helper_test_op([(2, 4)], lambda x: x.roll(shifts=(2, 1), dims=(0, 1))) helper_test_op([(2, 4, 6)], lambda x: x.roll(1, 0)) helper_test_op([(2, 4)], lambda x: x.roll(1, -1)) helper_test_op([(2, 4)], lambda x: x.roll(-1, -1)) helper_test_op([(2, 4)], lambda x: x.roll(5, 0)) helper_test_op([(2, 4)], lambda x: x.roll(-5, 0)) helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(2, -3), dims=(0, 2))) helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(1, 2, -1), dims=(0, 1, 2))) helper_test_op([(2, 4)], lambda x: x.roll(0, 0)) helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(0, 0), dims=(0, 1))) helper_test_op([(2, 4, 6)], lambda x: x.roll(shifts=(0, 2), dims=(0, 1))) self.helper_test_exception([(3, 3)], lambda x: x.roll(shifts=1, dims=(0, 1)), expected=RuntimeError) self.helper_test_exception([(10,)], lambda x: x.roll(shifts=(1, 2), dims=0), expected=RuntimeError) def test_detach(self): helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True) helper_test_op([()], lambda x: x.detach(), forward_only=True) def test_expand(self): helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6))) helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6))) helper_test_op([(4,3,1,6)], lambda x: x.expand((6,1,4,3,2,6))) helper_test_op([(4,3,1,6)], lambda x: x.expand((0,1,4,3,2,6))) helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,0,6))) helper_test_op([()], lambda x: x.expand((4,3,2,6))) helper_test_op([()], lambda x: x.expand([])) with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,1,1,6) with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,6,1,6) with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(3,1,6) with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,2,6).expand(4,3,0,6) @unittest.skip("very slow") def test_sd_big_conv(self): # internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int helper_test_op([(1,256,64,64), (512,256,3,3)], lambda x,w: torch.nn.functional.conv2d(x, w), lambda x,w: x.conv2d(w), atol=1e-3) @unittest.skip("slow") def test_large_bs_conv(self): # large batch size can cause OpenCL image to exceed max image height on macOS # (or cause the conv kernel to overflow short sampling coords) helper_test_op([(4096,3,3,3), (1,3,3,3)], lambda x,w: torch.nn.functional.conv2d(x, w), lambda x,w: x.conv2d(w), atol=1e-3) @unittest.skip("slow") def test_large_ic_conv(self): # large input channel count can cause OpenCL image to exceed max image width on macOS helper_test_op([(1,2048,3,3), (1,2048,3,3)], lambda x,w: torch.nn.functional.conv2d(x, w), lambda x,w: x.conv2d(w)) def test_biased_conv2d(self): C = 8 helper_test_op([(1,C,5,5), (C,C,1,1), (C,)], lambda x,w,b: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w,b).relu(),w,b), lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b)) def test_simple_conv2d(self): helper_test_op([(1,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) def test_simple_conv2d_bias(self): helper_test_op([(1,4,9,9), (4,4,3,3), (4,)], lambda x,w,b: torch.nn.functional.conv2d(x,w,b), lambda x,w,b: Tensor.conv2d(x,w,b), grad_rtol=1e-5) @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv3d(self): helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv3d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_padded_conv3d(self): helper_test_op([(1,4,5,5,5), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv3d(x,w,padding=1), lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]), grad_rtol=1e-5) def test_simple_conv2d_m4(self): helper_test_op([(1,16,9,9), (16,16,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), atol=1e-05, grad_rtol=1e-5) def test_simple_conv2d_1x1(self): helper_test_op([(1,4,9,9), (4,4,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) def test_simple_conv2d_1x1_m4(self): helper_test_op([(1,16,32,32), (16,16,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) @slow_test def test_nested_conv2d(self): helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)], lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2), lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2)) # expect reduce nodes == 3 def test_simple_conv2d_nhwc(self): # weights (from tf): filter_height x filter_width x in_channels x out_channels helper_test_op([(2,9,9,10), (3,3,10,20)], lambda x,w: torch.nn.functional.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)), lambda x,w: Tensor.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)), atol=1e-5, grad_rtol=1e-5) def test_simple_conv2d_batched(self): helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) # conv transpose def test_simple_conv_transpose2d(self): helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w), lambda x,w: Tensor.conv_transpose2d(x,w), grad_rtol=1e-5) def test_bias_conv_transpose2d(self): helper_test_op([(2,4,9,9), (4,4,3,3), (4,)], lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b), lambda x,w,b: Tensor.conv_transpose2d(x,w,b), grad_rtol=1e-5) def test_grouped_conv_transpose2d(self): helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2), lambda x,w: Tensor.conv_transpose2d(x,w,groups=2), grad_rtol=1e-5) @slow_test def test_padded_conv_transpose2d(self): for padding in [(1,2), (2,1), 2, 1, 0]: helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding), lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding), grad_rtol=1e-5) self.helper_test_exception([(2,16,2,2), (32,16,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=(1,1,1)), lambda x,w: Tensor.conv_transpose2d(x,w,padding=(1,1,1)), expected=(RuntimeError, ValueError)) @slow_test def test_dilated_conv_transpose2d(self): for dilation in [(1,2), (2,1), 2, 1]: helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=dilation), lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation), grad_rtol=1e-5) def test_strided_conv_transpose2d(self): for stride in [(2,1), (1,2), 1]: helper_test_op([(2,4,4,5), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride), lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride), atol=1e-5, grad_rtol=1e-5) @slow_test def test_output_padded_conv_transpose2d(self): for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride), lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride), grad_rtol=1e-5) @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv_transpose3d(self): helper_test_op([(2,4,9,9,9), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv_transpose3d(x,w), lambda x,w: Tensor.conv_transpose2d(x,w), grad_rtol=1e-5) @unittest.skipIf((IMAGE>0), "no conv1d on images") def test_conv1d(self): for bs in [1,8]: for cin in [1,3]: for H in [1,2,5]: for groups in [1,3] if cin == 3 and H == 5 else [1]: with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H): helper_test_op([(bs,cin,11), (6,cin//groups,H)], lambda x,w: torch.nn.functional.conv1d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) @unittest.skipIf(IMAGE>0, "no conv1d on images") def test_simple_padding_conv1d(self): bs = 6 cin = 2 groups = 1 H = 5 p = (1,1) helper_test_op([(bs,cin,11), (6,cin//groups,H)], lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w), lambda x,w: Tensor.conv2d(x,w,padding=p)) @unittest.skipIf(IMAGE>0, "no conv1d on images") def test_strided_conv1d_simple(self): bs, H = 2, 3 helper_test_op([(bs,1,5), (1,1,H)], lambda x,w: torch.nn.functional.conv1d(x,w,stride=2), lambda x,w: Tensor.conv2d(x,w,stride=2)) @unittest.skipIf(IMAGE>0, "no conv1d on images") def test_asymmetric_padding_conv1d(self): for p in [(0,1), (2,1), (2,0)]: with self.subTest(p): for n in [3,4]: for k in [2]: helper_test_op([(1,1,n), (1,1,k)], lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w), lambda x,w: Tensor.conv2d(x,w,padding=p)) def _test_conv2d(self, bs=1, cin=1, cout=6): for H in [2,3]: for W in [1,3,5]: for groups in [1,3] if cin == 3 and cout == 6 and H == 3 and W == 3 else [1]: with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W): helper_test_op([(bs,cin,5,7), (cout,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_conv2d(self): self._test_conv2d(bs=1, cin=3) @slow_test def test_conv2d_bs_4_cin_3(self): self._test_conv2d(bs=4, cin=3, cout=2) def test_conv2d_bs_1_cin_1(self): self._test_conv2d(bs=1, cin=1) @slow_test def test_conv2d_bs_4_cin_1(self): self._test_conv2d(bs=4, cin=1) def test_conv2d_errors(self): # kernel size cannot be larger than input size self.helper_test_exception([(1,1,6,7), (6,1,3,3)], lambda x,w:torch.nn.functional.conv2d(x,w,dilation=3), lambda x,w: Tensor.conv2d(x,w,dilation=3), expected=(RuntimeError, AssertionError)) # regression test for https://github.com/tinygrad/tinygrad/pull/7549/ self.helper_test_exception([(2,16,2,2), (32,16,3,3)], lambda x,w:torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(2,16,2,2), (32,16,3,3)], lambda x,w:torch.nn.functional.conv2d(x,w,padding=(1,1,1)), lambda x,w: Tensor.conv2d(x,w,padding=(1,1,1)), expected=(RuntimeError, ValueError)) @slow_test def test_large_input_conv2d(self): bs = 4 cin = 16 groups = 1 H = 5 W = 2 helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), # needed to relax tolerance for larger input lambda x,w: Tensor.conv2d(x,w,groups=groups), atol=1e-4, grad_atol=3e-4, grad_rtol=1e-4) def test_simple_grouped_conv2d(self): bs = 1 groups = 2 rcout = 1 cin = 2 helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_medium_grouped_conv2d(self): bs = 1 groups = 2 rcout = 2 cin = 2 helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_depthwise_conv2d(self): bs = 1 groups = 32 rcout = 1 cin = 1 helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_grouped_conv2d(self): bs = 4 groups = 5 rcout = 7 cin = 3 helper_test_op([(bs,groups*cin,5,5), (groups*rcout,cin,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_fancy_conv2d(self): bs = 2 cin = 3 cout = 1 groups = 3 H,W = 3,3 helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) @slow_test def test_strided_conv2d_simple(self): bs,H,W = 2,3,1 helper_test_op([(bs,1,5,1), (1,1,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2), lambda x,w: Tensor.conv2d(x,w,stride=2)) @unittest.skipUnless(Device.DEFAULT == "CPU" and DEV.renderer == "LLVM", "DEVECTORIZE=0 only for LLVM") def test_strided_conv2d_simple_vec(self): self.test_strided_conv2d_simple() @slow_test def test_strided_conv2d(self): bs = 4 cin = 3 H,W = 3,3 with self.subTest(stride := 2): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2), lambda x,w: Tensor.conv2d(x,w,stride=stride)) with self.subTest(stride := (2,1)): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride), lambda x,w: Tensor.conv2d(x,w,stride=(2,1))) def test_negative_padding_conv2d(self): n,k = 10, 3 helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1],w), lambda x,w: Tensor.conv2d(x,w,padding=-1)) helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:, 1:],w), lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0))) def test_simple_padding_conv2d(self): p = (1,1,1,1) helper_test_op(None, lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w), lambda x,w: Tensor.conv2d(x,w,padding=p), vals=[[[[[2.,3.]]]], [[[[1.]]]]]) def test_asymmetric_padding_conv2d(self): for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: with self.subTest(p): for n in [3,4]: for k in [2]: helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w), lambda x,w: Tensor.conv2d(x,w,padding=p)) helper_test_op([(1,1,n,n), (1,1,k,k)], lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w), lambda x,w: Tensor.conv2d(x,w,padding=p)) def test_padded_conv2d_p21(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,1) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding), lambda x,w: Tensor.conv2d(x,w,padding=padding)) def test_padded_conv2d_p22(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,2) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding), lambda x,w: Tensor.conv2d(x,w,padding=padding)) def test_padded_conv2d_1x1(self): bs,cin,H,W,padding = 4, 3, 1, 1, 2 helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding), lambda x,w: Tensor.conv2d(x,w,padding=padding)) def test_padded_conv2d_bs1(self): bs,cin,H,W,padding = 1, 3, 3, 3, 1 helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding), lambda x,w: Tensor.conv2d(x,w,padding=padding)) def test_padding_add(self): helper_test_op([(64,64), (60,60)], lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)), lambda x,w: x+w.pad((2,2,2,2))) def test_dilated_conv2d(self): bs = 4 cin = 3 H,W = 3,3 for d in [2, (2,1)]: with self.subTest(dilation := d): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation), lambda x,w: Tensor.conv2d(x,w,dilation=dilation)) def test_max_pool2d_simple(self): ksz = (2,2) helper_test_op([(1,1,2,3)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) @slow_test def test_max_pool2d(self): for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) @slow_test def test_max_pool2d_padding(self): for ksz in [(2,2), (3,3), 2, 3, (3,2)]: for p in [1, (1,0), (0,1)]: with self.subTest(kernel_size=ksz, padding=p): helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=p), lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=p)) self.helper_test_exception([(4,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), expected=(RuntimeError, ValueError)) @slow_test def test_max_pool2d_asymmetric_padding(self): for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: with self.subTest(padding=p): helper_test_op([(4,2,111,28)], lambda x: torch.nn.functional.max_pool2d(torch.nn.functional.pad(x, p, value=float("-inf")), kernel_size=(5,5)), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), padding=p)) @slow_test def test_max_pool2d_padding_int(self): ksz = (2,2) helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x.int(), kernel_size=ksz, padding=1), lambda x: Tensor.max_pool2d(x.int(), kernel_size=ksz, padding=1), forward_only=True) @slow_test def test_max_pool2d_bigger_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) @slow_test def test_max_pool2d_bigger_stride_dilation(self): for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]): with self.subTest(stride=stride): helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation)) @unittest.skipIf( Device.DEFAULT in {"CUDA", "NV"}, "CUDA fails on this") def test_max_pool2d_unit_stride(self): helper_test_op([(3, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) @slow_test def test_max_pool2d_smaller_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(3, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) @slow_test def test_max_pool2d_dilation(self): for dilation in [(2, 3), (3, 2), 2, 3]: helper_test_op([(3, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation)) def test_max_pool2d_ceil_mode(self): shape = (1,1,6,6) for ksz in [(3,3), 3, (3,2), 4]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True), lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True)) def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self): # sliding window ignored from end region helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True), lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True)) def test_max_pool2d_return_indices(self): # batch and multi-channel helper_test_op([(2,3,6,6)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True) # dilation helper_test_op([(1,1,10,10)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True) # padding helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True) # ceil mode padding helper_test_op([(1, 1, 7, 7)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1], forward_only=True) # global maxpool helper_test_op([(1,1,12,13)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1], forward_only=True) # multiple identical values in same window and overlapping windows helper_test_op(None, lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1], vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6) # overlapping max indices helper_test_op(None, lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1], vals=[[[[[1,2]*3]*6]]], forward_only=True) # Tensor([1,2,1,2,1,2]).expand(1,1,6,6) @slow_test def test_max_unpool2d(self): args = {"kernel_size":(5,5), "stride":(6,5)} helper_test_op([(8,3,50,50)], lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args), lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args), forward_only=True) args = {"kernel_size":(3,3), "stride":(6,7), "padding":1} helper_test_op([(8,3,30,30)], lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)), lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)), forward_only=True) # batch_size and channel_size of output_size are ignored helper_test_op([(1,3,7,6)], lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True), kernel_size=(2,2), output_size=(99,99,7,6)), lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True), kernel_size=(2,2), output_size=(99,99,7,6)), forward_only=True) def test_max_unpool2d_inf(self): data = [[[[math.inf, -math.inf, math.nan], [1.0, 2.0, 3.0]]]] ksz = (2,2) helper_test_op((), lambda: torch.nn.functional.max_unpool2d( *torch.nn.functional.max_pool2d(torch.tensor(data), kernel_size=ksz, return_indices=True), kernel_size=ksz ), lambda: Tensor.max_unpool2d( *Tensor.max_pool2d(Tensor(data), kernel_size=ksz, return_indices=True), kernel_size=ksz ), forward_only=True) @slow_test def test_avg_pool2d(self): shape = (32,2,11,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) # regression test for https://github.com/tinygrad/tinygrad/pull/7581 helper_test_op([(1,1,8,8)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5) @slow_test def test_avg_pool2d_padding(self): shape = (32,2,11,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: for p in [1, (1,0), (0,1)]: with self.subTest(kernel_size=ksz, padding=p): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=p), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=p), rtol=1e-5) with self.assertRaises(ValueError): Tensor.avg_pool2d(Tensor.randn((32,2,11,28)), kernel_size=(2,2), padding=(1,1,1)) def test_avg_pool2d_asymmetric_padding(self): shape = (32,2,11,28) for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: with self.subTest(padding=p): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(5,5), padding=1), lambda x: Tensor.avg_pool2d(x, kernel_size=(5,5), padding=1), rtol=1e-5) self.helper_test_exception([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), lambda x: Tensor.avg_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), expected=(RuntimeError, ValueError)) @slow_test def test_avg_pool2d_padding_not_counted(self): shape = (32,2,11,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5) def test_avg_pool2d_ceil_mode(self): shape = (1,1,6,6) for ksz in [(3,3), 3, (3,2), 4]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True), rtol=1e-5) def test_avg_pool2d_ceil_mode_padding_not_counted(self): shape = (1,1,6,6) for ksz in [(3,3), 3, (3,2), 4]: with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5) def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self): # sliding window ignored from end region helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True), lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True)) def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self): # sliding window ignored from end region helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True), lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True)) def test_global_avg_pool2d(self): helper_test_op([(32,2,11,28)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(11,28)), lambda x: Tensor.avg_pool2d(x, kernel_size=(11,28)), rtol=1e-5) def test_avg_pool3d(self): # TODO: AMD:LLVM has larger atol # TODO: DEV=PYTHON backward hangs? atol = 1e-2 if DEV.device == "AMD" and DEV.renderer == "LLVM" else 1e-6 helper_test_op([(1,1,16,16,16)], lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), atol=atol, rtol=1e-5, forward_only=True) def test_interpolate_linear(self): for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear"), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear")) def test_interpolate_linear_corners_aligned(self): for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True)) def test_interpolate_nearest(self, mode="nearest"): for in_sz, out_sz in [((13,),(9,)), ((9,),(13,))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) for in_sz, out_sz in [((13,10),(9,11)), ((13,9),(11,10)), ((9,11),(10,13))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) for in_sz, out_sz in [((5,2,8),(3,6,4))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) def test_interpolate_nearest_exact(self): self.test_interpolate_nearest("nearest-exact") @slow_test def test_interpolate_bilinear(self): for in_sz, out_sz in [((12,20),(9,31)), ((12,9),(31,20)), ((9,31),(20,12))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear"), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4) @slow_test def test_interpolate_bilinear_corners_aligned(self): for in_sz, out_sz in [((12,20),(9,31)), ((12,9),(31,20)), ((9,31),(20,12))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear", align_corners=True), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4) def test_interpolate_trilinear(self): for in_sz, out_sz in [((5,2,8),(3,6,4))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear"), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4) def test_interpolate_trilinear_corners_aligned(self): for in_sz, out_sz in [((5,2,8),(3,6,4))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear", align_corners=True), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4) @slow_test def test_cat(self): for dim in range(-2, 3): helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) # zero in non-cat axis helper_test_op([(45,0,9), (45,0,9), (45,0,9)], lambda x,y,z: torch.cat((x,y,z), 0), lambda x,y,z: x.cat(y, z, dim=0)) # zero in cat axis helper_test_op([(45,0,9), (45,1,9), (45,2,9)], lambda x,y,z: torch.cat((x,y,z), 1), lambda x,y,z: x.cat(y, z, dim=1)) helper_test_op([(45,0,9), (45,0,9), (45,0,9)], lambda x,y,z: torch.cat((x,y,z), 1), lambda x,y,z: x.cat(y, z, dim=1)) with self.assertRaises(IndexError): a = Tensor(3.14) a.cat(a) def test_multicat(self): for dim in range(-1, 2): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_stack(self): for dim in range(-1, 3): helper_test_op([(5,6,3), (5,6,3), (5,6,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack(x, y, z, dim=dim)) helper_test_op([(5,6,3), (5,6,3), (5,6,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack((x, y, z), dim=dim)) with self.assertRaises(IndexError): Tensor.stack(Tensor.randn(45, 65, 3), dim=77) with self.assertRaises(ValueError): Tensor.stack((Tensor([1, 2]), Tensor([3, 4])), Tensor([5, 6])) np.testing.assert_allclose(Tensor.stack(Tensor(3.14), Tensor(3.14)).numpy(), np.array([3.14, 3.14])) def test_stack_max(self): helper_test_op(None, lambda x, y: torch.stack((x, y)).max(axis=0)[0], lambda x, y: Tensor.stack(x, y).max(axis=0), vals=[[1.], [2.]]) def test_repeat(self): x = Tensor.randn(4, 6, 3) base_repeats = [2, 4, 3] for reps in [[], [4], [2, 1], [3, 2, 2]]: repeats = base_repeats + reps helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) with self.assertRaises(ValueError): x.repeat((2, 4)) np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy()) def test_repeat_interleave(self): helper_test_op([(3, 3)], lambda x: x.repeat_interleave(6)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 1)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 0)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -1)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -2)) def test_simple_repeat(self): repeats = [3, 3, 4] helper_test_op([(3, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2)) # NOTE: torch set backward to 1 at the boundaries # https://github.com/pytorch/pytorch/blob/7a41b66367c38d0af3e8a90f7be48d6b281e7bca/tools/autograd/derivatives.yaml#L421 helper_test_op(None, lambda x: x.clip(-2.5, 1.5), vals=[[-3.0, -2.5, 0, 1.5, 2]]) helper_test_op([(45,65)], lambda x: x.clip(0, 0)) helper_test_op([(45,65)], lambda x: x.clip(10, 100)) helper_test_op([(45,65)], lambda x: x.clip(0, 0.1)) helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2)) helper_test_op([(45,65)], lambda x: x.clip(3, 0)) # min > max helper_test_op([(45,65)], lambda x: x.clip(None, 0)) helper_test_op([(45,65)], lambda x: x.clip(0, None)) self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), expected=RuntimeError) def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z) def test_matvec(self): helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu()) @unittest.skip("this test is broken #862") def test_max_nan(self): n = Tensor([1, float("nan")]).max().numpy() assert math.isnan(n.item()), f"{n.item()} is not nan" @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_inf_where(self): x = Tensor.full((3, 3), float("inf")) n = (x < 0).where(x, 1).numpy() assert np.all(n == 1.) def _get_index_randoms(self): # indices cannot have gradient a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False) b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False) c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False) d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False) e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False) i, j, k, o, p = [Tensor(tor.detach().cpu().numpy().astype(np.int32)) for tor in [a,b,c,d,e]] return a,b,c,d,e,i,j,k,o,p def test_fancy_indexing_inf(self): data = [math.inf, -math.inf, math.nan] helper_test_op((), lambda: torch.tensor(data)[torch.tensor([0, 1, 2])], lambda: Tensor(data)[Tensor([0, 1, 2])]) @slow_test def test_slice_fancy_indexing_no_dim_collapse(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # no dim collapse from int or dim injection from None helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,d,e], lambda x: x[i,j,k,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[:,b,c,d,:], lambda x: x[:,j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,...], lambda x: x[i,j,...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p]) @slow_test @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_slice_fancy_indexing_dim_collapse_int(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # dim collapse from int helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,c,d,e], lambda x: x[1,j,k,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,3,d,e], lambda x: x[i,j,3,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,2,d,2], lambda x: x[1,j,2,o,2]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,2,2,2,e], lambda x: x[i,2,2,2,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,:,3:11:2,d,0:2], lambda x: x[1,:,3:11:2,o,0:2]) @slow_test @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_slice_fancy_indexing_dim_inject_none(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # dim injection from None helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,c,d,e], lambda x: x[None,j,k,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,d,None], lambda x: x[i,j,k,o,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,None,d,e], lambda x: x[i,j,None,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,c,d,None], lambda x: x[None,j,k,o,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,None,None,None], lambda x: x[None,None,None,None,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,d,e], lambda x: x[None,None,j,k,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,None,None], lambda x: x[None,None,j,k,None,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,d,e], lambda x: x[i,None,None,k,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,None,None], lambda x: x[i,None,None,k,None,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,None,d,e], lambda x: x[None,None,j,None,o,p]) def test_slice_fancy_indexing_dim_inject_and_collapse(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa # dim injection and collapse helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,None,d,1], lambda x: x[1,j,None,o,1]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,2,d,None], lambda x: x[None,j,2,o,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,1,d,None], lambda x: x[...,1,o,None]) def test_slice_fancy_indexing_with_tensors(self): # indexing using idx with different dim helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)]) helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])]) helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor([2,1,1])], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor([2,1,1])]) helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,1,-1],[-1,-2,0]]), torch.tensor([2,1,-1])], lambda x: x[Tensor([[0,1,-1],[-1,-2,0]]), Tensor([2,1,-1])]) @slow_test @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_slice_fancy_indexing_list_indices(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[((0,),)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0,),b,c,d,:], lambda x: x[(0,),j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0,-1),b,c,d,:], lambda x: x[(1,0,-1),j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,(2,1,0),c,(-2,1,0),e], lambda x: x[i,(2,1,0),k,(-2,1,0),p]) @slow_test @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, QCOMCLRenderer), "QCOM CL vectorized bool bug") def test_slice_fancy_indexing_tuple_indices(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[(((0,),),)], lambda x: x[(((0,),),)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0,),b,c,d,:], lambda x: x[(0,),j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,((2,),(1,),(0,)),c,(2,1,0)], lambda x: x[i,((2,),(1,),(0,)),k,(2,1,0)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,(2,1,0),None,c,(2,1,0),e], lambda x: x[1,(2,1,0),None,k,(2,1,0),p]) @slow_test def test_slice_fancy_indexing_list_with_tensors(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,)], lambda x: x[(i,)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,1)], lambda x: x[(i,1)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)]) def test_slice_fancy_indexing_errors(self): a = Tensor.ones(10,11,12) # tensors used as indices must be int tensors with self.assertRaises(IndexError): a[Tensor(1.1)] with self.assertRaises(IndexError): a[[1.1]] with self.assertRaises(IndexError): a[Tensor([True, False])] with self.assertRaises(IndexError): a[[True, False]] # shape mismatch, cannot broadcast. either exception is okay with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)] with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)] # this is fine helper_test_op([(5, 6)], lambda x: x[[True, False, 2]]) def test_gather(self): # indices cannot have gradient # indices cannot be negative (torch gather) b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) helper_test_op([(4,5,6)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a)) helper_test_op([(3,4,5)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=-1, index=b), lambda x: x.gather(dim=-1, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=-2, index=b), lambda x: x.gather(dim=-2, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=-3, index=b), lambda x: x.gather(dim=-3, index=a)) self.helper_test_exception([(4,5,6)], lambda x: x.gather(dim=0, index=torch.tensor([1], dtype=torch.int64)), lambda x: x.gather(dim=0, index=Tensor([1], dtype=dtypes.int32)), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(2,1,1)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a), expected=(RuntimeError, AssertionError)) helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)), lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])), vals=[[1., 2., 3.]]) # gather with inf values helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)), lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])), vals=[[-float("inf"), 2., 3.]]) def test_scatter(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) for dim in (0,1,2,-1,-2,-3): helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src), lambda x,src: x.scatter(dim=dim, index=a, src=src), forward_only=True) helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter(dim=1, index=b, src=src), lambda x,src: x.scatter(dim=1, index=a, src=src), forward_only=True) helper_test_op([(10,3,10), (10,10,10)], lambda x,src: x.scatter(dim=1, index=b, src=src), lambda x,src: x.scatter(dim=1, index=a, src=src), forward_only=True) self.helper_test_exception([(2,3,10), (10,10,10)], lambda x,src: x.scatter(dim=1, index=b, src=src), lambda x,src: x.scatter(dim=1, index=a, src=src), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(10,3,10), (10,3,10)], lambda x,src: x.scatter(dim=1, index=b, src=src), lambda x,src: x.scatter(dim=1, index=a, src=src), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(3,4,5), (3,4,5)], lambda x,src: x.scatter(dim=1, index=b, src=src, mode="typo"), lambda x,src: x.scatter(dim=1, index=a, src=src, mode="typo"), expected=TypeError) self.helper_test_exception([(3,4,5), (3,4,5)], lambda x,src: x.half().scatter(dim=1, index=b, src=src), lambda x,src: x.half().scatter(dim=1, index=a, src=src), expected=RuntimeError) helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=3), lambda x: x.scatter(dim=1, index=a, src=3), forward_only=True) helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf")), lambda x: x.scatter(dim=1, index=a, src=float("inf")), forward_only=True) # overlapping indices with 0s b = torch.tensor([0,0], requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) helper_test_op(None, lambda x,src: x.scatter(0, b, src), lambda x,src: x.scatter(0, a, src), forward_only=True, vals=[[1.,2.,3.,4.], [1.,0.]]) def test_scatter_add(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"), lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="add"), forward_only=True) # TODO: fails for webgpu if Device.DEFAULT != "WEBGPU": helper_test_op([(4,5,6)], lambda x: x.scatter(1, b, float("nan"), reduce="add"), lambda x: x.scatter(1, a, float("nan"), reduce="add"), forward_only=True) def test_scatter_mul(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"), lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True) # TODO: fails for webgpu if Device.DEFAULT != "WEBGPU": helper_test_op([(4,5,6)], lambda x: x.scatter(1, b, float("nan"), reduce="multiply"), lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True) def test_scatter_no_reduce_tensor_src(self): with self.assertRaises(TypeError): Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add") @slow_test def test_scatter_reduce(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32)) for reduce in ("sum", "prod", "mean", "amin", "amax"): for dim in (-1,1,-3): helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce), lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True) helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) def test_scatter_reduce_prod_zeros(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) x = Tensor.zeros([4,5,6]).float() y = torch.zeros([4,5,6]).float() helper_test_op([(4,5,6)], lambda src: y.scatter_reduce(dim=1, index=b, src=src, reduce="prod"), lambda src: x.scatter_reduce(dim=1, index=a, src=src, reduce="prod"), forward_only=True) def test_scatter_reduce_errors(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32) # invalid reduce arg self.helper_test_exception([(4,5,6), (4,5,6)], lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"), lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"), RuntimeError) # dtype mismatch self.helper_test_exception([(4,5,6), (4,5,6)], lambda x,src: x.half().scatter_reduce(dim=0, index=b, src=src, reduce="sum"), lambda x,src: x.half().scatter_reduce(dim=0, index=a, src=src, reduce="sum"), RuntimeError) @slow_test def test_scaled_dot_product_attention(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m)) @slow_test def test_scaled_dot_product_attention_mismatch_ls(self): helper_test_op([(32,8,4,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) @slow_test def test_scaled_dot_product_attention_causal(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True)) self.helper_test_exception([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m), expected=RuntimeError) @slow_test def test_scaled_dot_product_attention_gqa(self): helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True)) def test_scaled_dot_product_attention_gqa_errors(self): self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True), expected=(AssertionError, RuntimeError, ValueError, IndexError)) def test_binary_crossentropy(self): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) def test_binary_crossentropy_reductions(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(), y.clip(0,1), reduction=r), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1), reduction=r)) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x, y.clip(0,1), reduction=r), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1), reduction=r)) def test_binary_crossentropy_logits_pos_weights(self): pos_weight = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1), pos_weight=torch.tensor(pos_weight)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight))) def test_cross_entropy_class_probabilities(self): helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) helper_test_op([(32,4,4,4), (32,4,4,4)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) def test_cross_entropy_class_indices(self): classes = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() helper_test_op([(32,10)], lambda x: torch.nn.functional.cross_entropy(x, torch.tensor(classes)), lambda x: x.cross_entropy(Tensor(classes))) self.helper_test_exception([(32,10), (32,1)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y), expected=(AssertionError, RuntimeError)) def test_cross_entropy_reductions(self): for r in ("mean", "sum", "none"): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r), lambda x,y: x.cross_entropy(y, reduction=r)) self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"), lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError) def test_cross_entropy_smoothing(self): for ls in (0., 0.3, 0.7, 1.): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls), lambda x,y: x.cross_entropy(y, label_smoothing=ls)) classes = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() helper_test_op([(32,10)], lambda x: torch.nn.functional.cross_entropy(x, torch.tensor(classes), label_smoothing=ls), lambda x: x.cross_entropy(Tensor(classes), label_smoothing=ls)) def test_sparse_categorical_crossentropy(self): classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist() helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss()(x, torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes))) # combine args helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(reduction="mean", ignore_index=classes[0], label_smoothing=0.3)(x, torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes), reduction="mean", ignore_index=classes[0], label_smoothing=0.3)) # with batch. somehow this does not match torch classes = np.random.randint(0, 10, (3,12), dtype=np.int32).tolist() helper_test_op([(3,12,10)], lambda x: torch.nn.CrossEntropyLoss()(x.permute(0,2,1), torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes))) def test_sparse_categorical_crossentropy_reductions(self): for r in ("mean", "sum", "none"): classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist() helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(reduction=r)(x, torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes), reduction=r)) def test_sparse_categorical_crossentropy_ignore_index(self): classes = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] for i in (-1, 0, 3): helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(ignore_index=i)(x, torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes), ignore_index=i)) def test_sparse_categorical_crossentropy_label_smoothing(self): for s in (0.3, 0.9): classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist() helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(label_smoothing=s)(x, torch.tensor(classes)), lambda x: x.sparse_categorical_crossentropy(Tensor(classes), label_smoothing=s)) def test_nll_loss(self): target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() helper_test_op([(32,10)], lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target)), lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target))) def test_nll_loss_3d(self): target = np.random.randint(0, 10, (32,3,3,3), dtype=np.int32).tolist() helper_test_op([(32,10,3,3,3)], lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target)), lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target))) def test_nll_loss_reductions(self): target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() for r in ("mean", "sum", "none"): helper_test_op([(32,10)], lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), reduction=r), lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), reduction=r)) self.helper_test_exception([(32,10)], lambda x: torch.nn.functional.nll_loss(x, torch.tensor(target), reduction="typo"), lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=ValueError) def test_nll_loss_weight(self): target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() weight = np.random.normal(0, 1, (10,)).astype(np.float32).tolist() for r in ("mean", "sum", "none"): helper_test_op([(32,10)], lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), torch.tensor(weight), reduction=r), lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r)) def test_nll_loss_3d_weight(self): target = np.random.randint(0, 10, (16,3,3,3), dtype=np.int32).tolist() weight = np.random.normal(0, 1, (10,)).astype(np.float32).tolist() for r in ("mean", "sum", "none"): helper_test_op([(16,10,3,3,3)], lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), torch.tensor(weight), reduction=r), lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r)) def test_nll_loss_ignore_index(self): logits = [[2.0, 0.5, -1.0], [1.5, 2.5, -0.5], [0.0, -2.0, 1.0]] target = [0, 1, 2] helper_test_op(None, lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), ignore_index=1), lambda x: x.log_softmax().nll_loss(Tensor(target), ignore_index=1), vals=[logits]) def test_one_hot(self): data = [1, 2, 4] helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32), lambda: Tensor(data).one_hot(6), forward_only=True) # like jax.nn.one_hot, num_classes must be non-negative (torch accepts -1 for auto-inference, we don't) with self.assertRaises(ValueError): Tensor(data).one_hot(-1) data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]] helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32), lambda: Tensor(data).one_hot(8), forward_only=True) def test_masked_fill(self): helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf)) @unittest.skipIf((DEV.interface.startswith("MOCK") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu runtime issue") @unittest.skipIf(Device.DEFAULT == "QCOM", "QCOM fails with: Resource deadlock avoided") def test_masked_select(self): helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True) helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True) @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_masked_select_size(self): t = Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) mask = Tensor([True, False, True, False, True, False, False, False, True]) np.testing.assert_equal(t.masked_select(mask, size=4).numpy(), [0, 2, 4, 8]) np.testing.assert_equal(t.masked_select(mask, size=6, fill_value=-1).numpy(), [0, 2, 4, 8, -1, -1]) np.testing.assert_equal(t.masked_select(mask, size=2).numpy(), [0, 2]) np.testing.assert_equal(Tensor([], dtype=dtypes.int32).masked_select(Tensor([], dtype=dtypes.bool), size=2, fill_value=-1).numpy(), [-1, -1]) # fill_value must not alter output dtype self.assertEqual(Tensor([1.0, 2.0]).masked_select(Tensor([True, False]), size=3, fill_value=-1).dtype, dtypes.default_float) @slow_test def test_nonzero(self): helper_test_op([(32, 10)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) helper_test_op([(20,)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) helper_test_op([(10, 5, 3)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) for v in (0, 1, 0.0, 2.5, True, False): helper_test_op(None, lambda x: x.nonzero().int(), lambda x: x.nonzero(), vals=[v], forward_only=True) @unittest.skipIf(COMPILE_ONLY, "test requires runtime") def test_nonzero_size(self): np.testing.assert_equal(Tensor([1, 0, 2, 0, 3]).nonzero(size=3).numpy(), [[0], [2], [4]]) np.testing.assert_equal(Tensor([1, 0, 2, 0, 3]).nonzero(size=5, fill_value=-1).numpy(), [[0], [2], [4], [-1], [-1]]) np.testing.assert_equal(Tensor([[1, 0], [0, 2]]).nonzero(size=2).numpy(), [[0, 0], [1, 1]]) self.assertEqual(Tensor(5).nonzero(size=4).shape, (4, 0)) np.testing.assert_equal(Tensor([], dtype=dtypes.int32).nonzero(size=3, fill_value=-1).numpy(), [[-1], [-1], [-1]]) # fill_value must not promote dtype to float self.assertEqual(Tensor([1, 0]).nonzero(size=3, fill_value=-1.5).dtype, dtypes.default_int) def test_cast(self): helper_test_op([(3, 3)], lambda x: x.float()) helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True) helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True) helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True) helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True) def test_bitcast(self): helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True) def test_int_or(self): t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item() if not COMPILE_ONLY: assert t == -1 class TestOpsUint8(unittest.TestCase): def test_cast(self): helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255) def test_cast_relu(self): helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True) def test_interpolate_bilinear(self): out_sz = (10, 10) helper_test_op([(2,3,64,64)], lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="bilinear"), lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="linear"), forward_only=True) def test_interpolate_nearest(self): out_sz = (10, 10) helper_test_op([(2,3,64,64)], lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest"), lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest"), forward_only=True) def test_interpolate_nearest_exact(self): out_sz = (10, 10) helper_test_op([(2,3,64,64)], lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"), lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True) def test_min(self): helper_test_op(None, lambda x: x.type(torch.uint8).min(), lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[[0, 1, 2], [3, 4, 5]]]) helper_test_op(None, lambda x: x.type(torch.uint8).min(), lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]]) if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2)