more robust reduce_gradient (#10965)

This commit is contained in:
George Hotz
2025-06-24 14:09:33 -07:00
committed by GitHub
parent 8743ca40e2
commit c2f5f0f198
2 changed files with 7 additions and 6 deletions

View File

@@ -5,12 +5,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
def reduce_gradient(ctx:UOp, ret:UOp):
if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
if ret.arg[0] == Ops.MAX:
max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
max_is_1s = ret.src[0].ne(to_inp_shape(ret)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
return ((max_is_1s/div) * to_inp_shape(ctx),)
if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
# ctx is grad_output
pm_gradient = PatternMatcher([

View File

@@ -165,7 +165,7 @@ class CStyleLanguage(Renderer):
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l
else:
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL, Ops.STORE} or u.dtype == dtypes.void:
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
else:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")