From c2f5f0f198beb8c8a40cee9ae6d5787f3cdd104f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:09:33 -0700 Subject: [PATCH] more robust reduce_gradient (#10965) --- tinygrad/gradient.py | 11 ++++++----- tinygrad/renderer/cstyle.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 6936ebdca6..317510de1e 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -5,12 +5,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): - if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),) + def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape) + if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),) if ret.arg[0] == Ops.MAX: - max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) - div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape) - return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),) - if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],) + max_is_1s = ret.src[0].ne(to_inp_shape(ret)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) + div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1])) + return ((max_is_1s/div) * to_inp_shape(ctx),) + if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],) # ctx is grad_output pm_gradient = PatternMatcher([ diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 93a6a11d21..f4a9d0024c 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -165,7 +165,7 @@ class CStyleLanguage(Renderer): (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): r[u] = l else: - if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void: + if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL, Ops.STORE} or u.dtype == dtypes.void: if u.op is Ops.ASSIGN: r[u] = r[u.src[0]] else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")