mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
remove _force_unique from Tensor init (#16277)
This commit is contained in:
@@ -99,10 +99,10 @@ class TestTinygrad(unittest.TestCase):
|
||||
return second_derivative.numpy()
|
||||
|
||||
def test_tinygrad():
|
||||
x_val = Tensor(2.0)
|
||||
x_val = Tensor([2.0])
|
||||
f = x_val**3
|
||||
first_derivative = f.gradient(x_val)[0]
|
||||
second_derivative = first_derivative.gradient(x_val)[0]
|
||||
first_derivative = f.sum().gradient(x_val)[0]
|
||||
second_derivative = first_derivative.sum().gradient(x_val)[0]
|
||||
return second_derivative.numpy()
|
||||
|
||||
np.testing.assert_allclose(test_tinygrad(), test_pytorch(), atol=1e-5)
|
||||
@@ -171,8 +171,8 @@ class TestTinygrad(unittest.TestCase):
|
||||
return w1.grad, w2.grad
|
||||
|
||||
def test_tinygrad():
|
||||
w1 = Tensor(init)
|
||||
w2 = Tensor(init)
|
||||
w1 = Tensor(init).clone()
|
||||
w2 = Tensor(init).clone()
|
||||
out = w1.add(w2)
|
||||
out.backward()
|
||||
return w1.grad.numpy(), w2.grad.numpy()
|
||||
@@ -191,8 +191,8 @@ class TestTinygrad(unittest.TestCase):
|
||||
return w1.grad.numpy(), w2.grad.numpy()
|
||||
|
||||
def test_tinygrad():
|
||||
w1 = Tensor(init)
|
||||
w2 = Tensor(init)
|
||||
w1 = Tensor(init).clone()
|
||||
w2 = Tensor(init).clone()
|
||||
assert w1.requires_grad is True and w2.requires_grad is True
|
||||
nn.optim.SGD([w1, w2], lr=0.01)
|
||||
assert w1.requires_grad is True and w2.requires_grad is True
|
||||
|
||||
@@ -98,6 +98,12 @@ class TestTensorGradient(unittest.TestCase):
|
||||
x = Tensor.randn(4, 4)
|
||||
np.testing.assert_allclose(x.pad(((1,0),(0,0))).gradient(x, gradient=g2)[0].numpy(), np.zeros((4, 4)))
|
||||
|
||||
def test_bare_const_skipped_by_backward(self):
|
||||
Tensor.manual_seed(0)
|
||||
w = Tensor(1.0)
|
||||
(Tensor.rand(()) + w).backward()
|
||||
self.assertIsNone(w.grad)
|
||||
|
||||
class TestMultiOutputGradient(unittest.TestCase):
|
||||
@staticmethod
|
||||
def addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestIndexing(unittest.TestCase):
|
||||
n = random.randint(1, 10)
|
||||
z = Tensor.randn([m, n])
|
||||
a = 1.0
|
||||
w = Tensor(a)
|
||||
w = Tensor([a])
|
||||
z[:, 0] = w
|
||||
z.sum().backward()
|
||||
numpy_testing_assert_equal_helper(w.grad, m * a)
|
||||
|
||||
@@ -1295,8 +1295,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return [X, V]
|
||||
|
||||
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
|
||||
intermediate_tensors[y].backward()
|
||||
return tuple([t.grad for t in inputs])
|
||||
return tuple(intermediate_tensors[y].gradient(*inputs))
|
||||
|
||||
return {
|
||||
# Tensor ops
|
||||
|
||||
@@ -92,7 +92,7 @@ class Tensor(OpMixin):
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool=True, _force_unique:bool=False):
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool=True):
|
||||
if device is None:
|
||||
if isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
elif isinstance(data, UOp): device = data.device
|
||||
@@ -113,8 +113,7 @@ class Tensor(OpMixin):
|
||||
elif data is None:
|
||||
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
dt = _dtype or dtypes.from_py(data)
|
||||
data = UOp.unique_const(data, dt, _device) if _force_unique or (requires_grad and dtypes.is_float(dt)) else UOp.const(dt, data, _device)
|
||||
data = UOp.const(_dtype or dtypes.from_py(data), data, _device)
|
||||
elif isinstance(data, bytes): data = _frompy(data, _dtype or dtypes.uint8, _device)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if _dtype is None:
|
||||
@@ -161,11 +160,13 @@ class Tensor(OpMixin):
|
||||
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
||||
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b), requires_grad=False)
|
||||
@staticmethod
|
||||
def unique_const(fill_value:ConstType|UOp, **kwargs) -> Tensor: return Tensor(fill_value, _force_unique=True, **kwargs)
|
||||
def unique_const(fill_value:ConstType|UOp, **kwargs) -> Tensor:
|
||||
if isinstance(fill_value, UOp): return Tensor(fill_value, **kwargs)
|
||||
dtype, device = kwargs.pop("dtype", None), kwargs.pop("device", None)
|
||||
return Tensor(UOp.unique_const(fill_value, dtype, device), **kwargs)
|
||||
|
||||
def requires_grad_(self, requires_grad:bool=True) -> Tensor:
|
||||
# make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps
|
||||
if requires_grad and self.uop.op is Ops.CONST: self.replace(Tensor(self.uop.arg, device=self.device, dtype=self.dtype, requires_grad=True))
|
||||
if requires_grad and self.uop.op is Ops.CONST: self.replace(self.clone())
|
||||
self.requires_grad = requires_grad
|
||||
return self
|
||||
|
||||
@@ -859,9 +860,9 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
all_uops = self.uop.toposort()
|
||||
# backward fills .grad for every in-scope float tensor; .detach() only blocks gradient flow to the source
|
||||
# backward fills .grad for every in-scope non-CONST float tensor
|
||||
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
||||
t.uop in all_uops and t.is_floating_point()]
|
||||
t.uop in all_uops and t.is_floating_point() and t.uop.op is not Ops.CONST]
|
||||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
|
||||
Reference in New Issue
Block a user