diff --git a/test/test_jit.py b/test/test_jit.py index 20a4aa1289..359019212b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7,6 +7,7 @@ from tinygrad.tensor import Tensor from tinygrad.engine.jit import TinyJit from tinygrad.device import Device from tinygrad.helpers import CI +from tinygrad.dtype import dtypes def _simple_test(add, extract=lambda x: x, N=10): for _ in range(5): @@ -304,6 +305,21 @@ class TestJit(unittest.TestCase): np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5) np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5) + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU/CUDA/METAL in CI, fine to run on AMD/NV") + def test_jitted_view(self): + d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" + + def f(a): + x1 = a.sum(axis=(1,)) + x = (x1 + 5).bitcast(dtypes.int32) + y = x.to(d1) + return y.realize() + + jf = TinyJit(f) + for _ in range(5): + a = Tensor.randn(10, 1000, device=d0).realize() + xc = jf(a) + np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=1e-5) @unittest.skip("Pending multioutput implementation #3607") class TestMultioutputJit(unittest.TestCase): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 02908e8743..eb4f030e1f 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -97,12 +97,12 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method wait_nodes = [] for rawbuf in read + write: - if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)]) + if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)]) for rawbuf in write: - if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf))) + if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) - for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency) - for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency + for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) + for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency return list({id(x):x for x in wait_nodes}.values()) ReturnType = TypeVar('ReturnType')