Torch Backend Refinement (#9327)

* fix some torch tests

* fixup

* small change

* fixup

* fix test

* use default function

* add todo

* bunch of small changes

* fix tests

* more tests

* fix

* fix

* test fix

* simplify
This commit is contained in:
Friedrich Carl Eichenroth
2025-03-03 15:24:02 +00:00
committed by GitHub
parent 23084fd850
commit b4028e48ae

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
# https://pytorch.org/docs/stable/torch.compiler_ir.html
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
@@ -95,7 +95,7 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
return wrap(ret)
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
# TODO: supprt stride [] in tinygrad?
if stride is not None and len(stride) == 0: stride = None
# TODO: support return_indices in tinygrad
@@ -104,12 +104,12 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
if stride is not None and len(stride) == 0: stride = None
# TODO: utilize input indices once they are correct
grad_out, self = unwrap(grad_out), unwrap(self)
out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode)
return wrap(out.gradient(self, gradient=grad_out)[0])
self_ = unwrap(self)
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])
@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
@@ -226,7 +226,7 @@ simple_tensor_methods = [
# rounding
"ceil", "round", "floor", "trunc",
# binary
"mul", "div", "maximum", "minimum",
"mul", "div", "maximum", "minimum", "copysign",
# modify
"tril", "triu",
# reduce
@@ -244,7 +244,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow,
"aten.pow.Tensor_Scalar_out": Tensor.pow,
"aten.pow.Scalar_out": lambda x,y: x**y,
"aten.pow.Scalar_out": lambda input,exponent: input**exponent,
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
"aten.bitwise_xor.Tensor_out": Tensor.bitwise_xor,
@@ -254,16 +254,20 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__,
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
"aten.clamp_max.Tensor_out": lambda self,max_: self.clamp(max_=max_),
"aten.clamp_min.Tensor_out": lambda self,min_: self.clamp(min_=min_),
"aten.clamp_max.Tensor_out": lambda input,max_: input.clamp(max_=max_),
"aten.clamp_min.Tensor_out": lambda input,min_: input.clamp(min_=min_),
"aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other,
# TODO: this might result in overflow issues
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
# TODO: support this in tinygrad
"aten.bitwise_left_shift.Tensor_out": lambda self, other: Tensor(self << other.numpy()),
"aten.bitwise_right_shift.Tensor_out": lambda self, other: Tensor(self >> other.numpy()),
"aten.bitwise_left_shift.Tensor_out": lambda input,other: Tensor(input << other.numpy()),
"aten.bitwise_right_shift.Tensor_out": lambda input,other: Tensor(input >> other.numpy()),
# not in tinygrad. are there decomps for these?
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
"aten.log1p.out": lambda self: (self+1).log(),
"aten.expm1.out": lambda self: self.exp() - 1,
"aten.copysign.out": Tensor.copysign,
"aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))),
"aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))),
# TODO: this gets the shape wrong
#"aten.arange.start_out": Tensor.arange,
"aten.lerp.Scalar_out": Tensor.lerp,
@@ -363,7 +367,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
if param is None: continue
tinygrad_tensors.append(param.data)
for state_dict in optimizer.state.values():
for key, value in state_dict.items():
for _, value in state_dict.items():
if torch.is_tensor(value): tinygrad_tensors.append(value)
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)