more tensor.py docs (#4686)

wow much docs
This commit is contained in:
wozeparrot
2024-05-22 21:28:26 +00:00
committed by GitHub
parent 721f9f6acf
commit 6020595eb0
2 changed files with 50 additions and 2 deletions

View File

@@ -24,6 +24,7 @@
::: tinygrad.Tensor.shard_
::: tinygrad.Tensor.contiguous
::: tinygrad.Tensor.contiguous_backward
::: tinygrad.Tensor.backward
## Creation (basic)

View File

@@ -648,6 +648,17 @@ class Tensor:
return list(_walk(self, set()))
def backward(self) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.
Must be used on a scalar tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6.0, requires_grad=True)
t2 = t.sum()
t2.backward()
print(t.grad.numpy())
```
"""
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
@@ -668,14 +679,50 @@ class Tensor:
# ***** movement mlops *****
def view(self, *shape) -> Tensor: return self.reshape(shape) # in tinygrad, view and reshape are the same thing
def view(self, *shape) -> Tensor:
"""
`.view` is an alias for `.reshape`.
"""
return self.reshape(shape)
def reshape(self, shape, *args) -> Tensor:
"""
Returns a new tensor with the same data as the original tensor but with a different shape.
shape can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
new_shape = argfix(shape, *args)
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
"""
Returns a new tensor that is expanded to the shape that is specified.
Expand can also increase the number of dimensions that a tensor has.
Passing a `-1` or `None` to a dimension means that it's size will not be changed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
```
"""
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
def permute(self, order, *args) -> Tensor: return F.Permute.apply(self, order=argfix(order, *args))
def permute(self, order, *args) -> Tensor:
"""
Returns a new tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
order can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy(), "->")
print(t.permute(1, 0).numpy())
```
"""
return F.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self