mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
lshift and rshift (#4591)
This commit is contained in:
@@ -230,8 +230,8 @@ def NF4Linear(block_size):
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
high_bits = self.weight
|
||||
low_bits = (self.weight * 2 ** 4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
|
||||
low_bits = self.weight.lshift(4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4)
|
||||
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
|
||||
@@ -437,6 +437,28 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
||||
def test_lshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.uint32)
|
||||
# cast to int32 because torch does not support uint32
|
||||
helper_test_op([], lambda: tor << 0, lambda: (ten << 0).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor << 2, lambda: (ten << 2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor << 31, lambda: (ten << 31).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.__lshift__(2), lambda: ten.__lshift__(2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.bitwise_left_shift(2), lambda: ten.lshift(2).cast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_rshift(self):
|
||||
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.uint32)
|
||||
# cast to int32 because torch does not support uint32
|
||||
helper_test_op([], lambda: tor >> 0, lambda: (ten >> 0).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor >> 2, lambda: (ten >> 2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor >> 31, lambda: (ten >> 31).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True)
|
||||
helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin())
|
||||
helper_test_op([()], lambda x: x.sin())
|
||||
|
||||
@@ -344,9 +344,9 @@ class Tensor:
|
||||
|
||||
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
|
||||
for i in range(5):
|
||||
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False)))
|
||||
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
|
||||
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
||||
out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24)
|
||||
out = x[0].cat(x[1])[:num].rshift(8).cast(dtypes.float32).div(2 ** 24)
|
||||
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
||||
out.requires_grad = kwargs.get("requires_grad")
|
||||
return out.contiguous()
|
||||
@@ -1218,6 +1218,14 @@ class Tensor:
|
||||
return F.Div.apply(numerator, denominator)
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def lshift(self, x:int):
|
||||
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
||||
return self.mul(2 ** x)
|
||||
|
||||
def rshift(self, x:int):
|
||||
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
||||
return self.div(2 ** x, upcast=False)
|
||||
|
||||
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
@@ -1264,6 +1272,8 @@ class Tensor:
|
||||
def __truediv__(self, x) -> Tensor: return self.div(x)
|
||||
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
||||
def __xor__(self, x) -> Tensor: return self.xor(x)
|
||||
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
||||
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
||||
|
||||
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
||||
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
||||
@@ -1280,6 +1290,8 @@ class Tensor:
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
||||
|
||||
Reference in New Issue
Block a user