shorter schedule ops import [pr] (#7531)

This commit is contained in:
qazal
2024-11-04 16:02:43 +02:00
committed by GitHub
parent b5718ae135
commit 1d4df72798

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import MetaOps, GroupOp, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -59,7 +59,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
buf.buffer.options = None
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# consts are always fused and generated
if buf.op is MetaOps.CONST:
if buf.op is Ops.CONST:
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
# everything else has BUFFER
@@ -69,11 +69,11 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
# everything else needs sources
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
elif buf.op is Ops.CAST: ret = UOp(Ops.CAST, dtype, src)
elif buf.op is Ops.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
@@ -244,7 +244,7 @@ break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
for out in outs: out.forced_realize = True
# create the big graph
ctx = ScheduleContext()