mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
remove TestShapeSpec, it relies on ShapeTracker [pr] (#12369)
This commit is contained in:
@@ -544,86 +544,6 @@ class TestUopsObject(unittest.TestCase):
|
||||
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
|
||||
assert len(ret) == 10000
|
||||
|
||||
|
||||
class TestShapeSpec(unittest.TestCase):
|
||||
# ** CONST is CONST(VIEW(DEVICE)) -> RESHPAE -> EXPAND
|
||||
|
||||
def test_expanded_const(self):
|
||||
a = Tensor(1).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()))
|
||||
a = Tensor.ones((4, 4)).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
|
||||
|
||||
# NOTE: CONST ShapeTracker comes from its source
|
||||
def test_scalar_const(self):
|
||||
a = Tensor(0).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()))
|
||||
|
||||
def test_scalar_var(self):
|
||||
vv = UOp.variable("a", 1, 4).bind(2)
|
||||
t = Tensor(vv).uop
|
||||
self.assertEqual(t.st, ShapeTracker.from_shape(()))
|
||||
|
||||
# ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val)
|
||||
|
||||
def test_assign_flat(self):
|
||||
buffer = Tensor.arange(4).realize()
|
||||
a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int))
|
||||
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat()))
|
||||
assert assign_pattern.match(a.uop, {})
|
||||
a.realize()
|
||||
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
|
||||
|
||||
def test_assign_permuted(self):
|
||||
buffer = Tensor.arange(4).reshape(2, 1, 2).contiguous().realize()
|
||||
a = buffer.permute((1, 2, 0)).assign(Tensor.arange(4).reshape(1, 2, 2).contiguous())
|
||||
a.realize()
|
||||
self.assertEqual(buffer.tolist(), [[[0, 2]], [[1, 3]]])
|
||||
|
||||
def test_assign_reshaped(self):
|
||||
buffer = Tensor.ones((4,)).contiguous().realize()
|
||||
a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2)))
|
||||
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat()))
|
||||
assert assign_pattern.match(a.uop, {})
|
||||
a.realize()
|
||||
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
|
||||
|
||||
# setitem is a partial assign
|
||||
def test_setitem(self):
|
||||
a = Tensor.ones((4,)).contiguous().realize()
|
||||
assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,)))
|
||||
# the ASSIGN UOp has size=1
|
||||
self.assertEqual(assign.uop.size, 1)
|
||||
# the ASSIGN views the buffer with a shrunk st
|
||||
self.assertEqual(assign.uop.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
|
||||
# the underlying BUFFER has a size=4
|
||||
self.assertEqual(assign.uop.buf_uop.size, 4)
|
||||
# NOTE: output shape is different from the BUFFER shape
|
||||
self.assertNotEqual(assign.uop.shape, a.uop.shape)
|
||||
assign.realize()
|
||||
self.assertEqual(a.tolist(), [1, 0, 1, 1])
|
||||
|
||||
def test_buffer_st(self):
|
||||
a = UOp.new_buffer(Device.DEFAULT, 10, dtypes.float)
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape((10,)))
|
||||
|
||||
def test_ops_st(self):
|
||||
# view / mop
|
||||
a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape((4, 2, 1)).permute((1, 2, 0)))
|
||||
# alu / reduce
|
||||
alu = a*2
|
||||
self.assertEqual(alu.st, ShapeTracker.from_shape((2, 1, 4)))
|
||||
r = Tensor.empty(4, 4).sum(axis=1)
|
||||
self.assertEqual(r.uop.st, ShapeTracker.from_shape((4,)))
|
||||
|
||||
def test_st_wmma_none(self):
|
||||
A = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('a', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 1)))
|
||||
B = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('b', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 2)))
|
||||
C = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('c', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 3)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.float.vec(16), (A, B, C))
|
||||
assert wmma.st is None
|
||||
|
||||
class TestUOpChildren(unittest.TestCase):
|
||||
def test_children_exist(self):
|
||||
a = UOp.variable("weird_name_234", 0, 10)
|
||||
|
||||
Reference in New Issue
Block a user