mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Torch Backend Refinement (#9327)
* fix some torch tests * fixup * small change * fixup * fix test * use default function * add todo * bunch of small changes * fix tests * more tests * fix * fix * test fix * simplify
This commit is contained in:
committed by
GitHub
parent
23084fd850
commit
b4028e48ae
@@ -9,7 +9,7 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
# https://pytorch.org/docs/stable/torch.compiler_ir.html
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
@@ -95,7 +95,7 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
|
||||
def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
|
||||
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
|
||||
# TODO: supprt stride [] in tinygrad?
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: support return_indices in tinygrad
|
||||
@@ -104,12 +104,12 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
|
||||
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
|
||||
def max_pool2d_with_indices_backward(grad_out:Tensor, self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
|
||||
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: utilize input indices once they are correct
|
||||
grad_out, self = unwrap(grad_out), unwrap(self)
|
||||
out = Tensor.max_pool2d(self, kernel_size, stride, dilation, padding, ceil_mode)
|
||||
return wrap(out.gradient(self, gradient=grad_out)[0])
|
||||
self_ = unwrap(self)
|
||||
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
|
||||
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])
|
||||
|
||||
@torch.library.impl("aten::arange", "privateuseone")
|
||||
def arange(end, dtype=None, device=None, pin_memory=None):
|
||||
@@ -226,7 +226,7 @@ simple_tensor_methods = [
|
||||
# rounding
|
||||
"ceil", "round", "floor", "trunc",
|
||||
# binary
|
||||
"mul", "div", "maximum", "minimum",
|
||||
"mul", "div", "maximum", "minimum", "copysign",
|
||||
# modify
|
||||
"tril", "triu",
|
||||
# reduce
|
||||
@@ -244,7 +244,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
"aten.remainder.Tensor_out": Tensor.mod,
|
||||
"aten.pow.Tensor_Tensor_out": Tensor.pow,
|
||||
"aten.pow.Tensor_Scalar_out": Tensor.pow,
|
||||
"aten.pow.Scalar_out": lambda x,y: x**y,
|
||||
"aten.pow.Scalar_out": lambda input,exponent: input**exponent,
|
||||
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
|
||||
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
|
||||
"aten.bitwise_xor.Tensor_out": Tensor.bitwise_xor,
|
||||
@@ -254,16 +254,20 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
"aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__,
|
||||
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
|
||||
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
|
||||
"aten.clamp_max.Tensor_out": lambda self,max_: self.clamp(max_=max_),
|
||||
"aten.clamp_min.Tensor_out": lambda self,min_: self.clamp(min_=min_),
|
||||
"aten.clamp_max.Tensor_out": lambda input,max_: input.clamp(max_=max_),
|
||||
"aten.clamp_min.Tensor_out": lambda input,min_: input.clamp(min_=min_),
|
||||
"aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other,
|
||||
# TODO: this might result in overflow issues
|
||||
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
|
||||
# TODO: support this in tinygrad
|
||||
"aten.bitwise_left_shift.Tensor_out": lambda self, other: Tensor(self << other.numpy()),
|
||||
"aten.bitwise_right_shift.Tensor_out": lambda self, other: Tensor(self >> other.numpy()),
|
||||
"aten.bitwise_left_shift.Tensor_out": lambda input,other: Tensor(input << other.numpy()),
|
||||
"aten.bitwise_right_shift.Tensor_out": lambda input,other: Tensor(input >> other.numpy()),
|
||||
# not in tinygrad. are there decomps for these?
|
||||
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
|
||||
"aten.log1p.out": lambda self: (self+1).log(),
|
||||
"aten.expm1.out": lambda self: self.exp() - 1,
|
||||
"aten.copysign.out": Tensor.copysign,
|
||||
"aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))),
|
||||
"aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))),
|
||||
# TODO: this gets the shape wrong
|
||||
#"aten.arange.start_out": Tensor.arange,
|
||||
"aten.lerp.Scalar_out": Tensor.lerp,
|
||||
@@ -363,7 +367,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
if param is None: continue
|
||||
tinygrad_tensors.append(param.data)
|
||||
for state_dict in optimizer.state.values():
|
||||
for key, value in state_dict.items():
|
||||
for _, value in state_dict.items():
|
||||
if torch.is_tensor(value): tinygrad_tensors.append(value)
|
||||
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
|
||||
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
|
||||
|
||||
Reference in New Issue
Block a user