Remove POW llop and add SQRT llop (#1104)

* fixed division by zero for fast operations

* made et closer to 0

* replace POW llop with SQRT

* updated mlops to swap SQRT and POW llops

* updated hlops to swap POW and SQRT

* added sqrt llop to cpu runtime

* added sqrt llop to cstyle codegen

* added POW llop to llvm ir codegen

* added SQRT llop to torch runtime

* moved pow from mlops to hlops

* found a better way to do reverse pow

* fixed indentation

* added SQRT llop to triton

* update docs to match new llops

* removed POW operator from assembly codegen

* added sqrt and rsqrt to pow hlop

* rewrote pow function in tensor.py

* Adjust tolerance

* Adjust for adamw

* Reduce for Adam too

* removed accidental leftover code

* removed all of accidental code

* added rsqrt test

* removed pow from mlops again

it was added back when resolving merge conflicts

---------

Co-authored-by: Jacky Lee <jla524@sfu.ca>
This commit is contained in:
Eli Frigo
2023-07-05 20:07:58 -05:00
committed by GitHub
parent b7369ffcff
commit 801564f31b
12 changed files with 34 additions and 36 deletions

View File

@@ -7,9 +7,9 @@ It's pretty easy to add a new accelerator to tinygrad. All you need to do is imp
These are the ops that you must implement for your accelerator of choice. Compiled Accelerators do not need to implement movement_ops, as they are handled by the ShapeTracker.
```
Buffer # class of memory on this device
unary_op (NOOP, EXP2, LOG2, CAST, SIN) # A -> A
unary_op (NOOP, EXP2, LOG2, CAST, SIN, SQRT) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
fused_op [[optional]] (MULACC) # A * A -> B

View File

@@ -21,9 +21,9 @@ stream = cuda.Stream()
class TritonASTKernel(ASTKernel):
code_for_op : Dict[Op, str] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "tl.maximum(A, 0.0)", UnaryOps.GT0: "tl.where(A>0,1,0)",
UnaryOps.EXP: "tl.exp(A)", UnaryOps.LOG: "tl.log(A)", UnaryOps.RECIPROCAL: "(1.0/A)",
UnaryOps.EXP: "tl.exp(A)", UnaryOps.LOG: "tl.log(A)", UnaryOps.RECIPROCAL: "(1.0/A)", UnaryOps.SQRT: "tl.sqrt(A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "tl.exp(tl.log(A)*B)", BinaryOps.CMPEQ: "(A==B)",
BinaryOps.DIV: "(A/B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "A += B", ReduceOps.MAX: "A = tl.maximum(A,B)"
}
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "float('-inf')"}

View File

@@ -229,6 +229,9 @@ class TestOps(unittest.TestCase):
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0)
def test_rsqrt(self):
helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0)
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)

View File

@@ -71,10 +71,10 @@ class TestOptim(unittest.TestCase):
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-5, 1e-5)
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 1e-5, 3e-4)
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-4, 5e-4)
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 1e-5, 3e-4)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3)
def test_duped_weights(self):
for Opt in [Adam, AdamW, SGD]:

View File

@@ -143,13 +143,6 @@ class AssemblyCodegen(Linearizer):
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.POW:
# TODO: add UnaryOps.SQRT
tmp = newreg((newvar, "exp_a"))
tmp2 = newreg((newvar, "exp_a_times_b"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2))
elif args == BinaryOps.DIV and self.no_div:
tmp = newreg((newvar, "rcp"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))

View File

@@ -50,9 +50,10 @@ code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
}

View File

@@ -22,11 +22,11 @@ code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
FusedOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),

View File

@@ -55,6 +55,15 @@ class Exp(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.binary_op(BinaryOps.MUL, grad_output)
class Sqrt(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.unary_op(UnaryOps.SQRT)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
@@ -142,16 +151,6 @@ class Mul(Function):
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Pow(Function):
__slots__ = 'x', 'y', 'ret'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
return self.ret
def backward(self, grad_output:LazyBuffer):
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
class Div(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
@@ -217,4 +216,4 @@ class Flip(Function):
return x.stride(self.arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.stride(self.arg)
return grad_output.stride(self.arg)

View File

@@ -12,8 +12,8 @@ if TYPE_CHECKING:
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702

View File

@@ -10,7 +10,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
base_fxn_for_op: Dict[Op, Callable] = {
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
@@ -32,7 +32,7 @@ def match_types(x, y):
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),

View File

@@ -10,7 +10,7 @@ type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.
inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).type(torch.promote_types(x.dtype, y.dtype)),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),

View File

@@ -487,6 +487,8 @@ class Tensor:
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self)
def sin(self): return mlops.Sin.apply(self)
def sqrt(self): return mlops.Sqrt.apply(self)
def rsqrt(self): return (1/self).sqrt()
def cos(self): return ((pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@@ -504,8 +506,6 @@ class Tensor:
return (self < b).where(b-1, b)
def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5)
def rsqrt(self): return self.pow(-0.5)
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
@@ -552,13 +552,15 @@ class Tensor:
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if x.__class__ is not Tensor and not reverse:
# simple pow identities
if x < 0: return (1.0/self).pow(-x)
if x == 2.0: return self*self
if x == -1.0: return 1/self
return self._broadcasted(mlops.Pow, x, reverse) if x.__class__ is Tensor or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
if x == 1.0: return self
if x == 0.5: return self.sqrt()
return self.log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(x)).exp()
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)