mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-12 07:55:14 +08:00
186 lines
10 KiB
Python
186 lines
10 KiB
Python
from __future__ import annotations
|
|
from typing import List, Tuple, cast, Optional, Any, Dict
|
|
import functools
|
|
from tinygrad.shape.shapetracker import ShapeTracker, View
|
|
from tinygrad.shape.symbolic import sint
|
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
|
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer
|
|
from tinygrad.codegen.uops import UOp, UOps
|
|
from tinygrad.renderer import Renderer
|
|
from tinygrad.helpers import getenv, all_int, get_contraction
|
|
|
|
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
|
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
|
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.bigint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
|
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, self.b),
|
|
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
|
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
|
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
|
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
|
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self),
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
|
|
|
if getenv("UOP_IS_SYMBOLIC"):
|
|
# TODO: change this once UOps is ready to replace symbolic
|
|
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
|
# TODO: dtypes.realint
|
|
iexpr = variable_to_uop(view.offset)
|
|
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
|
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
|
|
if m is not None:
|
|
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
|
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
|
return iexpr, vexpr
|
|
|
|
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
|
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
|
|
for view in reversed(st.views[0:-1]):
|
|
view = view.minify()
|
|
acc, idxs = 1, []
|
|
for _d in reversed(view.shape):
|
|
d = variable_to_uop(_d)
|
|
idxs.append((idx//acc)%d)
|
|
acc *= d
|
|
idx, valid = _uop_view(view, idxs[::-1], valid)
|
|
if isinstance(dtype, ImageDType):
|
|
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
|
|
return idx, valid
|
|
else:
|
|
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
|
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
|
|
idx, valid = st.expr_idxs(fake_idxs)
|
|
ctx = dict(zip(fake_idxs, idxs))
|
|
uvalid = valid.render(render_ops, ctx)
|
|
if isinstance(dtype, ImageDType):
|
|
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
|
|
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
|
|
else:
|
|
uidx = idx.render(render_ops, ctx)
|
|
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
|
|
assert uvalid.dtype == dtypes.bool
|
|
return uidx, uvalid
|
|
|
|
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
|
# TODO: symbolic shape
|
|
if not all_int(dims): return dims
|
|
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
|
for i,m in enumerate(max_sizes):
|
|
if dims[i] * dims[i+1] <= m:
|
|
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
|
break
|
|
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
|
return dims
|
|
|
|
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
|
if reverse: dims = dims[::-1]
|
|
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
|
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), (i, f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
|
if limited != dims:
|
|
ret = []
|
|
# cast for mypy, get_contraction won't be None
|
|
for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
|
|
if len(contraction) == 1: ret.append(idx)
|
|
else:
|
|
for c in contraction:
|
|
ret.append(idx % dims[c])
|
|
idx //= dims[c]
|
|
return ret[::-1] if reverse else ret
|
|
|
|
class IndependentLowerer:
|
|
def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
|
|
self.output_count = len(ast.src)
|
|
|
|
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
|
|
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
|
full_shape = ast.full_shape
|
|
first_upcasted = len(full_shape)-ki.upcasted
|
|
# if there's no reduce, this is first_upcasted
|
|
first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
|
|
local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
|
|
# NOTE: this is taking the first one...there may be subtlelies here with multireduces
|
|
group_for_reduces = sum([x!=y for x,y in zip(
|
|
local_loads[0].arg.st.shape[first_reduce:first_upcasted], ast.src[0].arg.st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
|
global_dims = first_reduce-ki.local_dims
|
|
|
|
if opts.has_local:
|
|
if ki.dont_use_locals:
|
|
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
|
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
|
else:
|
|
# define indexes for GPU-like execution
|
|
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
|
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
|
else:
|
|
# all loops are RANGES
|
|
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))
|
|
for i,g in enumerate(full_shape[:first_reduce])]
|
|
|
|
# reduce loops
|
|
self.idxs += [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, True))
|
|
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
|
|
|
# upcast loops
|
|
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
|
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
|
self.idxs.append(UOp(UOps.EXPAND, dtypes.bigint, tuple(UOp.const(dtypes.bigint, j) for j in range(0, g)), ((i,g),)))
|
|
|
|
# late indexes (group for reduce)
|
|
self.ridxs = self.idxs[:]
|
|
for a in range(first_reduce, first_reduce+group_for_reduces):
|
|
self.ridxs[a] = UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
|
|
|
self.uop_cache: Dict[LazyOp, UOp] = {}
|
|
return self.to_uop(ast)
|
|
|
|
def to_uop(self, x:LazyOp) -> UOp:
|
|
if uop:=self.uop_cache.get(x, None): return uop
|
|
ret = self._to_uop(x)
|
|
self.uop_cache[x] = ret
|
|
return ret
|
|
|
|
def _to_uop(self, x:LazyOp) -> UOp:
|
|
if x.op in BufferOps:
|
|
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
|
|
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
|
|
# TODO: check has_valid in UPat, not here
|
|
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
|
if x.op is BufferOps.CONST:
|
|
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
|
|
return UOp.alu(TernaryOps.WHERE, valid, UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
|
|
if x.arg.idx == -1:
|
|
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
|
|
else:
|
|
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
|
|
(x.arg.idx, x.arg.idx < self.output_count))
|
|
if x.op is BufferOps.LOAD:
|
|
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
|
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
|
|
# NOTE: only store the local reduceop in the first thread
|
|
if x.arg.idx != -1:
|
|
has_valid = True
|
|
for oidx, ridx in zip(self.idxs, self.ridxs):
|
|
if oidx != ridx: valid = valid * oidx.eq(0)
|
|
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
|
|
|
|
in_uops = tuple(self.to_uop(y) for y in x.src)
|
|
if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
|
|
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
|
|
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
|
|
if x.op in ReduceOps:
|
|
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
|
|
if x.op is ReduceOps.WMMA:
|
|
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
|
|
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
|
|
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
|
|
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
|
|
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
|
|
|
con = tuple(UOp(UOps.CONTRACT, dtype.vec(4), tuple([UOp(UOps.GEP, dtype, (ret,), i+j*4) for i in range(4)])) for j in range(4))
|
|
return UOp(UOps.EXPAND, dtype, con, arg=((upcast_axis[2], wmma_sz[2]//4),))
|
|
# NOTE: always using ridxs is fine here
|
|
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
|
return UOp.alu(x.op, *in_uops)
|
|
|
|
def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
|