From ea3fa07c2aae0ffdb5ef9bb2a07dd7359000fd74 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Feb 2023 18:07:03 -0800 Subject: [PATCH] bump tinygrad to 0.5, move reshape logic from mlops --- setup.py | 2 +- tinygrad/mlops.py | 4 +--- tinygrad/tensor.py | 5 ++++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 56d7bf6896..ce72d88ba0 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f: long_description = f.read() setup(name='tinygrad', - version='0.4.0', + version='0.5.0', description='You like pytorch? You like micrograd? You love tinygrad! heart', author='George Hotz', license='MIT', diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 19c5f50bb7..cc7f846658 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,4 +1,4 @@ -from tinygrad.helpers import prod, argsort +from tinygrad.helpers import argsort from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps from tinygrad.tensor import Function @@ -126,9 +126,7 @@ class Expand(Function): class Reshape(Function): def forward(self, x, shape): - assert len(shape) > 0 and all(x != 0 for x in shape), f"zeros not allowed in shape {shape}" self.input_shape = x.shape - shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape) return x.movement_op(MovementOps.RESHAPE, shape) def backward(self, grad_output): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bc7865151f..9f6c50117f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -188,7 +188,10 @@ class Tensor: # ***** movement mlops ***** - def reshape(self, shape, *args) -> Tensor: return mlops.Reshape.apply(self, shape=argfix(shape, *args)) + def reshape(self, shape, *args) -> Tensor: + new_shape = argfix(shape, *args) + assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}" + return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape)) def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))