From b4028e48aefcdf60342fb64124a8aec8d980c3c9 Mon Sep 17 00:00:00 2001 From: Friedrich Carl Eichenroth Date: Mon, 3 Mar 2025 15:24:02 +0000 Subject: [PATCH] Torch Backend Refinement (#9327) * fix some torch tests * fixup * small change * fixup * fix test * use default function * add todo * bunch of small changes * fix tests * more tests * fix * fix * test fix * simplify --- extra/torch_backend/backend.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 6bd575b3a8..63c8e2c4b4 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -9,7 +9,7 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype # https://pytorch.org/docs/stable/torch.compiler_ir.html import torch.utils.cpp_extension -mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"]) +mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")]) def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype)) def unwrap(x:torch.Tensor) -> Tensor: assert isinstance(x, torch.Tensor), f"x isn't {type(x)}" @@ -95,7 +95,7 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F return wrap(ret) @torch.library.impl("aten::max_pool2d_with_indices", "privateuseone") -def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): +def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False): # TODO: supprt stride [] in tinygrad? if stride is not None and len(stride) == 0: stride = None # TODO: support return_indices in tinygrad @@ -104,12 +104,12 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64))) @torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone") -def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): +def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): if stride is not None and len(stride) == 0: stride = None # TODO: utilize input indices once they are correct - grad_out, self = unwrap(grad_out), unwrap(self) - out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode) - return wrap(out.gradient(self, gradient=grad_out)[0]) + self_ = unwrap(self) + out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode) + return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0]) @torch.library.impl("aten::arange", "privateuseone") def arange(end, dtype=None, device=None, pin_memory=None): @@ -226,7 +226,7 @@ simple_tensor_methods = [ # rounding "ceil", "round", "floor", "trunc", # binary - "mul", "div", "maximum", "minimum", + "mul", "div", "maximum", "minimum", "copysign", # modify "tril", "triu", # reduce @@ -244,7 +244,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.remainder.Tensor_out": Tensor.mod, "aten.pow.Tensor_Tensor_out": Tensor.pow, "aten.pow.Tensor_Scalar_out": Tensor.pow, - "aten.pow.Scalar_out": lambda x,y: x**y, + "aten.pow.Scalar_out": lambda input,exponent: input**exponent, "aten.bitwise_and.Tensor_out": Tensor.bitwise_and, "aten.bitwise_or.Tensor_out": Tensor.bitwise_or, "aten.bitwise_xor.Tensor_out": Tensor.bitwise_xor, @@ -254,16 +254,20 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__, "aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__, "aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__, - "aten.clamp_max.Tensor_out": lambda self,max_: self.clamp(max_=max_), - "aten.clamp_min.Tensor_out": lambda self,min_: self.clamp(min_=min_), + "aten.clamp_max.Tensor_out": lambda input,max_: input.clamp(max_=max_), + "aten.clamp_min.Tensor_out": lambda input,min_: input.clamp(min_=min_), + "aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other, + # TODO: this might result in overflow issues + "aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals, # TODO: support this in tinygrad - "aten.bitwise_left_shift.Tensor_out": lambda self, other: Tensor(self << other.numpy()), - "aten.bitwise_right_shift.Tensor_out": lambda self, other: Tensor(self >> other.numpy()), + "aten.bitwise_left_shift.Tensor_out": lambda input,other: Tensor(input << other.numpy()), + "aten.bitwise_right_shift.Tensor_out": lambda input,other: Tensor(input >> other.numpy()), # not in tinygrad. are there decomps for these? "aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)), "aten.log1p.out": lambda self: (self+1).log(), "aten.expm1.out": lambda self: self.exp() - 1, - "aten.copysign.out": Tensor.copysign, + "aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))), + "aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))), # TODO: this gets the shape wrong #"aten.arange.start_out": Tensor.arange, "aten.lerp.Scalar_out": Tensor.lerp, @@ -363,7 +367,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs): if param is None: continue tinygrad_tensors.append(param.data) for state_dict in optimizer.state.values(): - for key, value in state_dict.items(): + for _, value in state_dict.items(): if torch.is_tensor(value): tinygrad_tensors.append(value) real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"] if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)