mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
remove cast before view (#8613)
* remove cast before view * greener * indexing * that passes too * openpilot too * ack --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
6
.github/workflows/benchmark.yml
vendored
6
.github/workflows/benchmark.yml
vendored
@@ -91,7 +91,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
@@ -217,7 +217,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
@@ -406,7 +406,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD)
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -297,7 +297,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
|
||||
@@ -66,7 +66,8 @@ class TestArange(unittest.TestCase):
|
||||
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
# update: passing after CAST_BEFORE_VIEW=1 deletion
|
||||
# @unittest.expectedFailure
|
||||
def test_arange_2_reduce(self):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
|
||||
@@ -132,11 +132,12 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
||||
# update: CAST_BEFORE_VIEW=1 is no longer supported
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# not folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
|
||||
@@ -120,7 +120,7 @@ class TestImageDType(unittest.TestCase):
|
||||
loss = x.image_dot(w1).image_dot(w2).float().max()
|
||||
loss.backward()
|
||||
sched = unwrap(w1.grad).schedule()
|
||||
self.assertEqual(len(sched), 10)
|
||||
self.assertEqual(len(sched), 9)
|
||||
for s,ei in zip(sched, lower_schedule(sched[:])):
|
||||
ei.run()
|
||||
if s.outputs[0].dtype == dtypes.float:
|
||||
|
||||
@@ -1436,6 +1436,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_late_fusion_post_expand(self):
|
||||
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_view(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
@@ -1446,6 +1447,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we might want to reconsider pushing this cast before the shrink
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_after_shrink(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
@@ -1455,6 +1457,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
@@ -1464,6 +1467,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
@@ -1566,7 +1570,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous(self):
|
||||
@@ -1574,7 +1578,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_child(self):
|
||||
@@ -1582,7 +1586,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)+1
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous_child(self):
|
||||
@@ -1590,7 +1594,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_childless_base(self):
|
||||
|
||||
@@ -81,6 +81,8 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
# *** a model ***
|
||||
|
||||
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
|
||||
@unittest.skipIf(IMAGE>0, "failing because of make things that can't be images not images")
|
||||
def test_mnist_model(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
|
||||
@@ -500,7 +500,7 @@ break_sched = PatternMatcher([
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.src:
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
# BUFFER_VIEW overrides the underlying buffer
|
||||
# TODO: this should be a shrink on the buffer
|
||||
|
||||
@@ -361,17 +361,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def cast(self, dtype:DType, bitcast=False):
|
||||
if bitcast: return self.bitcast(dtype)
|
||||
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
|
||||
# TODO: move this to the scheduler
|
||||
ret = self.base.cast(dtype, bitcast)
|
||||
op_arg = []
|
||||
mop = self
|
||||
while mop is not self.base:
|
||||
op_arg.append((mop.op, mop.arg))
|
||||
mop = mop.src[0]
|
||||
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
||||
return ret
|
||||
return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType):
|
||||
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
|
||||
@@ -477,7 +466,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
|
||||
Reference in New Issue
Block a user