diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 6936ebdca6..317510de1e 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -5,12 +5,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): - if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),) + def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape) + if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),) if ret.arg[0] == Ops.MAX: - max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) - div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape) - return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),) - if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],) + max_is_1s = ret.src[0].ne(to_inp_shape(ret)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype) + div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1])) + return ((max_is_1s/div) * to_inp_shape(ctx),) + if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],) # ctx is grad_output pm_gradient = PatternMatcher([ diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 93a6a11d21..f4a9d0024c 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -165,7 +165,7 @@ class CStyleLanguage(Renderer): (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): r[u] = l else: - if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void: + if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL, Ops.STORE} or u.dtype == dtypes.void: if u.op is Ops.ASSIGN: r[u] = r[u.src[0]] else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")