mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
(bounty) Make mnist training run with torch backend (#9233)
* yml changes * torch backend remove meta decomps and add test * torch backend bump timeout for tests --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -146,7 +146,7 @@ jobs:
|
||||
torchbackend:
|
||||
name: Torch Backend Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -169,10 +169,10 @@ jobs:
|
||||
run: PYTHONPATH=. python3 extra/torch_backend/test.py
|
||||
- name: Test one op in torch tests
|
||||
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test Ops with TINY_BACKEND (expect failure)
|
||||
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)
|
||||
run: PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
|
||||
@@ -73,6 +73,14 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
|
||||
# TODO: this is wrong
|
||||
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
|
||||
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
|
||||
grad_out, self, indices = unwrap(grad_out), unwrap(self), unwrap(indices)
|
||||
# TODO: utilize input indices once they are correct
|
||||
self = self.detach().clone().requires_grad_(True)
|
||||
Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode).backward(grad_out)
|
||||
return wrap(self.grad)
|
||||
|
||||
@torch.library.impl("aten::arange", "privateuseone")
|
||||
def arange(end, dtype=None, device=None, pin_memory=None):
|
||||
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
|
||||
@@ -111,6 +119,18 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
|
||||
groups=groups, stride=stride, dilation=dilation, padding=padding))
|
||||
#raise NotImplementedError("need convolution")
|
||||
|
||||
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
|
||||
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
if TORCH_DEBUG >= 1:
|
||||
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
|
||||
grad_out, input, weight = unwrap(grad_out), unwrap(input), unwrap(weight)
|
||||
input = input.detach().clone().requires_grad_(output_mask[0])
|
||||
weight = weight.detach().clone().requires_grad_(output_mask[1])
|
||||
bias = Tensor.zeros(weight.shape[0]).requires_grad_(output_mask[2])
|
||||
Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).backward(grad_out)
|
||||
return tuple(wrap(x.grad) if x.grad is not None else None for x in [input, weight, bias])
|
||||
#raise NotImplementedError("need convolution")
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src, dest, non_blocking=False):
|
||||
if str(src.device) == "tiny" and str(dest.device) == "tiny":
|
||||
@@ -134,68 +154,57 @@ def cat_out(tensors, dim=0, out=None):
|
||||
# register some decompositions
|
||||
from torch._decomp import get_decompositions
|
||||
aten = torch.ops.aten
|
||||
decomps = {
|
||||
"post_autograd": [
|
||||
aten.native_batch_norm, aten.native_batch_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten.addmm,
|
||||
aten.addcmul,
|
||||
aten.addcdiv,
|
||||
aten._log_softmax_backward_data,
|
||||
aten.threshold_backward,
|
||||
aten.softplus_backward,
|
||||
aten.elu, # elu has a scale + input_scale param
|
||||
aten.softplus,
|
||||
aten.threshold,
|
||||
aten.nll_loss_forward,
|
||||
aten.nll_loss_backward,
|
||||
# AttributeError: 'int' object has no attribute '_broadcasted'
|
||||
aten.sigmoid_backward,
|
||||
aten.tanh_backward,
|
||||
aten.sinc,
|
||||
aten._prelu_kernel,
|
||||
aten.softshrink,
|
||||
aten.hardshrink,
|
||||
aten.log_sigmoid_forward,
|
||||
aten.isneginf,
|
||||
aten.isposinf,
|
||||
aten.nan_to_num,
|
||||
aten.logit,
|
||||
aten.rsub,
|
||||
aten.index_select,
|
||||
aten.native_dropout, aten.native_dropout_backward,
|
||||
aten._softmax_backward_data, aten.embedding_dense_backward,
|
||||
aten.linalg_vector_norm,
|
||||
aten.unfold,
|
||||
# activations
|
||||
aten.hardswish, aten.hardswish_backward,
|
||||
aten.hardtanh, aten.hardtanh_backward,
|
||||
aten.gelu, aten.gelu_backward,
|
||||
# NOTE: this uses index
|
||||
#aten.reflection_pad2d,
|
||||
# NOTE: many of these don't work or cause infinite loops
|
||||
#aten.var_mean,
|
||||
#aten.var,
|
||||
#aten.rsqrt,
|
||||
#aten.max_pool2d_with_indices,
|
||||
# NOTE: these are prims
|
||||
#aten.digamma,
|
||||
#aten.erfinv,
|
||||
#aten.lgamma,
|
||||
# this needs copy_strided
|
||||
#aten.lerp,
|
||||
],
|
||||
"meta": [
|
||||
aten.max_pool2d_with_indices_backward,
|
||||
aten.convolution_backward,
|
||||
],
|
||||
}
|
||||
|
||||
for dctype,lst in decomps.items():
|
||||
for k,v in get_decompositions(lst, type=dctype).items():
|
||||
key = str(k._schema).split("(")[0]
|
||||
if TORCH_DEBUG >= 2: print("register decomp for", k)
|
||||
torch.library.impl(key, "privateuseone")(v)
|
||||
decomps = [
|
||||
aten.native_batch_norm, aten.native_batch_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten.addmm,
|
||||
aten.addcmul,
|
||||
aten.addcdiv,
|
||||
aten._log_softmax_backward_data,
|
||||
aten.threshold_backward,
|
||||
aten.softplus_backward,
|
||||
aten.elu, # elu has a scale + input_scale param
|
||||
aten.softplus,
|
||||
aten.threshold,
|
||||
aten.nll_loss_forward,
|
||||
aten.nll_loss_backward,
|
||||
# AttributeError: 'int' object has no attribute '_broadcasted'
|
||||
aten.sigmoid_backward,
|
||||
aten.tanh_backward,
|
||||
aten.sinc,
|
||||
aten._prelu_kernel,
|
||||
aten.softshrink,
|
||||
aten.hardshrink,
|
||||
aten.log_sigmoid_forward,
|
||||
aten.isneginf,
|
||||
aten.isposinf,
|
||||
aten.nan_to_num,
|
||||
aten.logit,
|
||||
aten.rsub,
|
||||
aten.index_select,
|
||||
aten.native_dropout, aten.native_dropout_backward,
|
||||
aten._softmax_backward_data, aten.embedding_dense_backward,
|
||||
aten.linalg_vector_norm,
|
||||
# activations
|
||||
aten.hardswish, aten.hardswish_backward,
|
||||
aten.hardtanh, aten.hardtanh_backward,
|
||||
aten.gelu, aten.gelu_backward,
|
||||
# NOTE: many of these don't work or cause infinite loops
|
||||
#aten.var_mean,
|
||||
#aten.var,
|
||||
#aten.rsqrt,
|
||||
#aten.max_pool2d_with_indices,
|
||||
# NOTE: these are prims
|
||||
#aten.digamma,
|
||||
#aten.erfinv,
|
||||
#aten.lgamma,
|
||||
# this needs copy_strided
|
||||
#aten.lerp,
|
||||
]
|
||||
for k,v in get_decompositions(decomps).items():
|
||||
key = str(k._schema).split("(")[0]
|
||||
if TORCH_DEBUG >= 2: print("register decomp for", k)
|
||||
torch.library.impl(key, "privateuseone")(v)
|
||||
|
||||
# NOTE: we should only implement the "out" form, it should be 0 overhead
|
||||
# TODO: due to issue with empty / is_realized, it is slow to use assign so we use replace
|
||||
|
||||
@@ -77,6 +77,11 @@ class TestTorchBackend(unittest.TestCase):
|
||||
c = a == b
|
||||
print(c.cpu().numpy())
|
||||
|
||||
def test_maxpool2d_backward(self):
|
||||
x = torch.arange(3*3, device=device).reshape(1, 1, 3, 3).requires_grad_(True)
|
||||
torch.nn.functional.max_pool2d(x, kernel_size=2, stride=1).sum().backward()
|
||||
np.testing.assert_equal(x.grad.squeeze().cpu().numpy(), [[0, 0, 0], [0, 1, 1], [0, 1, 1]])
|
||||
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device=device)
|
||||
|
||||
Reference in New Issue
Block a user