diff --git a/examples/transformer.py b/examples/transformer.py index 85cc4e8527..4eb3a7b334 100755 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -48,14 +48,16 @@ class TransformerBlock: query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T) + score = query.dot(key) + print(query.shape) + print(key.shape) + print(score.shape) #score = query.reshape(shape=(-1, self.projection_dim)).dot( # key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0))) #scaled_score = score * (1/np.sqrt(self.projection_dim)) - print(query.shape) - print(key.shape) #print(value.shape) #print(scaled_score.shape) diff --git a/test/test_ops.py b/test/test_ops.py index 7ca73c3710..10b9c411e8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -39,6 +39,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0 print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp)) +# TODO: everywhere you see this, make the op work on GPU def cpu_only(func): def wrapper(self): if self.device == Device.CPU: @@ -70,6 +71,9 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device) + @cpu_only + def test_multidot(self): + helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device) def test_sum(self): helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device) def test_sum_axis(self): @@ -113,7 +117,7 @@ class TestOps(unittest.TestCase): def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device) - @cpu_only # TODO: transpose for GPU + @cpu_only def test_transpose(self): helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device) diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 7b8258eec9..bac1c867d9 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -77,13 +77,13 @@ class Dot(Function): @staticmethod def forward(ctx, input, weight): ctx.save_for_backward(input, weight) - return input.dot(weight) + return input @ weight @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors - grad_input = grad_output.dot(weight.T) - grad_weight = input.T.dot(grad_output) + grad_input = grad_output @ np.swapaxes(weight, -2, -1) + grad_weight = np.swapaxes(input, -2, -1) @ grad_output return grad_input, grad_weight register('dot', Dot)