mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
reorder AST matchers + comments [pr] (#9193)
This commit is contained in:
@@ -226,15 +226,24 @@ create_kernels = merge_views+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
])
|
||||
|
||||
# **** convert Kernel to a ScheduleItem (for legacy reasons)
|
||||
# **** fix kernel AST
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
# ** create buffer ops + enumerate buffers
|
||||
|
||||
# **** Kernel creation
|
||||
def load_buf(ctx:list[UOp], x:UOp):
|
||||
if x not in ctx: ctx.append(x)
|
||||
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), load_buf),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
])
|
||||
|
||||
# ** push views to buffer ops
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
||||
@@ -288,13 +297,22 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
||||
# ** unbind variables
|
||||
|
||||
def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx.update(var_vals)
|
||||
return st.to_uop() if st != x.st else None
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
|
||||
# ** fix_kernel_ops
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
@@ -307,7 +325,7 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
# remove CONTIGUOUS/ASSIGN/DEVICE
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
@@ -318,23 +336,12 @@ fix_kernel_ops = PatternMatcher([
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
||||
])
|
||||
|
||||
def load_buf(ctx:list[UOp], x:UOp):
|
||||
if x not in ctx: ctx.append(x)
|
||||
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), load_buf),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
])
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
# TODO: replace this with the KERNEL UOp
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
|
||||
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
||||
|
||||
Reference in New Issue
Block a user