diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 3527bd963b..dbcb455c7a 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -5,8 +5,8 @@ from tinygrad.helpers import argsort def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]: axis = ret.arg - dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0)) - return tuple(ctx.shrink(tuple([(dim_cumsum[i], dim_cumsum[i+1]) if j==axis else (0, ctx.shape[j]) + dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0)) + return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j]) for j in range(len(ctx.shape))])) for i in range(len(ret.src))) def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 69561c12f4..f3e0a15506 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -114,9 +114,8 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: def lower_cat(cat:UOp) -> UOp: axis = cat.arg - dim_cumsum = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0)) - padded = [s.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==axis else (0,0) - for j in range(len(s.shape)))) for i,s in enumerate(cat.src)] + dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0)) + padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)] ret = padded[0] for p in padded[1:]: ret = ret.alu(Ops.ADD, p) return ret diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5aa89deaa1..4793f54a56 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -264,14 +264,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.MULTI if len(self.src) == 0: return None case Ops.CAT: - shapes = [s._shape for s in self.src] - if any(s is None for s in shapes): raise RuntimeError("CAT requires all sources to have shapes") + shapes = [s.shape for s in self.src] axis = self.arg - if axis == -1: return (len(shapes),) + shapes[0] # type: ignore # new leading axis (stack) for s in shapes[1:]: - if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis): # type: ignore + if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis): raise ValueError(f"CAT shape mismatch: {shapes}") - return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0]))) # type: ignore + return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0]))) # movement ops change the shape # NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 9b562c0fb1..b042fceb60 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", - **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", + **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6", Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",