simplify tensors before scheduling [pr] (#8580)

* delete forced_realize

* put that back

* work

* remove forced_realize

* expectedFailures

* contiguous(buffer)

* multi

* expectedFailures

* cleaner create_subbuffer

* more comments

* remove that

* note

* realizes

* work

* one upat and image is back

* remove

* cleaner

* fix test_complex_backward for now

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal
2025-01-20 16:42:42 -05:00
committed by GitHub
parent 02ad450e22
commit 08eb1f1f56
3 changed files with 30 additions and 56 deletions

View File

@@ -113,7 +113,8 @@ class TestImageDType(unittest.TestCase):
assert it.lazydata.base.realized._buf != b1
# issue caused by: don't realize image to image casts. this is part of a larger problem
@unittest.expectedFailure
#@unittest.expectedFailure
# update: passing after tensor_map
def test_lil_model(self):
with Context(IMAGE=2):
x = Tensor.zeros(1, 1)

View File

@@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase):
GlobalCounters.reset()
expr = (a*b)/b
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now!
self.assertEqual(GlobalCounters.global_ops, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
@@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase):
GlobalCounters.reset()
expr = a/a
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_ops, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
@@ -2204,7 +2204,7 @@ class TestConst(unittest.TestCase):
sched = add.schedule()
self.assertEqual(len(sched), 0)
# b+0 and b share the same underlying device memory
self.assertIs(add.lazydata.realized, b.lazydata.realized)
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
def test_src_masked_const_folding(self):

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
from tinygrad.dtype import DType, ImageDType, dtypes
@@ -88,15 +88,15 @@ class ScheduleContext:
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# SINK is passthrough
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
# VIEW is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st))
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
return ret
# make things that can't be images not images
dtype = buf.dtype
@@ -105,9 +105,9 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = [buf]
ctx.tensor_uops[buf_uop] = tensor_map[buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
@@ -358,10 +358,8 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
old_base = contig.src[0].src[0]
if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti)
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
@@ -372,8 +370,6 @@ ops_folding = symbolic_simple+PatternMatcher([
# op with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# if the uop folded to a CONST we can delete the BUFFER
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# reduce of size 0 is the identity element
@@ -386,13 +382,16 @@ ops_folding = symbolic_simple+PatternMatcher([
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# double contiguous is one contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER/VIEW from SINK
(UPat(Ops.SINK, name="root"),
@@ -400,34 +399,6 @@ ops_folding = symbolic_simple+PatternMatcher([
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** buffer merging
def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}"
# if b2 is realized also realize b1
if b2 in ctx.realizes:
ctx.realizes[b1] = b1
del ctx.realizes[b2]
# ops referring to b2 now ref to b1
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
del ctx.tensor_uops[b2]
# merge
return v1
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
# early become
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st))
return v1
merge_bufs = PatternMatcher([
# merge base
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge),
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized),
# collapse view
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv),
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv),
])
# ** this decides which ops get realized
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
@@ -481,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
@@ -523,15 +494,13 @@ remove_movement_ops = PatternMatcher([
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
# if using VIZ, do a graph rewrite to vizualize the Tensor graph
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext())
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={})
# const folding and fusion
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx)
sink = graph_rewrite(sink, merge_bufs, ctx)
# create the scheduler context
graph_rewrite(sink, create_ctx, ctx)
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
@@ -539,13 +508,17 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
# TODO: this still exists because symbolic folding is happening after bufferization
if not all(x.op is Ops.STORE for x in small_sink.src): continue
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
prescheduled.append(schedule_uop(small_sink, ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# tensors can become an existing buffer, no ScheduleItem needed
for k,v in tensor_map.items():
# NOTE: we only add base tensors to becomes_map
if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)