remove cast before view (#8613)

* remove cast before view

* greener

* indexing

* that passes too

* openpilot too

* ack

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2025-01-14 12:04:58 -08:00
committed by GitHub
parent 393eec3201
commit bfbe81df71
9 changed files with 22 additions and 25 deletions

View File

@@ -91,7 +91,7 @@ jobs:
- name: Run GPT2 w HALF
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
@@ -217,7 +217,7 @@ jobs:
- name: Run GPT2 w HALF
run: NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA)
@@ -406,7 +406,7 @@ jobs:
- name: Run GPT2 w HALF
run: AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD)

View File

@@ -297,7 +297,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx

View File

@@ -66,7 +66,8 @@ class TestArange(unittest.TestCase):
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
class TestIndexing(unittest.TestCase):
@unittest.expectedFailure
# update: passing after CAST_BEFORE_VIEW=1 deletion
# @unittest.expectedFailure
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1

View File

@@ -132,11 +132,12 @@ class TestMovedConstFolding(unittest.TestCase):
def test_cast_padded(self):
# NOTE: this is folded due to CAST_BEFORE_VIEW
# update: CAST_BEFORE_VIEW=1 is no longer supported
if is_dtype_supported(dtypes.int16):
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# not folded
if is_dtype_supported(dtypes.int64):

View File

@@ -120,7 +120,7 @@ class TestImageDType(unittest.TestCase):
loss = x.image_dot(w1).image_dot(w2).float().max()
loss.backward()
sched = unwrap(w1.grad).schedule()
self.assertEqual(len(sched), 10)
self.assertEqual(len(sched), 9)
for s,ei in zip(sched, lower_schedule(sched[:])):
ei.run()
if s.outputs[0].dtype == dtypes.float:

View File

@@ -1436,6 +1436,7 @@ class TestSchedule(unittest.TestCase):
def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_view(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
@@ -1446,6 +1447,7 @@ class TestSchedule(unittest.TestCase):
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
# NOTE: we might want to reconsider pushing this cast before the shrink
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_after_shrink(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
@@ -1455,6 +1457,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
self.assertListEqual(realized_view.tolist(), [[0, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
@@ -1464,6 +1467,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_const(self):
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dtypes.float32)
@@ -1566,7 +1570,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous(self):
@@ -1574,7 +1578,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_child(self):
@@ -1582,7 +1586,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous_child(self):
@@ -1590,7 +1594,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_childless_base(self):

View File

@@ -81,6 +81,8 @@ class TestTiny(unittest.TestCase):
# *** a model ***
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
@unittest.skipIf(IMAGE>0, "failing because of make things that can't be images not images")
def test_mnist_model(self):
layers = [
nn.Conv2d(1, 32, 5), Tensor.relu,

View File

@@ -500,7 +500,7 @@ break_sched = PatternMatcher([
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
# BUFFER_VIEW overrides the underlying buffer
# TODO: this should be a shrink on the buffer

View File

@@ -361,17 +361,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def cast(self, dtype:DType, bitcast=False):
if bitcast: return self.bitcast(dtype)
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
# TODO: move this to the scheduler
ret = self.base.cast(dtype, bitcast)
op_arg = []
mop = self
while mop is not self.base:
op_arg.append((mop.op, mop.arg))
mop = mop.src[0]
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
return ret
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType):
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
@@ -477,7 +466,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)