(bounty) Make mnist training run with torch backend (#9233)

* yml changes

* torch backend remove meta decomps and add test

* torch backend bump timeout for tests

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Priyank Patel
2025-02-26 19:32:25 -08:00
committed by GitHub
parent 67ba073c55
commit a0764f0dc0
3 changed files with 79 additions and 65 deletions

View File

@@ -146,7 +146,7 @@ jobs:
torchbackend:
name: Torch Backend Tests
runs-on: ubuntu-latest
timeout-minutes: 10
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -169,10 +169,10 @@ jobs:
run: PYTHONPATH=. python3 extra/torch_backend/test.py
- name: Test one op in torch tests
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
- name: Test beautiful_mnist in torch with TINY_BACKEND
run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
- name: Test Ops with TINY_BACKEND (expect failure)
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)
run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true
- name: Test some torch tests (expect failure)
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true

View File

@@ -73,6 +73,14 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
# TODO: this is wrong
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
grad_out, self, indices = unwrap(grad_out), unwrap(self), unwrap(indices)
# TODO: utilize input indices once they are correct
self = self.detach().clone().requires_grad_(True)
Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode).backward(grad_out)
return wrap(self.grad)
@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@@ -111,6 +119,18 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
groups=groups, stride=stride, dilation=dilation, padding=padding))
#raise NotImplementedError("need convolution")
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
if TORCH_DEBUG >= 1:
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
grad_out, input, weight = unwrap(grad_out), unwrap(input), unwrap(weight)
input = input.detach().clone().requires_grad_(output_mask[0])
weight = weight.detach().clone().requires_grad_(output_mask[1])
bias = Tensor.zeros(weight.shape[0]).requires_grad_(output_mask[2])
Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).backward(grad_out)
return tuple(wrap(x.grad) if x.grad is not None else None for x in [input, weight, bias])
#raise NotImplementedError("need convolution")
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src, dest, non_blocking=False):
if str(src.device) == "tiny" and str(dest.device) == "tiny":
@@ -134,68 +154,57 @@ def cat_out(tensors, dim=0, out=None):
# register some decompositions
from torch._decomp import get_decompositions
aten = torch.ops.aten
decomps = {
"post_autograd": [
aten.native_batch_norm, aten.native_batch_norm_backward,
aten.native_layer_norm_backward,
aten.addmm,
aten.addcmul,
aten.addcdiv,
aten._log_softmax_backward_data,
aten.threshold_backward,
aten.softplus_backward,
aten.elu, # elu has a scale + input_scale param
aten.softplus,
aten.threshold,
aten.nll_loss_forward,
aten.nll_loss_backward,
# AttributeError: 'int' object has no attribute '_broadcasted'
aten.sigmoid_backward,
aten.tanh_backward,
aten.sinc,
aten._prelu_kernel,
aten.softshrink,
aten.hardshrink,
aten.log_sigmoid_forward,
aten.isneginf,
aten.isposinf,
aten.nan_to_num,
aten.logit,
aten.rsub,
aten.index_select,
aten.native_dropout, aten.native_dropout_backward,
aten._softmax_backward_data, aten.embedding_dense_backward,
aten.linalg_vector_norm,
aten.unfold,
# activations
aten.hardswish, aten.hardswish_backward,
aten.hardtanh, aten.hardtanh_backward,
aten.gelu, aten.gelu_backward,
# NOTE: this uses index
#aten.reflection_pad2d,
# NOTE: many of these don't work or cause infinite loops
#aten.var_mean,
#aten.var,
#aten.rsqrt,
#aten.max_pool2d_with_indices,
# NOTE: these are prims
#aten.digamma,
#aten.erfinv,
#aten.lgamma,
# this needs copy_strided
#aten.lerp,
],
"meta": [
aten.max_pool2d_with_indices_backward,
aten.convolution_backward,
],
}
for dctype,lst in decomps.items():
for k,v in get_decompositions(lst, type=dctype).items():
key = str(k._schema).split("(")[0]
if TORCH_DEBUG >= 2: print("register decomp for", k)
torch.library.impl(key, "privateuseone")(v)
decomps = [
aten.native_batch_norm, aten.native_batch_norm_backward,
aten.native_layer_norm_backward,
aten.addmm,
aten.addcmul,
aten.addcdiv,
aten._log_softmax_backward_data,
aten.threshold_backward,
aten.softplus_backward,
aten.elu, # elu has a scale + input_scale param
aten.softplus,
aten.threshold,
aten.nll_loss_forward,
aten.nll_loss_backward,
# AttributeError: 'int' object has no attribute '_broadcasted'
aten.sigmoid_backward,
aten.tanh_backward,
aten.sinc,
aten._prelu_kernel,
aten.softshrink,
aten.hardshrink,
aten.log_sigmoid_forward,
aten.isneginf,
aten.isposinf,
aten.nan_to_num,
aten.logit,
aten.rsub,
aten.index_select,
aten.native_dropout, aten.native_dropout_backward,
aten._softmax_backward_data, aten.embedding_dense_backward,
aten.linalg_vector_norm,
# activations
aten.hardswish, aten.hardswish_backward,
aten.hardtanh, aten.hardtanh_backward,
aten.gelu, aten.gelu_backward,
# NOTE: many of these don't work or cause infinite loops
#aten.var_mean,
#aten.var,
#aten.rsqrt,
#aten.max_pool2d_with_indices,
# NOTE: these are prims
#aten.digamma,
#aten.erfinv,
#aten.lgamma,
# this needs copy_strided
#aten.lerp,
]
for k,v in get_decompositions(decomps).items():
key = str(k._schema).split("(")[0]
if TORCH_DEBUG >= 2: print("register decomp for", k)
torch.library.impl(key, "privateuseone")(v)
# NOTE: we should only implement the "out" form, it should be 0 overhead
# TODO: due to issue with empty / is_realized, it is slow to use assign so we use replace

View File

@@ -77,6 +77,11 @@ class TestTorchBackend(unittest.TestCase):
c = a == b
print(c.cpu().numpy())
def test_maxpool2d_backward(self):
x = torch.arange(3*3, device=device).reshape(1, 1, 3, 3).requires_grad_(True)
torch.nn.functional.max_pool2d(x, kernel_size=2, stride=1).sum().backward()
np.testing.assert_equal(x.grad.squeeze().cpu().numpy(), [[0, 0, 0], [0, 1, 1], [0, 1, 1]])
@unittest.skip("meh")
def test_str(self):
a = torch.ones(4, device=device)