From d9c5a2b1bbfc439d3d9901dc58d7bb4f8de8f774 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 12 Apr 2024 15:55:02 -0400 Subject: [PATCH] fix return dtype of getitem Tensor indexing (#4158) the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype --- test/test_dtype.py | 7 +++++++ tinygrad/tensor.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 13438d52a8..18953b6117 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase): assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32 assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32 + @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints)) + def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype): + X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype) + indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype) + X = X_data[indices] + assert X.dtype == X_data.dtype + class TestTypePromotion(unittest.TestCase): @given(strat.sampled_from(core_dtypes)) def test_self_promo_to_self(self, dtype): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7a16a93a0e..76318a6f50 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -502,12 +502,12 @@ class Tensor: masks.append(i == a) # reduce masks to 1 mask - mask = functools.reduce(lambda x,y: x.mul(y), masks) + mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks) # inject 1's for the extra dims added in create masks sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:] # sum reduce the extra dims introduced in create masks - ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys())) + ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype) # special permute case if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):