diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ea0870666e..9883fb1b99 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -170,7 +170,7 @@ jobs: - name: Test one op in torch tests run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32 - name: Test beautiful_mnist in torch with TINY_BACKEND - run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py + run: PYTHONPATH=. LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py - name: Test Ops with TINY_BACKEND (expect failure) run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true - name: Test some torch tests (expect failure) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index f3e14d03e2..7c0b9ccf8e 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -2,7 +2,7 @@ from tinygrad import Tensor, dtypes from tinygrad.helpers import DEBUG, getenv, prod import torch.lib TORCH_DEBUG = getenv("TORCH_DEBUG") -import torch, pathlib, math, operator +import torch, pathlib, math, operator, functools torch.autograd.grad_mode.set_multithreading_enabled(False) from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype @@ -40,18 +40,22 @@ def masked_select(self, mask): # err, bad return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) +@functools.lru_cache(None) +def cached_to_movement_ops(shape, st) -> list: + mops = to_movement_ops(st) + if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:] + return mops + from tinygrad.shape.shapetracker import ShapeTracker, View from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps @torch.library.impl("aten::as_strided", "privateuseone") def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): # TODO: this is heavyweight - st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)]) + st = ShapeTracker((View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset))) ret = unwrap(tensor) if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size)) if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st) - mops = to_movement_ops(st) - if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:] - for mo in mops: ret = apply_mop(ret, mo) + for mo in cached_to_movement_ops(tuple(tensor.shape), st): ret = apply_mop(ret, mo) return wrap(ret) @torch.library.impl("aten::empty_strided", "privateuseone") @@ -75,11 +79,10 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di @torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone") def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): - grad_out, self, indices = unwrap(grad_out), unwrap(self), unwrap(indices) # TODO: utilize input indices once they are correct - self = self.detach().clone().requires_grad_(True) - Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode).backward(grad_out) - return wrap(self.grad) + grad_out, self = unwrap(grad_out), unwrap(self) + out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode) + return wrap(out.gradient(self, gradient=grad_out)[0]) @torch.library.impl("aten::arange", "privateuseone") def arange(end, dtype=None, device=None, pin_memory=None): @@ -117,19 +120,15 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None, groups=groups, stride=stride, dilation=dilation, padding=padding)) - #raise NotImplementedError("need convolution") @torch.library.impl("aten::convolution_backward_overrideable", "privateuseone") def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask): if TORCH_DEBUG >= 1: print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") - grad_out, input, weight = unwrap(grad_out), unwrap(input), unwrap(weight) - input = input.detach().clone().requires_grad_(output_mask[0]) - weight = weight.detach().clone().requires_grad_(output_mask[1]) - bias = Tensor.zeros(weight.shape[0]).requires_grad_(output_mask[2]) - Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).backward(grad_out) - return tuple(wrap(x.grad) if x.grad is not None else None for x in [input, weight, bias]) - #raise NotImplementedError("need convolution") + grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0]) + out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding) + grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out) + return tuple([wrap(grads.pop(0)) if m else None for m in output_mask]) @torch.library.impl("aten::_copy_from", "privateuseone") def _copy_from(src, dest, non_blocking=False): @@ -141,9 +140,6 @@ def _copy_from(src, dest, non_blocking=False): dest.copy_(torch.from_numpy(unwrap(src).numpy())) elif str(src.device) == "cpu" and str(dest.device) == "tiny": unwrap(dest).assign(Tensor(src.numpy())) - #if 0 in dest.stride(): - # print(dest.shape, dest.stride()) - # exit(0) else: raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")