This commit is contained in:
George Hotz
2026-03-07 10:11:21 +08:00
parent be0f9d1055
commit af1db22b25
4 changed files with 8 additions and 11 deletions

View File

@@ -5,8 +5,8 @@ from tinygrad.helpers import argsort
def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]:
axis = ret.arg
dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
return tuple(ctx.shrink(tuple([(dim_cumsum[i], dim_cumsum[i+1]) if j==axis else (0, ctx.shape[j])
dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j])
for j in range(len(ctx.shape))])) for i in range(len(ret.src)))
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):

View File

@@ -114,9 +114,8 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
def lower_cat(cat:UOp) -> UOp:
axis = cat.arg
dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
padded = [s.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==axis else (0,0)
for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
ret = padded[0]
for p in padded[1:]: ret = ret.alu(Ops.ADD, p)
return ret

View File

@@ -264,14 +264,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.MULTI if len(self.src) == 0: return None
case Ops.CAT:
shapes = [s._shape for s in self.src]
if any(s is None for s in shapes): raise RuntimeError("CAT requires all sources to have shapes")
shapes = [s.shape for s in self.src]
axis = self.arg
if axis == -1: return (len(shapes),) + shapes[0] # type: ignore # new leading axis (stack)
for s in shapes[1:]:
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis): # type: ignore
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis):
raise ValueError(f"CAT shape mismatch: {shapes}")
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0]))) # type: ignore
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0])))
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking

View File

@@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",