mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
more robust reduce_gradient (#10965)
This commit is contained in:
@@ -5,12 +5,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp):
|
||||
if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
|
||||
def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||
if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
|
||||
if ret.arg[0] == Ops.MAX:
|
||||
max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
|
||||
div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
|
||||
return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
|
||||
if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
|
||||
max_is_1s = ret.src[0].ne(to_inp_shape(ret)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
|
||||
div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
|
||||
return ((max_is_1s/div) * to_inp_shape(ctx),)
|
||||
if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
|
||||
|
||||
# ctx is grad_output
|
||||
pm_gradient = PatternMatcher([
|
||||
|
||||
@@ -165,7 +165,7 @@ class CStyleLanguage(Renderer):
|
||||
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
||||
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL, Ops.STORE} or u.dtype == dtypes.void:
|
||||
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
||||
else:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
|
||||
Reference in New Issue
Block a user