Files
tinygrad/tinygrad/codegen/lowerer.py
2024-07-22 13:38:08 -03:00

186 lines
10 KiB
Python

from __future__ import annotations
from typing import List, Tuple, cast, Optional, Any, Dict
import functools
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.renderer import Renderer
from tinygrad.helpers import getenv, all_int, get_contraction
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.bigint, x) if isinstance(x, int) else x.render(render_ops, ctx)
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, self.b),
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
if getenv("UOP_IS_SYMBOLIC"):
# TODO: change this once UOps is ready to replace symbolic
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
if m is not None:
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(st.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for _d in reversed(view.shape):
d = variable_to_uop(_d)
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
if isinstance(dtype, ImageDType):
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
return idx, valid
else:
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
uvalid = valid.render(render_ops, ctx)
if isinstance(dtype, ImageDType):
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
else:
uidx = idx.render(render_ops, ctx)
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if dims[i] * dims[i+1] <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
return dims
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
if reverse: dims = dims[::-1]
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), (i, f"{prefix}{i}", s)) for i,s in enumerate(limited)]
if limited != dims:
ret = []
# cast for mypy, get_contraction won't be None
for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
if len(contraction) == 1: ret.append(idx)
else:
for c in contraction:
ret.append(idx % dims[c])
idx //= dims[c]
return ret[::-1] if reverse else ret
class IndependentLowerer:
def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
self.output_count = len(ast.src)
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
# if there's no reduce, this is first_upcasted
first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
# NOTE: this is taking the first one...there may be subtlelies here with multireduces
group_for_reduces = sum([x!=y for x,y in zip(
local_loads[0].arg.st.shape[first_reduce:first_upcasted], ast.src[0].arg.st.shape[first_reduce:first_upcasted])]) if local_loads else 0
global_dims = first_reduce-ki.local_dims
if opts.has_local:
if ki.dont_use_locals:
assert ki.local_dims == 0, "can't use locals if there's no local dims"
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
else:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
self.idxs += [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),)))
# late indexes (group for reduce)
self.ridxs = self.idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
self.uop_cache: Dict[LazyOp, UOp] = {}
return self.to_uop(ast)
def to_uop(self, x:LazyOp) -> UOp:
if uop:=self.uop_cache.get(x, None): return uop
ret = self._to_uop(x)
self.uop_cache[x] = ret
return ret
def _to_uop(self, x:LazyOp) -> UOp:
if x.op in BufferOps:
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is BufferOps.CONST:
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
if x.arg.idx == -1:
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
else:
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
(x.arg.idx, x.arg.idx < self.output_count))
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the first thread
if x.arg.idx != -1:
has_valid = True
for oidx, ridx in zip(self.idxs, self.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
in_uops = tuple(self.to_uop(y) for y in x.src)
if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
if x.op in ReduceOps:
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
if x.op is ReduceOps.WMMA:
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
con = tuple(UOp(UOps.CONTRACT, dtype.vec(4), tuple([UOp(UOps.GEP, dtype, (ret,), i+j*4) for i in range(4)])) for j in range(4))
return UOp(UOps.EXPAND, dtype, con, arg=((upcast_axis[2], wmma_sz[2]//4),))
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)
def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)