mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
use UOp.st for kernel reduce axes (#8499)
* use UOp.st for kernel reduce axes [pr] * do not return dict
This commit is contained in:
@@ -2157,9 +2157,9 @@ class TestKernelOpts(unittest.TestCase):
|
||||
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
helper_linearizer_ast(sink, [data1, data2], opts=[
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
|
||||
@@ -75,10 +75,10 @@ class TestVerifyAST(unittest.TestCase):
|
||||
|
||||
def test_buffer_uops_st(self):
|
||||
a = Tensor.randn(4, 4)+2
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
|
||||
verify_ast(ast:=a.schedule()[-1].ast)
|
||||
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
||||
const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
def test_assert_swizzle(self):
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, print_uops, typ
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, ContextVar
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
@@ -57,7 +57,7 @@ class Kernel:
|
||||
if ast.op is Ops.SINK: self.ast = ast
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
try: uop_sts_map = verify_ast(self.ast)
|
||||
try: verify_ast(self.ast)
|
||||
except AssertionError as e:
|
||||
print("INVALID AST")
|
||||
print(self.ast)
|
||||
@@ -80,8 +80,8 @@ class Kernel:
|
||||
# add the shapetrackers for each reduce
|
||||
# we use this to track which axes are reduced in each reduce
|
||||
for x in self.reduceops:
|
||||
self.sts.append(uop_sts_map[x])
|
||||
self.sts.append(uop_sts_map[x.src[0]])
|
||||
self.sts.append(unwrap(x.st))
|
||||
self.sts.append(unwrap(x.src[0].st))
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.output_shape)))
|
||||
@@ -707,7 +707,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) ->
|
||||
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
||||
sts[uop] = st
|
||||
|
||||
def verify_ast(ast:UOp) -> dict[UOp, ShapeTracker]:
|
||||
def verify_ast(ast:UOp) -> None:
|
||||
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
||||
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
||||
sts: dict[UOp, ShapeTracker] = {}
|
||||
@@ -715,4 +715,3 @@ def verify_ast(ast:UOp) -> dict[UOp, ShapeTracker]:
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
type_verify(list(sts))
|
||||
return sts
|
||||
|
||||
Reference in New Issue
Block a user