diff --git a/docs/adding_new_accelerators.md b/docs/adding_new_accelerators.md index b0ac5b40a6..6f90cb7f1c 100644 --- a/docs/adding_new_accelerators.md +++ b/docs/adding_new_accelerators.md @@ -7,9 +7,9 @@ It's pretty easy to add a new accelerator to tinygrad. All you need to do is imp These are the ops that you must implement for your accelerator of choice. Compiled Accelerators do not need to implement movement_ops, as they are handled by the ShapeTracker. ``` Buffer # class of memory on this device -unary_op (NOOP, EXP2, LOG2, CAST, SIN) # A -> A +unary_op (NOOP, EXP2, LOG2, CAST, SIN, SQRT) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) -binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size) +binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size) movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size) load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device) fused_op [[optional]] (MULACC) # A * A -> B diff --git a/extra/accel/triton/ops_triton.py b/extra/accel/triton/ops_triton.py index f9cc00ecbc..f38cc263ed 100644 --- a/extra/accel/triton/ops_triton.py +++ b/extra/accel/triton/ops_triton.py @@ -21,9 +21,9 @@ stream = cuda.Stream() class TritonASTKernel(ASTKernel): code_for_op : Dict[Op, str] = { UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "tl.maximum(A, 0.0)", UnaryOps.GT0: "tl.where(A>0,1,0)", - UnaryOps.EXP: "tl.exp(A)", UnaryOps.LOG: "tl.log(A)", UnaryOps.RECIPROCAL: "(1.0/A)", + UnaryOps.EXP: "tl.exp(A)", UnaryOps.LOG: "tl.log(A)", UnaryOps.RECIPROCAL: "(1.0/A)", UnaryOps.SQRT: "tl.sqrt(A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", - BinaryOps.DIV: "(A/B)", BinaryOps.POW: "tl.exp(tl.log(A)*B)", BinaryOps.CMPEQ: "(A==B)", + BinaryOps.DIV: "(A/B)", BinaryOps.CMPEQ: "(A==B)", ReduceOps.SUM: "A += B", ReduceOps.MAX: "A = tl.maximum(A,B)" } start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "float('-inf')"} diff --git a/test/test_ops.py b/test/test_ops.py index 69a03a7ba8..a0e447adc3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -229,6 +229,9 @@ class TestOps(unittest.TestCase): def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0) + def test_rsqrt(self): + helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) def test_sin(self): helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) diff --git a/test/test_optim.py b/test/test_optim.py index 1224182b1a..880a10141a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -71,10 +71,10 @@ class TestOptim(unittest.TestCase): def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-5, 1e-5) def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0) - def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 1e-5, 3e-4) + def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-4, 5e-4) def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0) - def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 1e-5, 3e-4) + def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3) def test_duped_weights(self): for Opt in [Adam, AdamW, SGD]: diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py index 3997b349da..3075e126b3 100644 --- a/tinygrad/codegen/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -143,13 +143,6 @@ class AssemblyCodegen(Linearizer): pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) - elif args == BinaryOps.POW: - # TODO: add UnaryOps.SQRT - tmp = newreg((newvar, "exp_a")) - tmp2 = newreg((newvar, "exp_a_times_b")) - ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2)) - ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL)) - ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2)) elif args == BinaryOps.DIV and self.no_div: tmp = newreg((newvar, "rcp")) ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP)) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index d900b78e9b..04834ce432 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -50,9 +50,10 @@ code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", + UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", - BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})" } diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index f32fd520ed..4888d3f948 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -22,11 +22,11 @@ code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)), BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), - BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)), diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 1bf04ff4f9..f9801fd810 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -55,6 +55,15 @@ class Exp(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.binary_op(BinaryOps.MUL, grad_output) +class Sqrt(Function): + __slots__ = "ret" + def forward(self, x:LazyBuffer) -> LazyBuffer: + self.ret = x.unary_op(UnaryOps.SQRT) + return self.ret + + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2))) + # NOTE: the implicit derivative of sigmoid is not stable # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e # TODO: have the backend automatically find this @@ -142,16 +151,6 @@ class Mul(Function): return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \ self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None -class Pow(Function): - __slots__ = 'x', 'y', 'ret' - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y) - return self.ret - - def backward(self, grad_output:LazyBuffer): - return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \ - grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None - class Div(Function): __slots__ = 'x', 'y' def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: @@ -217,4 +216,4 @@ class Flip(Function): return x.stride(self.arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.stride(self.arg) \ No newline at end of file + return grad_output.stride(self.arg) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b100b8b338..e3655171c1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -12,8 +12,8 @@ if TYPE_CHECKING: # the Enum class doesn't work with mypy, this is static. sorry it's ugly # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block -class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702 +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 14013c2329..e027a29315 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,7 +10,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b) base_fxn_for_op: Dict[Op, Callable] = { - BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, + BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)], @@ -32,7 +32,7 @@ def match_types(x, y): numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin, - BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), + BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), UnaryOps.SQRT: np.sqrt, MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to, MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)], FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to), diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 3320fc089f..2c30ad8bc3 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch. inverse_type_map = {v:k for k,v in type_map.items()} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, + UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2105d81068..4430c048b7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -487,6 +487,8 @@ class Tensor: def relu(self): return mlops.Relu.apply(self) def sigmoid(self): return mlops.Sigmoid.apply(self) def sin(self): return mlops.Sin.apply(self) + def sqrt(self): return mlops.Sqrt.apply(self) + def rsqrt(self): return (1/self).sqrt() def cos(self): return ((pi/2)-self).sin() def tan(self): return self.sin() / self.cos() @@ -504,8 +506,6 @@ class Tensor: return (self < b).where(b-1, b) def __neg__(self): return 0.0-self - def sqrt(self): return self.pow(0.5) - def rsqrt(self): return self.pow(-0.5) def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu() @@ -552,13 +552,15 @@ class Tensor: def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self + def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x) def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: if x.__class__ is not Tensor and not reverse: # simple pow identities + if x < 0: return (1.0/self).pow(-x) if x == 2.0: return self*self - if x == -1.0: return 1/self - return self._broadcasted(mlops.Pow, x, reverse) if x.__class__ is Tensor or x != 1.0 or reverse else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x) + if x == 1.0: return self + if x == 0.5: return self.sqrt() + return self.log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(x)).exp() def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)