advance indexing to mixin [PR] (#16532)

This commit is contained in:
chenyu
2026-06-08 09:24:49 -04:00
committed by GitHub
parent 95d63d6c07
commit ebc5390c9a
6 changed files with 167 additions and 118 deletions

View File

@@ -2,8 +2,8 @@ import ctypes, gzip, unittest, timeit, pickle
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count, all_same
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.helpers import ceildiv, ansistrip
from tinygrad.tensor import Tensor, get_shape
from tinygrad.helpers import ceildiv, ansistrip, get_shape
from tinygrad.tensor import Tensor
import numpy as np
VARIABLE = ContextVar("VARIABLE", 0)

View File

@@ -10,6 +10,9 @@ def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
def _t(*shape):
return Tensor.arange(math.prod(shape)).reshape(*shape)
def _ti(data): # int32 index tensor
return Tensor(data, dtype=dtypes.int32)
# Tensor().func().uop should be the same as UOp.func()
def _check(tc: unittest.TestCase, t: Tensor, fn):
tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}")
@@ -110,6 +113,44 @@ class TestTensorUOpGetitem(unittest.TestCase):
def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2])
def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None])
# ---- advanced indexing: UOp index on UOp must match Tensor index on Tensor (same uop) ----
def _check_adv(self, t, idx):
# idx is a Tensor or a tuple mixing Tensor index arrays with slice/int/None/Ellipsis
ui = idx.uop if isinstance(idx, Tensor) else (tuple(i.uop if isinstance(i,Tensor) else i for i in idx) if isinstance(idx, tuple) else idx)
self.assertIs(t[idx].uop, t.uop[ui])
def test_adv_single(self): self._check_adv(_t(5), _ti([2,1,0,1,2]))
def test_adv_negative(self): self._check_adv(_t(5), _ti([-1,-2,0]))
def test_adv_out_of_bounds(self): self._check_adv(_t(5), _ti([4,7,2])) # oob -> 0
def test_adv_2d_dim0(self): self._check_adv(_t(3,4), _ti([2,0,1]))
def test_adv_2d_index_array(self): self._check_adv(_t(4,5), _ti([[0,1],[2,3]])) # shaped index
def test_adv_two_consecutive(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([1,2,3]))) # linear path
def test_adv_two_broadcast(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([[1],[2],[3]])))
def test_adv_three_consecutive(self):self._check_adv(_t(3,4,5), (_ti([2,0]), _ti([1,3]), _ti([4,2])))
def test_adv_after_slice(self): self._check_adv(_t(2,3,4), (slice(None), _ti([2,0,1])))
def test_adv_non_consec_permute(self): self._check_adv(_t(2,3,4), (_ti([1,0]), slice(None), _ti([2,0])))
def test_adv_idx_then_int(self): self._check_adv(_t(3,4), (_ti([2,0,1]), 2))
def test_adv_int_then_idx(self): self._check_adv(_t(3,4), (1, _ti([2,0,1])))
def test_adv_idx_then_none(self): self._check_adv(_t(3,4), (_ti([2,0,1]), None))
def test_adv_none_then_idx(self): self._check_adv(_t(3,4), (None, _ti([2,0,1])))
def test_adv_ellipsis_then_idx(self):self._check_adv(_t(2,3,4), (Ellipsis, _ti([2,0,1])))
def test_adv_idx_slice_mix(self): self._check_adv(_t(4,5,6), (_ti([1,3]), slice(1,4), _ti([2,0])))
# bool index is unsupported
def test_adv_bool_index_rejected(self):
with self.assertRaises(IndexError): _t(5)[_t(5) > 2]
with self.assertRaises(IndexError): _t(5).uop[(_t(5) > 2).uop]
# python list/tuple indices
def test_adv_python_list(self):
self.assertIs(_strip_unique(_t(5)[[2,1,0]].uop), _strip_unique(_t(5).uop[[2,1,0]]))
def test_adv_python_list_negative(self):
self.assertIs(_strip_unique(_t(5)[[-1,-2,0]].uop), _strip_unique(_t(5).uop[[-1,-2,0]]))
def test_adv_python_list_nested(self):
self.assertIs(_strip_unique(_t(3,4)[[0,1],[2,0]].uop), _strip_unique(_t(3,4).uop[[0,1],[2,0]]))
def test_adv_python_list_2d_index(self):
self.assertIs(_strip_unique(_t(4,5)[[[0,1],[2,3]]].uop), _strip_unique(_t(4,5).uop[[[0,1],[2,3]]]))
class TestTensorUOpCumalu(unittest.TestCase):
def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum())
def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1))

View File

@@ -29,6 +29,11 @@ def argfix(*x):
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input
def get_shape(x) -> tuple[int, ...]:
# NOTE: str is special because iterating it still yields strs
if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return ()
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color:str|None, background=False): # replace the termcolor library
if NO_COLOR: return st

View File

@@ -7,7 +7,8 @@ from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
from tinygrad.helpers import resolve_pool_pads, round_up
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
@@ -50,6 +51,102 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self: return self._getitem(indices)
def _getitem(self, indices, v=None) -> Self:
from tinygrad.uop.ops import UOp
def is_adv(i): return isinstance(i,(list,tuple)) or (isinstance(i,type(self)) and (not isinstance(i,UOp) or i.shape != ()))
# view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__
if v is None and not any(is_adv(i) for i in (indices if isinstance(indices,tuple) else (indices,))):
return super().__getitem__(indices)
# wrap single index into a list
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
indices_parsed, dim = [], 0
for index in self._normalize_indices(list(indices)):
size = 1 if index is None else self.shape[dim]
parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False}
if isinstance(index,(list,tuple)):
flat = fully_flatten(index)
inferred = dtypes.bool if (flat and all(isinstance(s,bool) for s in flat)) else \
(dtypes.default_int if flat and all_int(flat) else dtypes.default_float)
if not dtypes.is_int(inferred): raise IndexError(f"{index=} contains non-int element")
index = self._wrap_uop(UOp._frompy([i+size if i<0 else i for i in flat], inferred, self.device)).reshape(get_shape(index))
elif is_adv(index):
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
if index.device is not None and self.device is not None and index.device != self.device:
raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
assert isinstance(size, int), "size must be an int"
index = (index < 0).where(index+size, index) # treat negative index values
else: parsed = self._parse_view_index(index, size)
indices_parsed.append({**parsed, "index":index})
if index is not None: dim += 1
# apply view ops then dim injection (None) and collapse (int)
x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self
x_dims = [p for p in indices_parsed if not p["collapse_dim"]]
x = x.reshape(tuple(p["size"] for p in x_dims))
# tensor indexing
if tops := [(d, p) for d, p in enumerate(x_dims) if is_adv(p['index'])]:
dims, tensors, masks = [d for d, _ in tops], [p['index'] for _, p in tops], []
big_shape = _broadcast_shape(*(t.shape for t in tensors))
# consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks
consecutive = dims == list(range(dims[0], dims[0] + len(dims)))
if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)):
strides = tuple(prod(ishp[i+1:]) for i in range(len(dims)))
try: linear_idx = type(self).usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)])
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
valid = type(self).uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)])
pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:]
x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)]
return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0)
pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:]
# create index masks
for dim, tensor in zip(dims, tensors):
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
# reduce masks to 1 mask
mask = type(self).uprod(*masks)
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
# sum reduce the extra dims introduced in create masks
x_pre = x # save collapsed shape for advanced setitem
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
# special permute case
if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))):
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
if v is None: return x # advanced getitem
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
start = dims[0] if not permuted else 0
vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape))))
elif v is None: return x # basic getitem
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
vb = vb.reshape(tuple(1 if p['collapse_dim'] else p['size'] for p in indices_parsed if p['index'] is not None))
per_dim = []
for d, m in enumerate(mops):
(s, e), st = m['boundary'], abs(m['stride'])
if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros
vb = vb.unsqueeze(d+1)
vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim)))
vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:])
vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim)))
idx = type(self).arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1))
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def invalids(cls, *shape, **kwargs) -> Self:
"""

View File

@@ -1,12 +1,12 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
@@ -43,23 +43,6 @@ def _fromnp(x: 'numpy.ndarray') -> UOp:
ret.buffer.allocate(x)
return ret.reshape(x.shape)
def get_shape(x) -> tuple[int, ...]:
# NOTE: str is special because iterating it still yields strs
if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return ()
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
else:
ret = UOp.empty(shape:=get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"{prod(shape)}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)])
# fake realize. if target device is PYTHON it needs bytearray to be writable
ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data)))
return ret
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], dtype:DType) -> list[list[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), dtype=dtype, buffer=False) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
@@ -113,13 +96,12 @@ class Tensor(OpMixin):
data = UOp.const(_dtype or dtypes.default_float, 0)
elif isinstance(data, get_args(ConstType)):
data = UOp.const(_dtype or dtypes.from_py(data), data)
elif isinstance(data, bytes): data = _frompy(data, _dtype or dtypes.uint8, _device)
elif isinstance(data, bytes): data = UOp._frompy(data, _dtype or dtypes.uint8, _device)
elif isinstance(data, (list, tuple)):
if _dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): _dtype = dtypes.bool
else: _dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if _dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = _frompy(data, dtypes.float32, _device).cast(_dtype)
else: data = _frompy(data, _dtype, _device)
data = UOp._frompy(data, _dtype, _device)
elif is_numpy_ndarray(data):
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
@@ -838,96 +820,6 @@ class Tensor(OpMixin):
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
# view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__
if v is None and not any(isinstance(i, (Tensor, list, tuple)) for i in (indices if isinstance(indices, tuple) else (indices,))):
return super().__getitem__(indices)
# wrap single index into a list
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
indices_parsed, dim = [], 0
for index in self._normalize_indices(list(indices)):
size = 1 if index is None else self.shape[dim]
parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False}
match index:
case Tensor():
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
if index.device is not None and self.device is not None and index.device != self.device:
raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
assert isinstance(size, int), "size must be an int"
index = (index < 0).where(index+size, index) # treat negative index values
case list() | tuple():
if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element")
index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device).reshape(ti.shape)
case _: parsed = self._parse_view_index(index, size)
indices_parsed.append({**parsed, "index":index})
if index is not None: dim += 1
# apply view ops then dim injection (None) and collapse (int)
x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self
x_dims = [p for p in indices_parsed if not p["collapse_dim"]]
x = x.reshape(tuple(p["size"] for p in x_dims))
# tensor indexing
if tops := [(d, p) for d, p in enumerate(x_dims) if isinstance(p['index'], Tensor)]:
dims, tensors, masks = [d for d, _ in tops], cast(list[Tensor], [p['index'] for _, p in tops]), []
big_shape = _broadcast_shape(*(t.shape for t in tensors))
# consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks
consecutive = dims == list(range(dims[0], dims[0] + len(dims)))
if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)):
strides = tuple(prod(ishp[i+1:]) for i in range(len(dims)))
try: linear_idx = Tensor.usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)])
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
valid = Tensor.uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)])
pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:]
x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)]
return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0)
pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:]
# create index masks
for dim, tensor in zip(dims, tensors):
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
# reduce masks to 1 mask
mask: Tensor = Tensor.uprod(*masks)
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
# sum reduce the extra dims introduced in create masks
x_pre = x # save collapsed shape for advanced setitem
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
# special permute case
if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))):
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
if v is None: return x # advanced getitem
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
start = dims[0] if not permuted else 0
vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape))))
elif v is None: return x # basic getitem
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None))
per_dim = []
for d, m in enumerate(mops):
(s, e), st = m['boundary'], abs(m['stride'])
if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros
vb = vb.unsqueeze(d+1)
vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim)))
vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:])
vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim)))
idx = Tensor.arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1))
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (Tensor.uprod(*per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self)
def __getitem__(self, indices) -> Tensor:
"""
Retrieves a sub-tensor using indexing.
@@ -966,7 +858,7 @@ class Tensor(OpMixin):
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return self._getitem(indices)
return super().__getitem__(indices)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storag
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST, get_shape, fully_flatten
from tinygrad.helpers import colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.renderer import Estimates
@@ -744,6 +744,20 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
device = canonicalize_device(self.device if device is None else device)
axis = self.axis if isinstance(device, tuple) else None
return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis)
@staticmethod
def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str, ...]|None=None) -> UOp:
device = canonicalize_device(device)
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
else:
# bfloat16 and fp8 have no struct format, so pack a float32 buffer and cast
bdtype = dtypes.float32 if dtype in [dtypes.bfloat16, *dtypes.fp8s] else dtype
assert bdtype.fmt is not None, f"{bdtype=} has None fmt"
ret = UOp.empty(shape:=get_shape(x), bdtype, "PYTHON")
data = struct.pack(f"{prod(shape)}{bdtype.fmt}", *[truncate[bdtype](bdtype.const(xi)) for xi in fully_flatten(x)])
# fake realize. if target device is PYTHON it needs bytearray to be writable
ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data)))
if ret.dtype != dtype: ret = ret.cast(dtype)
return ret if ret.device == device else ret.copy_to_device(device)
def clone(self, device=None) -> UOp:
device = device or self.device
ret = self.empty_like(device=device)