mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
move gradient to mixin [PR] (#16526)
This commit is contained in:
@@ -59,6 +59,14 @@ class TestTensorUOpClone(unittest.TestCase):
|
||||
u = UOp.const(dtypes.float, 2.0)
|
||||
self.assertIs(_strip_unique(Tensor(u).clone().uop), _strip_unique(u.clone()))
|
||||
|
||||
class TestTensorUOpGradient(unittest.TestCase):
|
||||
def test_gradient(self):
|
||||
x = _t(3, 3).float()
|
||||
z = (x * 2).sum()
|
||||
(tg,) = z.gradient(x)
|
||||
(ug,) = z.uop.gradient(x.uop)
|
||||
self.assertIs(tg.uop, ug)
|
||||
|
||||
class TestTensorUOpGetitem(unittest.TestCase):
|
||||
# ---- pure slice patterns ----
|
||||
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])
|
||||
|
||||
@@ -20,6 +20,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
|
||||
@staticmethod
|
||||
def const(dtype, b): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
|
||||
@property
|
||||
def _uop(self) -> UOp: raise NotImplementedError
|
||||
def _wrap_uop(self, u:UOp) -> Self: raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
|
||||
@@ -382,6 +385,28 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
|
||||
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
|
||||
|
||||
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
|
||||
"""
|
||||
Computes the gradient of the targets with respect to self.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor.eye(3)
|
||||
y = Tensor([[2.0,0,-2.0]])
|
||||
z = y.matmul(x).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
|
||||
print(dx.tolist()) # dz/dx
|
||||
print(dy.tolist()) # dz/dy
|
||||
```
|
||||
"""
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
|
||||
from tinygrad.gradient import compute_gradient
|
||||
if gradient is None: gradient = self.const_like(1.0)
|
||||
target_uops = [t._uop for t in targets]
|
||||
grads = compute_gradient(self._uop, gradient._uop, set(target_uops))
|
||||
return [self._wrap_uop(grads[x] if x in grads else x.const_like(0)) for x in target_uops]
|
||||
|
||||
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Self:
|
||||
"""
|
||||
Returns the minimum value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -9,9 +9,8 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.schedule import create_linear_with_vars
|
||||
from tinygrad.device import Buffer, canonicalize_device
|
||||
from tinygrad.engine.realize import run_linear
|
||||
@@ -157,6 +156,9 @@ class Tensor(OpMixin):
|
||||
|
||||
# alu and const_like are used by the mixins
|
||||
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
||||
@property
|
||||
def _uop(self) -> UOp: return self.uop
|
||||
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
|
||||
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstType|UOp) -> Tensor:
|
||||
@@ -810,31 +812,6 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def gradient(self, *targets:Tensor, gradient:Tensor|None=None) -> list[Tensor]:
|
||||
"""
|
||||
Computes the gradient of the targets with respect to self.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor.eye(3)
|
||||
y = Tensor([[2.0,0,-2.0]])
|
||||
z = y.matmul(x).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
|
||||
print(dx.tolist()) # dz/dx
|
||||
print(dy.tolist()) # dz/dy
|
||||
```
|
||||
"""
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
|
||||
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device)
|
||||
target_uops = [x.uop for x in targets]
|
||||
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
|
||||
ret:list[Tensor] = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: y = x.const_like(0)
|
||||
ret.append(Tensor(y))
|
||||
return ret
|
||||
|
||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||
"""
|
||||
Propagates the gradient of a tensor backwards through the computation graph.
|
||||
|
||||
@@ -491,6 +491,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
perm = src.permute(tuple([i for i in range(src.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*non_slice_args, ptr=True)
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
@property
|
||||
def _uop(self) -> UOp: return self
|
||||
def _wrap_uop(self, u:UOp) -> UOp: return u
|
||||
def const_like(self, b:ConstLike, dtype:DType|None=None):
|
||||
return UOp.const(dtype or self.dtype.base, b, shape=self._shape)
|
||||
def ufix(self, x):
|
||||
|
||||
Reference in New Issue
Block a user