lshift and rshift (#4591)

This commit is contained in:
chenyu
2024-05-14 19:16:31 -04:00
committed by GitHub
parent 45e7400e3c
commit 2b0ee74bb6
3 changed files with 38 additions and 4 deletions

View File

@@ -230,8 +230,8 @@ def NF4Linear(block_size):
def __call__(self, x: Tensor) -> Tensor:
high_bits = self.weight
low_bits = (self.weight * 2 ** 4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
low_bits = self.weight.lshift(4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4)
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)

View File

@@ -437,6 +437,28 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
def test_lshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.uint32)
# cast to int32 because torch does not support uint32
helper_test_op([], lambda: tor << 0, lambda: (ten << 0).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor << 2, lambda: (ten << 2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor << 31, lambda: (ten << 31).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.__lshift__(2), lambda: ten.__lshift__(2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.bitwise_left_shift(2), lambda: ten.lshift(2).cast(dtypes.int32), forward_only=True)
def test_rshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.uint32)
# cast to int32 because torch does not support uint32
helper_test_op([], lambda: tor >> 0, lambda: (ten >> 0).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor >> 2, lambda: (ten >> 2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor >> 31, lambda: (ten >> 31).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True)
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())

View File

@@ -344,9 +344,9 @@ class Tensor:
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
for i in range(5):
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False)))
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24)
out = x[0].cat(x[1])[:num].rshift(8).cast(dtypes.float32).div(2 ** 24)
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
out.requires_grad = kwargs.get("requires_grad")
return out.contiguous()
@@ -1218,6 +1218,14 @@ class Tensor:
return F.Div.apply(numerator, denominator)
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse))
def lshift(self, x:int):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.mul(2 ** x)
def rshift(self, x:int):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.div(2 ** x, upcast=False)
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
@@ -1264,6 +1272,8 @@ class Tensor:
def __truediv__(self, x) -> Tensor: return self.div(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __xor__(self, x) -> Tensor: return self.xor(x)
def __lshift__(self, x) -> Tensor: return self.lshift(x)
def __rshift__(self, x) -> Tensor: return self.rshift(x)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
@@ -1280,6 +1290,8 @@ class Tensor:
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))