mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
test new style gated store rendering (#6413)
* test new style gated store rendering * switch to lidx * make lidx optional * fixup [run_process_replay]
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, linearize_uop
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.dtype import PtrDType, DType, dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten
|
||||
from tinygrad.helpers import dedup, flatten, getenv, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp]):
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
assert all(x.op is UOps.STORE for x in stores)
|
||||
# NOTE: we only toposort the stores
|
||||
uops: List[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, 1, cast(DType,u.src[2].dtype)).allocate() for u in uops if u.op is UOps.STORE]
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=cast(DType,u.src[2].dtype)), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
|
||||
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops))
|
||||
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
@@ -38,5 +39,31 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
ret = _test_uop_result([Tensor([1])], [store])[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
class TestPTXFailures(unittest.TestCase):
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||
sink = UOp(UOps.SINK, None, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
def test_gated_store_with_if(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
val = UOp.const(dtypes.int, 1)
|
||||
if_uop = UOp(UOps.IF, None, (gate_alu, val))
|
||||
gated_alu_store = UOp(UOps.STORE, None, (a, lidx0, val, if_uop))
|
||||
sink = UOp(UOps.SINK, None, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
|
||||
if getenv("PTX"):
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
else: np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -507,7 +507,7 @@ def type_verify(uops):
|
||||
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
|
||||
@@ -193,7 +193,8 @@ class PTXRenderer(Renderer):
|
||||
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
||||
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
||||
else:
|
||||
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
||||
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype,
|
||||
gate=r[src[3]] if len(src)>3 and src[3].op is not UOps.IF else None, ss=mem_type, offset=src[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
|
||||
@@ -127,7 +127,7 @@ class CStyleLanguage(Renderer):
|
||||
# mark DEFINE_GLOBAL buf as writable
|
||||
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE:
|
||||
|
||||
Reference in New Issue
Block a user