diff --git a/test/test_dtype.py b/test/test_dtype.py index 13438d52a8..18953b6117 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase): assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32 assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32 + @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints)) + def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype): + X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype) + indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype) + X = X_data[indices] + assert X.dtype == X_data.dtype + class TestTypePromotion(unittest.TestCase): @given(strat.sampled_from(core_dtypes)) def test_self_promo_to_self(self, dtype): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7a16a93a0e..76318a6f50 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -502,12 +502,12 @@ class Tensor: masks.append(i == a) # reduce masks to 1 mask - mask = functools.reduce(lambda x,y: x.mul(y), masks) + mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks) # inject 1's for the extra dims added in create masks sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:] # sum reduce the extra dims introduced in create masks - ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys())) + ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype) # special permute case if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):