From 8a04107d30babd3dfcdf8fc7b7a8ea24aadf5c30 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 20 Dec 2023 23:50:37 -0500 Subject: [PATCH] move the op casting logic from mlops to tensor try 2 (#2887) * unary works * where works * add sub mul * xor div * CMPLT * sparse_categorical_crossentropy * image const * sparse_categorical_crossentropy --- extra/onnx_ops.py | 4 ++-- tinygrad/lazy.py | 4 +++- tinygrad/mlops.py | 57 ++++++++++++++++++++++------------------------ tinygrad/tensor.py | 27 ++++++++++++---------- 4 files changed, 47 insertions(+), 45 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 3588f73032..f87c750c9e 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -107,7 +107,7 @@ def Atan(y: Tensor): def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0 - return x.triu(k) if upper else x.tril(k) + return x.triu(k).cast(dtypes.int64) if upper else x.tril(k).cast(dtypes.int64) def Squeeze(data: Tensor, axes): if isinstance(axes, Tensor): axes = safe_numpy(axes) @@ -122,7 +122,7 @@ def Unsqueeze(data: Tensor, axes): new_shape[i] = next(ptr) return data.reshape(new_shape) -def Binarizer(input, threshold=0.0): return input > threshold +def Binarizer(input, threshold=0.0): return (input > threshold).cast(dtypes.float32) def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0): axis = axis + x.ndim if axis < 0 else axis diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index dbb0f3936a..345ef2ce08 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -59,7 +59,9 @@ class LazyBuffer: return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ()) def const(self, val:Union[float, int]) -> LazyBuffer: - return LazyBuffer.loadop(LoadOps.CONST, (), self.dtype, self.device, val).reshape((1,)*len(self.shape)).expand(self.shape) + # NOTE: we force the image dtype const to be a float32 + const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32 + return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape) def contiguous(self): if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const(): diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 001fca6f28..0649ec41ce 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,6 +1,6 @@ import math from typing import Tuple, Optional, cast -from tinygrad.helpers import argsort, DType, least_upper_float, dtypes, least_upper_dtype +from tinygrad.helpers import argsort, DType, dtypes from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer @@ -35,7 +35,7 @@ class Neg(Function): class Sin(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x - return x.cast(least_upper_float(x.dtype)).e(UnaryOps.SIN) + return x.e(UnaryOps.SIN) def backward(self, grad:LazyBuffer) -> LazyBuffer: return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad) @@ -47,19 +47,19 @@ class Relu(Function): return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output) + return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output) class Log(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x - return x.cast(ftype:= least_upper_float(x.dtype)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2))) + return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.cast(ftype:=least_upper_float(x.dtype)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2) + self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -67,7 +67,7 @@ class Exp(Function): class Sqrt(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.cast(least_upper_float(x.dtype)).e(UnaryOps.SQRT) + self.ret = x.e(UnaryOps.SQRT) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -78,7 +78,6 @@ class Sqrt(Function): # TODO: have the backend automatically find this class Sigmoid(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - x = x.cast(least_upper_float(x.dtype)) self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2))) return self.ret @@ -89,60 +88,58 @@ class Sigmoid(Function): class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - output_dtype = least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype)) + # in webgpu bool cannot be used as a storage buffer type + return x.e(BinaryOps.CMPLT, y).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool) class Xor(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - output_dtype = least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype)) + return x.e(BinaryOps.XOR, y) class Add(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.ADD, y.cast(output_dtype)) + return x.e(BinaryOps.ADD, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \ - grad_output.cast(self.y_dtype) if self.needs_input_grad[1] else None + return grad_output if self.needs_input_grad[0] else None, \ + grad_output if self.needs_input_grad[1] else None class Sub(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.SUB, y.cast(output_dtype)) + self.x_dtype, self.y_dtype = x.dtype, y.dtype + return x.e(BinaryOps.SUB, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \ - grad_output.cast(self.y_dtype).e(UnaryOps.NEG) if self.needs_input_grad[1] else None + return grad_output if self.needs_input_grad[0] else None, \ + grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None class Mul(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.MUL, y.cast(output_dtype)) + self.x, self.y = x, y + return x.e(BinaryOps.MUL, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return self.y.cast(self.x.dtype).e(BinaryOps.MUL, grad_output.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \ - self.x.cast(self.y.dtype).e(BinaryOps.MUL, grad_output.cast(self.y.dtype)) if self.needs_input_grad[1] else None + return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \ + self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None class Div(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype) - return x.cast(output_dtype).e(BinaryOps.DIV, y.cast(output_dtype)) + self.x, self.y = x, y + return x.e(BinaryOps.DIV, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output.cast(self.x.dtype).e(BinaryOps.DIV, self.y.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \ - grad_output.cast(self.y.dtype).e(UnaryOps.NEG).e(BinaryOps.MUL, self.x.cast(self.y.dtype)).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501 + return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ + grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501 # ************* ternary ops ************* class Where(Function): def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: - self.x, self.y_dtype, self.z_dtype, output_type = x.cast(dtypes.bool), y.dtype, z.dtype, least_upper_dtype(y.dtype, z.dtype) - return self.x.e(TernaryOps.WHERE, y.cast(output_type), z.cast(output_type)) + self.x = x + return self.x.e(TernaryOps.WHERE, y, z) def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: return None, \ - self.x.e(TernaryOps.WHERE, grad_output.cast(self.y_dtype), grad_output.cast(self.y_dtype).const(0)) if self.needs_input_grad[1] else None, \ + self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None # ************* reduce ops ************* diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 97327e081e..ededc24d9f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from functools import partialmethod, reduce from itertools import accumulate import numpy as np -from tinygrad.helpers import DType, dtypes, ImageDType +from tinygrad.helpers import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten from tinygrad.lazy import LazyBuffer, create_schedule from tinygrad.ops import LoadOps @@ -249,7 +249,7 @@ class Tensor: # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous # this is "implicit gradient creation" - self.grad = Tensor(1, device=self.device, requires_grad=False) + self.grad = Tensor(1.0, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): assert (t0.grad is not None) @@ -670,14 +670,14 @@ class Tensor: def neg(self): return mlops.Neg.apply(self) def contiguous(self): return mlops.Contiguous.apply(self) def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) - def log(self): return mlops.Log.apply(self) - def log2(self): return mlops.Log.apply(self)/math.log(2) - def exp(self): return mlops.Exp.apply(self) + def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype))) + def log2(self): return self.log()/math.log(2) + def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype))) def exp2(self): return mlops.Exp.apply(self*math.log(2)) def relu(self): return mlops.Relu.apply(self) - def sigmoid(self): return mlops.Sigmoid.apply(self) - def sin(self): return mlops.Sin.apply(self) - def sqrt(self): return mlops.Sqrt.apply(self) + def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype))) + def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype))) + def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype))) def rsqrt(self): return self.reciprocal().sqrt() def cos(self): return ((math.pi/2)-self).sin() def tan(self): return self.sin() / self.cos() @@ -729,6 +729,9 @@ class Tensor: x = x.cast(y_dtype) y = Tensor(y, self.device, y_dtype, requires_grad=False) + output_dtype = least_upper_dtype(x.dtype, y.dtype) + x, y = x.cast(output_dtype), y.cast(output_dtype) + if reverse: x, y = y, x # left pad shape with 1s @@ -784,7 +787,7 @@ class Tensor: def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): x_,y = self._broadcasted(input_) x,z = x_._broadcasted(other) - return mlops.Where.apply(x, *y._broadcasted(z)) + return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z)) # ***** op wrappers (wasted lines to make the typechecker happy) ***** @@ -860,9 +863,9 @@ class Tensor: def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor: # NOTE: self is a logits input - loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501 - y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) + loss_mask = (Y != ignore_index).cast(dtypes.float) + y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501 + y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) return self.log_softmax().mul(y).sum() / loss_mask.sum() # ***** cast ops *****