diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 846e290fe9..5d0b06cf38 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -22,6 +22,33 @@ torch.utils.rename_privateuse1_backend("tiny") torch._register_device_module("tiny", TinyBackend()) torch.utils.generate_methods_for_privateuse1_backend() +# *** bad functions on CPU *** + +@torch.library.impl("aten::masked_select", "privateuseone") +def masked_select(self, mask): + # err, bad + return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) + +@torch.library.impl("aten::topk", "privateuseone") +def topk(self, k, dim=-1, largest=True, sorted=True): + # TODO: move to tinygrad + t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted) + return torch.return_types.topk((t1.tiny(), t2.tiny())) + +@torch.library.impl("aten::_index_put_impl_", "privateuseone") +def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): + # TODO: move to tinygrad + return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny() + +@torch.library.impl("aten::index.Tensor", "privateuseone") +def index_tensor(x, y): + return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny() + +@torch.library.impl("aten::randperm.generator_out", "privateuseone") +def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny()) + +# *** end bad functions on CPU *** + @torch.library.impl("aten::zero_", "privateuseone") def zero_(x): tt = unwrap(x) @@ -35,11 +62,6 @@ def fill_scalar(x, y): @torch.library.impl("aten::_local_scalar_dense", "privateuseone") def _local_scalar_dense(tensor): return unwrap(tensor).item() -@torch.library.impl("aten::masked_select", "privateuseone") -def masked_select(self, mask): - # err, bad - return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) - @functools.lru_cache(None) def cached_to_movement_ops(shape, st) -> list: mops = to_movement_ops(st) @@ -99,24 +121,6 @@ def arange_start(start, end, dtype=None, device=None, pin_memory=None): def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None): return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()))) -@torch.library.impl("aten::topk", "privateuseone") -def topk(self, k, dim=-1, largest=True, sorted=True): - # TODO: move to tinygrad - t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted) - return torch.return_types.topk((t1.tiny(), t2.tiny())) - -@torch.library.impl("aten::_index_put_impl_", "privateuseone") -def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): - # TODO: move to tinygrad - return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny() - -@torch.library.impl("aten::randperm.generator_out", "privateuseone") -def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator).tiny()) - -@torch.library.impl("aten::index.Tensor", "privateuseone") -def index_tensor(x, y): - return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny() - @torch.library.impl("aten::convolution_overrideable", "privateuseone") def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups): if TORCH_DEBUG >= 1: