test new style gated store rendering (#6413)

* test new style gated store rendering

* switch to lidx

* make lidx optional

* fixup [run_process_replay]
This commit is contained in:
qazal
2024-09-09 13:59:22 +08:00
committed by GitHub
parent 90fb17304f
commit ff8a9ac3c1
4 changed files with 36 additions and 8 deletions

View File

@@ -1,27 +1,28 @@
import unittest
from typing import List, cast
import numpy as np
from tinygrad.codegen.uopgraph import full_graph_rewrite, linearize_uop
from tinygrad.device import Buffer, Device
from tinygrad.dtype import PtrDType, DType, dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten
from tinygrad.helpers import dedup, flatten, getenv, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, UOps
from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.lazy import LazyBuffer
def _test_uop_result(inputs:List[Tensor], stores:List[UOp]):
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
for x in inputs: x.realize()
assert all(x.op is UOps.STORE for x in stores)
# NOTE: we only toposort the stores
uops: List[UOp] = []
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, 1, cast(DType,u.src[2].dtype)).allocate() for u in uops if u.op is UOps.STORE]
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=cast(DType,u.src[2].dtype)), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render("test", uops)
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops))
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
@@ -38,5 +39,31 @@ class TestCStyleFailures(unittest.TestCase):
ret = _test_uop_result([Tensor([1])], [store])[0]
self.assertEqual(ret[0], 1)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
class TestPTXFailures(unittest.TestCase):
def test_gated_store_with_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
sink = UOp(UOps.SINK, None, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
def test_gated_store_with_if(self):
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
val = UOp.const(dtypes.int, 1)
if_uop = UOp(UOps.IF, None, (gate_alu, val))
gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, val, if_uop))
sink = UOp(UOps.SINK, None, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
if getenv("PTX"):
with self.assertRaises(AssertionError):
np.testing.assert_equal(ret, [0, 1, 1, 1])
else: np.testing.assert_equal(ret, [0, 1, 1, 1])
if __name__ == '__main__':
unittest.main()

View File

@@ -507,7 +507,7 @@ def type_verify(uops):
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"
if uop is UOps.ALU:
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:

View File

@@ -193,7 +193,8 @@ class PTXRenderer(Renderer):
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
else:
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype,
gate=r[src[3]] if len(src)>3 and src[3].op is not UOps.IF else None, ss=mem_type, offset=src[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))

View File

@@ -127,7 +127,7 @@ class CStyleLanguage(Renderer):
# mark DEFINE_GLOBAL buf as writable
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: