don't use uops.add while constructing (#4913)

* don't use uops.add while constructing

* rebase

* bugfixes

* have to use BFS

* prove it's late

* simpler uop symbolic test (why we did this)

* use dict, not set
This commit is contained in:
George Hotz
2024-06-12 13:31:34 +02:00
committed by GitHub
parent d894acbb50
commit 11a03cbbf5
3 changed files with 62 additions and 39 deletions

View File

@@ -12,13 +12,9 @@ from tinygrad.ops import BinaryOps
import functools
def render(self) -> str:
graph = UOpGraph()
# NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(0,True))
def recursive_add(x):
graph.add(x.uop, x.dtype, x.vin, x.arg)
for c in x.vin: recursive_add(c)
recursive_add(UOp(UOps.STORE, None, (glbl,UOp.const(dtypes.int, 0),self)))
graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
graph.linearize()
from tinygrad.renderer.cstyle import CStyleLanguage
class TestRenderer(CStyleLanguage):

View File

@@ -4,7 +4,7 @@ import itertools, math, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name
from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
@@ -49,7 +49,7 @@ class Linearizer(Kernel):
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:ConstType|Variable, dtype:DType=dtypes.int32) -> UOp:
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b)
return UOp(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b)
def get_reduce_acc(self, reduceop:LazyOp):
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
@@ -99,7 +99,7 @@ class Linearizer(Kernel):
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
acc_count += 1
elif this_const is not None:
self.load_cache[key] = self.const(this_const, localtype)
@@ -110,16 +110,16 @@ class Linearizer(Kernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple()
self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4),
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
if localtype == localtype.scalar():
idx_small = idx%4
res = idx_small.render(self.render_ops, self)
out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix))
out = UOp.alu(TernaryOps.WHERE, sel, rvv, out)
self.load_cache[key] = out
@@ -128,8 +128,8 @@ class Linearizer(Kernel):
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
@@ -152,7 +152,7 @@ class Linearizer(Kernel):
amt = len(grouped)
idx, valid = self.sts[i].expr_idxs(k)
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset = store_offset_new
stores = []
@@ -160,17 +160,17 @@ class Linearizer(Kernel):
idx, valid = self.sts[i].expr_idxs(_idx)
if isinstance(buf.dtype, ImageDType):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
tuple(x.render(self.render_ops, self) for x in image_idx))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
return stores
# render loop
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
self.loop_uops.update(new_loops)
@@ -240,17 +240,17 @@ class Linearizer(Kernel):
return strides
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
# cast initial accs
wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
for it in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
# phi the last wmmas back to accs
accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2])))
for z, acc in enumerate(accs[reduceop])]
else:
assert not locals_to_store, "storing locals isn't supported here"
@@ -269,12 +269,12 @@ class Linearizer(Kernel):
if self.group_for_reduces:
fake_global_idxs = [x*0 for x in global_idxs]
stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
barrier = UOp(UOps.BARRIER, None, tuple(stores))
if self.opts.has_local:
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier))
barrier = UOp(UOps.IF, None, (if_cond, barrier))
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
@@ -333,23 +333,22 @@ class Linearizer(Kernel):
if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
# uops
self.uops:UOpGraph = UOpGraph()
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL,
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
# add var vals
for i,var in enumerate(self.vars):
assert var.expr is not None
self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var)
self.loop_uops[var.expr] = UOp(UOps.DEFINE_VAR, dtypes.int32, (), var)
# define local buffers
for aliases in self.local_alias.values():
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
(), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduces:
@@ -358,7 +357,7 @@ class Linearizer(Kernel):
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
# kernel name (before late upcast)
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
@@ -381,11 +380,11 @@ class Linearizer(Kernel):
self.local_size: Optional[List[int]] = None
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
elif self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
self.render_loop(loop_global_idxs+loop_local_idxs, 1)
if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size))
@@ -403,8 +402,10 @@ class Linearizer(Kernel):
# render reduceops by depth
for reduceop in self.reduceops:
self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs)
# only the final stores are needed to define the full UOps graph
self.uops:UOpGraph = UOpGraph(flatten(stores))
# maybe graph the uops
if DEBUG >= 5: self.uops.print()
@@ -428,7 +429,7 @@ class Linearizer(Kernel):
# TODO: delete render_reduceop and move the logic for group_for_reduces to Block
local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs,
reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r])
return accs[r]
return [accs[r]]
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
@@ -442,7 +443,7 @@ class Linearizer(Kernel):
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST,
self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)]
if x.op in ReduceOps and reduce_acc is None:
assert offs is None, "not available if we aren't doing reduce"
@@ -459,7 +460,7 @@ class Linearizer(Kernel):
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]))
else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)]
cache[x] = ret
return ret

View File

@@ -249,13 +249,39 @@ constant_folder = PatternMatcher([
# *** uop graph ***
class UOpGraph:
def __init__(self):
def __init__(self, add_nodes:Optional[List[UOp]]=None):
self.nodes: Dict[Tuple, UOp] = {}
self._uops: Optional[List[UOp]] = None
if add_nodes is not None: self.multiadd(add_nodes)
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
def __getitem__(self, index) -> UOp: return self.uops[index]
def multiadd(self, unprocessed_nodes:List[UOp]):
# add nodes to graph in reverse BFS order
# TODO: i feel like this is written in a few places, possible to library it?
in_degree: DefaultDict[UOp, int] = defaultdict(int)
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
all_nodes: Dict[UOp, None] = dict()
while len(unprocessed_nodes):
n = unprocessed_nodes.pop(0)
if n in all_nodes: continue
all_nodes[n] = None
for x in n.vin:
in_degree[x] += 1
children[x].append(n)
unprocessed_nodes += list(n.vin)
queue = [x for x in all_nodes if in_degree[x] == 0]
replace_nodes: Dict[UOp, UOp] = {}
while len(queue):
n = queue.pop(0)
if n in replace_nodes: continue
replace_nodes[n] = self.add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
queue.append(x)
def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]