mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
simpler
This commit is contained in:
@@ -5,8 +5,8 @@ from tinygrad.helpers import argsort
|
||||
|
||||
def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]:
|
||||
axis = ret.arg
|
||||
dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
|
||||
return tuple(ctx.shrink(tuple([(dim_cumsum[i], dim_cumsum[i+1]) if j==axis else (0, ctx.shape[j])
|
||||
dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
|
||||
return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j])
|
||||
for j in range(len(ctx.shape))])) for i in range(len(ret.src)))
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
|
||||
@@ -114,9 +114,8 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
|
||||
def lower_cat(cat:UOp) -> UOp:
|
||||
axis = cat.arg
|
||||
dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
|
||||
padded = [s.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==axis else (0,0)
|
||||
for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
|
||||
dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
|
||||
padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
|
||||
ret = padded[0]
|
||||
for p in padded[1:]: ret = ret.alu(Ops.ADD, p)
|
||||
return ret
|
||||
|
||||
@@ -264,14 +264,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.MULTI if len(self.src) == 0: return None
|
||||
|
||||
case Ops.CAT:
|
||||
shapes = [s._shape for s in self.src]
|
||||
if any(s is None for s in shapes): raise RuntimeError("CAT requires all sources to have shapes")
|
||||
shapes = [s.shape for s in self.src]
|
||||
axis = self.arg
|
||||
if axis == -1: return (len(shapes),) + shapes[0] # type: ignore # new leading axis (stack)
|
||||
for s in shapes[1:]:
|
||||
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis): # type: ignore
|
||||
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis):
|
||||
raise ValueError(f"CAT shape mismatch: {shapes}")
|
||||
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0]))) # type: ignore
|
||||
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0])))
|
||||
|
||||
# movement ops change the shape
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||
|
||||
@@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
||||
Reference in New Issue
Block a user