fix return dtype of getitem Tensor indexing (#4158)

the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype
This commit is contained in:
chenyu
2024-04-12 15:55:02 -04:00
committed by GitHub
parent f6c8032e5d
commit d9c5a2b1bb
2 changed files with 9 additions and 2 deletions

View File

@@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
X = X_data[indices]
assert X.dtype == X_data.dtype
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):

View File

@@ -502,12 +502,12 @@ class Tensor:
masks.append(i == a)
# reduce masks to 1 mask
mask = functools.reduce(lambda x,y: x.mul(y), masks)
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
# inject 1's for the extra dims added in create masks
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
# sum reduce the extra dims introduced in create masks
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()))
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):