mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix return dtype of getitem Tensor indexing (#4158)
the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype
This commit is contained in:
@@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
|
||||
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
|
||||
X = X_data[indices]
|
||||
assert X.dtype == X_data.dtype
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
||||
@@ -502,12 +502,12 @@ class Tensor:
|
||||
masks.append(i == a)
|
||||
|
||||
# reduce masks to 1 mask
|
||||
mask = functools.reduce(lambda x,y: x.mul(y), masks)
|
||||
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
||||
|
||||
# inject 1's for the extra dims added in create masks
|
||||
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()))
|
||||
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
||||
|
||||
# special permute case
|
||||
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
||||
|
||||
Reference in New Issue
Block a user