mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
simplify tensors before scheduling [pr] (#8580)
* delete forced_realize * put that back * work * remove forced_realize * expectedFailures * contiguous(buffer) * multi * expectedFailures * cleaner create_subbuffer * more comments * remove that * note * realizes * work * one upat and image is back * remove * cleaner * fix test_complex_backward for now --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -113,7 +113,8 @@ class TestImageDType(unittest.TestCase):
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
|
||||
# issue caused by: don't realize image to image casts. this is part of a larger problem
|
||||
@unittest.expectedFailure
|
||||
#@unittest.expectedFailure
|
||||
# update: passing after tensor_map
|
||||
def test_lil_model(self):
|
||||
with Context(IMAGE=2):
|
||||
x = Tensor.zeros(1, 1)
|
||||
|
||||
@@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
expr = (a*b)/b
|
||||
expr.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now!
|
||||
self.assertEqual(GlobalCounters.global_ops, 0)
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
|
||||
|
||||
@@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
expr = a/a
|
||||
expr.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_ops, 0)
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
|
||||
|
||||
@@ -2204,7 +2204,7 @@ class TestConst(unittest.TestCase):
|
||||
sched = add.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
# b+0 and b share the same underlying device memory
|
||||
self.assertIs(add.lazydata.realized, b.lazydata.realized)
|
||||
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
|
||||
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
|
||||
|
||||
def test_src_masked_const_folding(self):
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
@@ -88,15 +88,15 @@ class ScheduleContext:
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# SINK is passthrough
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
||||
if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
||||
# VIEW is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st))
|
||||
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
@@ -105,9 +105,9 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
ctx.tensor_uops[buf_uop] = [buf]
|
||||
ctx.tensor_uops[buf_uop] = tensor_map[buf]
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
@@ -358,10 +358,8 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
|
||||
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
|
||||
old_base = contig.src[0].src[0]
|
||||
if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti)
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
@@ -372,8 +370,6 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# if the uop folded to a CONST we can delete the BUFFER
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
|
||||
# DETACH is a NOOP here
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
@@ -386,13 +382,16 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# double contiguous is one contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
@@ -400,34 +399,6 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** buffer merging
|
||||
|
||||
def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
|
||||
assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}"
|
||||
# if b2 is realized also realize b1
|
||||
if b2 in ctx.realizes:
|
||||
ctx.realizes[b1] = b1
|
||||
del ctx.realizes[b2]
|
||||
# ops referring to b2 now ref to b1
|
||||
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
|
||||
del ctx.tensor_uops[b2]
|
||||
# merge
|
||||
return v1
|
||||
|
||||
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
|
||||
# early become
|
||||
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st))
|
||||
return v1
|
||||
|
||||
merge_bufs = PatternMatcher([
|
||||
# merge base
|
||||
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge),
|
||||
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized),
|
||||
# collapse view
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
@@ -481,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
@@ -523,15 +494,13 @@ remove_movement_ops = PatternMatcher([
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
# if using VIZ, do a graph rewrite to vizualize the Tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext())
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={})
|
||||
# const folding and fusion
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx)
|
||||
sink = graph_rewrite(sink, merge_bufs, ctx)
|
||||
# create the scheduler context
|
||||
graph_rewrite(sink, create_ctx, ctx)
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
@@ -539,13 +508,17 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
|
||||
# TODO: this still exists because symbolic folding is happening after bufferization
|
||||
if not all(x.op is Ops.STORE for x in small_sink.src): continue
|
||||
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
|
||||
prescheduled.append(schedule_uop(small_sink, ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
|
||||
# tensors can become an existing buffer, no ScheduleItem needed
|
||||
for k,v in tensor_map.items():
|
||||
# NOTE: we only add base tensors to becomes_map
|
||||
if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user