reorder AST matchers + comments [pr] (#9193)

This commit is contained in:
qazal
2025-02-21 15:31:15 +02:00
committed by GitHub
parent 2eab8021fb
commit 8bb80b6e5e

View File

@@ -226,15 +226,24 @@ create_kernels = merge_views+PatternMatcher([
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
])
# **** convert Kernel to a ScheduleItem (for legacy reasons)
# **** fix kernel AST
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
# ** create buffer ops + enumerate buffers
# **** Kernel creation
def load_buf(ctx:list[UOp], x:UOp):
if x not in ctx: ctx.append(x)
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), load_buf),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
])
# ** push views to buffer ops
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
@@ -288,13 +297,22 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def _append_st_vars(ctx:dict[Variable, int], x:UOp) -> UOp|None:
# ** unbind variables
def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None:
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.update(var_vals)
return st.to_uop() if st != x.st else None
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
ctx[var.replace(src=())] = val.arg
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
# ** fix_kernel_ops
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
@@ -307,7 +325,7 @@ def check_load_st(glbl:UOp, view:UOp):
fix_kernel_ops = PatternMatcher([
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
# remove CONTIGUOUS/ASSIGN/DEVICE
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
@@ -318,23 +336,12 @@ fix_kernel_ops = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])
def load_buf(ctx:list[UOp], x:UOp):
if x not in ctx: ctx.append(x)
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), load_buf),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
])
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
ctx[var.replace(src=())] = val.arg
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
# TODO: replace this with the KERNEL UOp
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"