diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 76be7195d2..cf7869d4ce 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -1,27 +1,28 @@ import unittest from typing import List, cast import numpy as np +from tinygrad.codegen.uopgraph import full_graph_rewrite, linearize_uop from tinygrad.device import Buffer, Device from tinygrad.dtype import PtrDType, DType, dtypes from tinygrad.engine.realize import CompiledRunner -from tinygrad.helpers import dedup, flatten +from tinygrad.helpers import dedup, flatten, getenv, prod from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.ops import BinaryOps, UOp, UOps from tinygrad.renderer import Program from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.lazy import LazyBuffer -def _test_uop_result(inputs:List[Tensor], stores:List[UOp]): +def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): for x in inputs: x.realize() - assert all(x.op is UOps.STORE for x in stores) # NOTE: we only toposort the stores uops: List[UOp] = [] def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop] uops = dedup(flatten(_recursive_add(st) for st in stores)) - outbufs = [Buffer(Device.DEFAULT, 1, cast(DType,u.src[2].dtype)).allocate() for u in uops if u.op is UOps.STORE] + outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=cast(DType,u.src[2].dtype)), \ + initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE] inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render("test", uops) - ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops)) + ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size)) ei.exec(outbufs+inbufs) return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] @@ -38,5 +39,31 @@ class TestCStyleFailures(unittest.TestCase): ret = _test_uop_result([Tensor([1])], [store])[0] self.assertEqual(ret[0], 1) +@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") +class TestPTXFailures(unittest.TestCase): + def test_gated_store_with_alu(self): + a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) + sink = UOp(UOps.SINK, None, (gated_alu_store,)) + uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] + np.testing.assert_equal(ret, [0, 1, 1, 1]) + + def test_gated_store_with_if(self): + a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + val = UOp.const(dtypes.int, 1) + if_uop = UOp(UOps.IF, None, (gate_alu, val)) + gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, val, if_uop)) + sink = UOp(UOps.SINK, None, (gated_alu_store,)) + uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) + ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] + + if getenv("PTX"): + with self.assertRaises(AssertionError): + np.testing.assert_equal(ret, [0, 1, 1, 1]) + else: np.testing.assert_equal(ret, [0, 1, 1, 1]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 803cd1527f..966658b0e6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -507,7 +507,7 @@ def type_verify(uops): if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool if uop is UOps.STORE: assert dtype is None, f"{uop} dtype must be None, got {dtype}" - if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" + if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}" if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 813000d31d..165632e006 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -193,7 +193,8 @@ class PTXRenderer(Renderer): kk((f"@{r[src[3]]} " if len(src)>3 else "") + \ f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};") else: - kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg)) + kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, + gate=r[src[3]] if len(src)>3 and src[3].op is not UOps.IF else None, ss=mem_type, offset=src[1].arg)) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f0d82e5d9f..b6762b6b3e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -127,7 +127,7 @@ class CStyleLanguage(Renderer): # mark DEFINE_GLOBAL buf as writable if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True)) rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) - kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store) + kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: