mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
shorter schedule ops import [pr] (#7531)
This commit is contained in:
@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import MetaOps, GroupOp, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -59,7 +59,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
buf.buffer.options = None
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
# consts are always fused and generated
|
||||
if buf.op is MetaOps.CONST:
|
||||
if buf.op is Ops.CONST:
|
||||
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
|
||||
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
|
||||
# everything else has BUFFER
|
||||
@@ -69,11 +69,11 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
# everything else needs sources
|
||||
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
|
||||
if buf.op in GroupOp.Reduce: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op is Ops.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is Ops.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg)
|
||||
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
|
||||
elif buf.op is Ops.CAST: ret = UOp(Ops.CAST, dtype, src)
|
||||
elif buf.op is Ops.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
|
||||
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
@@ -244,7 +244,7 @@ break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not MetaOps.CONST)) == 0: return [], {}
|
||||
if len(outs:=dedup(x.base for x in outs if x.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
||||
for out in outs: out.forced_realize = True
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
|
||||
Reference in New Issue
Block a user