move gradient to mixin [PR] (#16526)

This commit is contained in:
chenyu
2026-06-07 00:05:02 -04:00
committed by GitHub
parent 4e7c6260b0
commit 90b556ca48
4 changed files with 40 additions and 27 deletions

View File

@@ -59,6 +59,14 @@ class TestTensorUOpClone(unittest.TestCase):
u = UOp.const(dtypes.float, 2.0)
self.assertIs(_strip_unique(Tensor(u).clone().uop), _strip_unique(u.clone()))
class TestTensorUOpGradient(unittest.TestCase):
def test_gradient(self):
x = _t(3, 3).float()
z = (x * 2).sum()
(tg,) = z.gradient(x)
(ug,) = z.uop.gradient(x.uop)
self.assertIs(tg.uop, ug)
class TestTensorUOpGetitem(unittest.TestCase):
# ---- pure slice patterns ----
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])

View File

@@ -20,6 +20,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
@staticmethod
def const(dtype, b): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
@property
def _uop(self) -> UOp: raise NotImplementedError
def _wrap_uop(self, u:UOp) -> Self: raise NotImplementedError
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
@@ -382,6 +385,28 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
"""
Computes the gradient of the targets with respect to self.
```python exec="true" source="above" session="tensor" result="python"
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist()) # dz/dx
print(dy.tolist()) # dz/dy
```
"""
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
from tinygrad.gradient import compute_gradient
if gradient is None: gradient = self.const_like(1.0)
target_uops = [t._uop for t in targets]
grads = compute_gradient(self._uop, gradient._uop, set(target_uops))
return [self._wrap_uop(grads[x] if x in grads else x.const_like(0)) for x in target_uops]
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Self:
"""
Returns the minimum value of the tensor along the specified axis or axes.

View File

@@ -9,9 +9,8 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin import OpMixin
from tinygrad.schedule import create_linear_with_vars
from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_linear
@@ -157,6 +156,9 @@ class Tensor(OpMixin):
# alu and const_like are used by the mixins
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
@property
def _uop(self) -> UOp: return self.uop
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor:
@@ -810,31 +812,6 @@ class Tensor(OpMixin):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor, gradient:Tensor|None=None) -> list[Tensor]:
"""
Computes the gradient of the targets with respect to self.
```python exec="true" source="above" session="tensor" result="python"
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist()) # dz/dx
print(dy.tolist()) # dz/dy
```
"""
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device)
target_uops = [x.uop for x in targets]
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
ret:list[Tensor] = []
for x in target_uops:
if (y:=grads.get(x)) is None: y = x.const_like(0)
ret.append(Tensor(y))
return ret
def backward(self, gradient:Tensor|None=None) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.

View File

@@ -491,6 +491,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
perm = src.permute(tuple([i for i in range(src.ndim) if i not in slice_idx] + slice_idx))
return perm.index(*non_slice_args, ptr=True)
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
@property
def _uop(self) -> UOp: return self
def _wrap_uop(self, u:UOp) -> UOp: return u
def const_like(self, b:ConstLike, dtype:DType|None=None):
return UOp.const(dtype or self.dtype.base, b, shape=self._shape)
def ufix(self, x):