From ebc5390c9a455d3d20da72d2bdb360a5809f3a56 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 8 Jun 2026 09:24:49 -0400 Subject: [PATCH] advance indexing to mixin [PR] (#16532) --- test/null/test_helpers.py | 4 +- test/null/test_tensor_uop_mixin.py | 41 ++++++++++ tinygrad/helpers.py | 5 ++ tinygrad/mixin/__init__.py | 99 +++++++++++++++++++++++- tinygrad/tensor.py | 120 ++--------------------------- tinygrad/uop/ops.py | 16 +++- 6 files changed, 167 insertions(+), 118 deletions(-) diff --git a/test/null/test_helpers.py b/test/null/test_helpers.py index ff7adc64cc..44c49e8891 100644 --- a/test/null/test_helpers.py +++ b/test/null/test_helpers.py @@ -2,8 +2,8 @@ import ctypes, gzip, unittest, timeit, pickle from tinygrad import Variable from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count, all_same from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits -from tinygrad.helpers import ceildiv, ansistrip -from tinygrad.tensor import Tensor, get_shape +from tinygrad.helpers import ceildiv, ansistrip, get_shape +from tinygrad.tensor import Tensor import numpy as np VARIABLE = ContextVar("VARIABLE", 0) diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 9c33622b4a..02871b7945 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -10,6 +10,9 @@ def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm) def _t(*shape): return Tensor.arange(math.prod(shape)).reshape(*shape) +def _ti(data): # int32 index tensor + return Tensor(data, dtype=dtypes.int32) + # Tensor().func().uop should be the same as UOp.func() def _check(tc: unittest.TestCase, t: Tensor, fn): tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}") @@ -110,6 +113,44 @@ class TestTensorUOpGetitem(unittest.TestCase): def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2]) def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None]) + # ---- advanced indexing: UOp index on UOp must match Tensor index on Tensor (same uop) ---- + def _check_adv(self, t, idx): + # idx is a Tensor or a tuple mixing Tensor index arrays with slice/int/None/Ellipsis + ui = idx.uop if isinstance(idx, Tensor) else (tuple(i.uop if isinstance(i,Tensor) else i for i in idx) if isinstance(idx, tuple) else idx) + self.assertIs(t[idx].uop, t.uop[ui]) + + def test_adv_single(self): self._check_adv(_t(5), _ti([2,1,0,1,2])) + def test_adv_negative(self): self._check_adv(_t(5), _ti([-1,-2,0])) + def test_adv_out_of_bounds(self): self._check_adv(_t(5), _ti([4,7,2])) # oob -> 0 + def test_adv_2d_dim0(self): self._check_adv(_t(3,4), _ti([2,0,1])) + def test_adv_2d_index_array(self): self._check_adv(_t(4,5), _ti([[0,1],[2,3]])) # shaped index + def test_adv_two_consecutive(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([1,2,3]))) # linear path + def test_adv_two_broadcast(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([[1],[2],[3]]))) + def test_adv_three_consecutive(self):self._check_adv(_t(3,4,5), (_ti([2,0]), _ti([1,3]), _ti([4,2]))) + def test_adv_after_slice(self): self._check_adv(_t(2,3,4), (slice(None), _ti([2,0,1]))) + def test_adv_non_consec_permute(self): self._check_adv(_t(2,3,4), (_ti([1,0]), slice(None), _ti([2,0]))) + def test_adv_idx_then_int(self): self._check_adv(_t(3,4), (_ti([2,0,1]), 2)) + def test_adv_int_then_idx(self): self._check_adv(_t(3,4), (1, _ti([2,0,1]))) + def test_adv_idx_then_none(self): self._check_adv(_t(3,4), (_ti([2,0,1]), None)) + def test_adv_none_then_idx(self): self._check_adv(_t(3,4), (None, _ti([2,0,1]))) + def test_adv_ellipsis_then_idx(self):self._check_adv(_t(2,3,4), (Ellipsis, _ti([2,0,1]))) + def test_adv_idx_slice_mix(self): self._check_adv(_t(4,5,6), (_ti([1,3]), slice(1,4), _ti([2,0]))) + + # bool index is unsupported + def test_adv_bool_index_rejected(self): + with self.assertRaises(IndexError): _t(5)[_t(5) > 2] + with self.assertRaises(IndexError): _t(5).uop[(_t(5) > 2).uop] + + # python list/tuple indices + def test_adv_python_list(self): + self.assertIs(_strip_unique(_t(5)[[2,1,0]].uop), _strip_unique(_t(5).uop[[2,1,0]])) + def test_adv_python_list_negative(self): + self.assertIs(_strip_unique(_t(5)[[-1,-2,0]].uop), _strip_unique(_t(5).uop[[-1,-2,0]])) + def test_adv_python_list_nested(self): + self.assertIs(_strip_unique(_t(3,4)[[0,1],[2,0]].uop), _strip_unique(_t(3,4).uop[[0,1],[2,0]])) + def test_adv_python_list_2d_index(self): + self.assertIs(_strip_unique(_t(4,5)[[[0,1],[2,3]]].uop), _strip_unique(_t(4,5).uop[[[0,1],[2,3]]])) + class TestTensorUOpCumalu(unittest.TestCase): def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum()) def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7b515e9421..38333a9201 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -29,6 +29,11 @@ def argfix(*x): # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input +def get_shape(x) -> tuple[int, ...]: + # NOTE: str is special because iterating it still yields strs + if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return () + if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}") + return (len(subs),) + (subs[0] if subs else ()) def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t) def colored(st, color:str|None, background=False): # replace the termcolor library if NO_COLOR: return st diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index af1cdbfbe8..37ded556bc 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -7,7 +7,8 @@ from tinygrad.mixin.reduce import ReduceMixin from tinygrad.uop import Ops from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype -from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up +from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod +from tinygrad.helpers import resolve_pool_pads, round_up if TYPE_CHECKING: from tinygrad.uop.ops import sint, UOp @@ -50,6 +51,102 @@ class OpMixin(ElementwiseMixin, ReduceMixin): val = val.reshape((1,)*len(new_shape)).expand(new_shape) return val.clone(device=device) if buffer else val + def __getitem__(self, indices) -> Self: return self._getitem(indices) + + def _getitem(self, indices, v=None) -> Self: + from tinygrad.uop.ops import UOp + def is_adv(i): return isinstance(i,(list,tuple)) or (isinstance(i,type(self)) and (not isinstance(i,UOp) or i.shape != ())) + # view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__ + if v is None and not any(is_adv(i) for i in (indices if isinstance(indices,tuple) else (indices,))): + return super().__getitem__(indices) + # wrap single index into a list + if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices] + indices_parsed, dim = [], 0 + for index in self._normalize_indices(list(indices)): + size = 1 if index is None else self.shape[dim] + parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False} + if isinstance(index,(list,tuple)): + flat = fully_flatten(index) + inferred = dtypes.bool if (flat and all(isinstance(s,bool) for s in flat)) else \ + (dtypes.default_int if flat and all_int(flat) else dtypes.default_float) + if not dtypes.is_int(inferred): raise IndexError(f"{index=} contains non-int element") + index = self._wrap_uop(UOp._frompy([i+size if i<0 else i for i in flat], inferred, self.device)).reshape(get_shape(index)) + elif is_adv(index): + if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") + if index.device is not None and self.device is not None and index.device != self.device: + raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}") + assert isinstance(size, int), "size must be an int" + index = (index < 0).where(index+size, index) # treat negative index values + else: parsed = self._parse_view_index(index, size) + indices_parsed.append({**parsed, "index":index}) + if index is not None: dim += 1 + + # apply view ops then dim injection (None) and collapse (int) + x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self + x_dims = [p for p in indices_parsed if not p["collapse_dim"]] + x = x.reshape(tuple(p["size"] for p in x_dims)) + + # tensor indexing + if tops := [(d, p) for d, p in enumerate(x_dims) if is_adv(p['index'])]: + dims, tensors, masks = [d for d, _ in tops], [p['index'] for _, p in tops], [] + big_shape = _broadcast_shape(*(t.shape for t in tensors)) + + # consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks + consecutive = dims == list(range(dims[0], dims[0] + len(dims))) + if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)): + strides = tuple(prod(ishp[i+1:]) for i in range(len(dims))) + try: linear_idx = type(self).usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)]) + except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err + valid = type(self).uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)]) + pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:] + x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)] + return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0) + + pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:] + + # create index masks + for dim, tensor in zip(dims, tensors): + try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape) + except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err + masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim))) + + # reduce masks to 1 mask + mask = type(self).uprod(*masks) + + # inject 1's for the extra dims added in create masks + reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] + # sum reduce the extra dims introduced in create masks + x_pre = x # save collapsed shape for advanced setitem + x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) + + # special permute case + if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))): + mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x)) + + if v is None: return x # advanced getitem + # advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path + vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) + for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum + start = dims[0] if not permuted else 0 + vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape)))) + elif v is None: return x # basic getitem + # basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims) + else: vb = v.cast(self.dtype)._broadcast_to(x.shape) + vb = vb.reshape(tuple(1 if p['collapse_dim'] else p['size'] for p in indices_parsed if p['index'] is not None)) + per_dim = [] + for d, m in enumerate(mops): + (s, e), st = m['boundary'], abs(m['stride']) + if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros + vb = vb.unsqueeze(d+1) + vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim))) + vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:]) + vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim))) + idx = type(self).arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1)) + per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0)) + vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0)) + vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) + return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self) + @classmethod def invalids(cls, *shape, **kwargs) -> Self: """ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 256a258b7e..44ee2bd45f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,12 +1,12 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref +import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref from contextlib import ContextDecorator from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING if TYPE_CHECKING: import numpy -from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate +from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid -from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped +from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape @@ -43,23 +43,6 @@ def _fromnp(x: 'numpy.ndarray') -> UOp: ret.buffer.allocate(x) return ret.reshape(x.shape) -def get_shape(x) -> tuple[int, ...]: - # NOTE: str is special because iterating it still yields strs - if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return () - if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}") - return (len(subs),) + (subs[0] if subs else ()) - -def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp: - if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x - else: - ret = UOp.empty(shape:=get_shape(x), dtype, "PYTHON") - assert dtype.fmt is not None, f"{dtype=} has None fmt" - truncate_function = truncate[dtype] - data = struct.pack(f"{prod(shape)}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)]) - # fake realize. if target device is PYTHON it needs bytearray to be writable - ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data))) - return ret - def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], dtype:DType) -> list[list[Tensor]]: return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), dtype=dtype, buffer=False) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] @@ -113,13 +96,12 @@ class Tensor(OpMixin): data = UOp.const(_dtype or dtypes.default_float, 0) elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data) - elif isinstance(data, bytes): data = _frompy(data, _dtype or dtypes.uint8, _device) + elif isinstance(data, bytes): data = UOp._frompy(data, _dtype or dtypes.uint8, _device) elif isinstance(data, (list, tuple)): if _dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): _dtype = dtypes.bool else: _dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True - if _dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = _frompy(data, dtypes.float32, _device).cast(_dtype) - else: data = _frompy(data, _dtype, _device) + data = UOp._frompy(data, _dtype, _device) elif is_numpy_ndarray(data): import numpy as np assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}" @@ -838,96 +820,6 @@ class Tensor(OpMixin): def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) - def _getitem(self, indices, v: Tensor|None = None) -> Tensor: - # view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__ - if v is None and not any(isinstance(i, (Tensor, list, tuple)) for i in (indices if isinstance(indices, tuple) else (indices,))): - return super().__getitem__(indices) - # wrap single index into a list - if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices] - indices_parsed, dim = [], 0 - for index in self._normalize_indices(list(indices)): - size = 1 if index is None else self.shape[dim] - parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False} - match index: - case Tensor(): - if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") - if index.device is not None and self.device is not None and index.device != self.device: - raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}") - assert isinstance(size, int), "size must be an int" - index = (index < 0).where(index+size, index) # treat negative index values - case list() | tuple(): - if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element") - index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device).reshape(ti.shape) - case _: parsed = self._parse_view_index(index, size) - indices_parsed.append({**parsed, "index":index}) - if index is not None: dim += 1 - - # apply view ops then dim injection (None) and collapse (int) - x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self - x_dims = [p for p in indices_parsed if not p["collapse_dim"]] - x = x.reshape(tuple(p["size"] for p in x_dims)) - - # tensor indexing - if tops := [(d, p) for d, p in enumerate(x_dims) if isinstance(p['index'], Tensor)]: - dims, tensors, masks = [d for d, _ in tops], cast(list[Tensor], [p['index'] for _, p in tops]), [] - big_shape = _broadcast_shape(*(t.shape for t in tensors)) - - # consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks - consecutive = dims == list(range(dims[0], dims[0] + len(dims))) - if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)): - strides = tuple(prod(ishp[i+1:]) for i in range(len(dims))) - try: linear_idx = Tensor.usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)]) - except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err - valid = Tensor.uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)]) - pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:] - x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)] - return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0) - - pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:] - - # create index masks - for dim, tensor in zip(dims, tensors): - try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape) - except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err - masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim))) - - # reduce masks to 1 mask - mask: Tensor = Tensor.uprod(*masks) - - # inject 1's for the extra dims added in create masks - reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] - # sum reduce the extra dims introduced in create masks - x_pre = x # save collapsed shape for advanced setitem - x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) - - # special permute case - if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))): - mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x)) - - if v is None: return x # advanced getitem - # advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path - vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) - for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum - start = dims[0] if not permuted else 0 - vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape)))) - elif v is None: return x # basic getitem - # basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims) - else: vb = v.cast(self.dtype)._broadcast_to(x.shape) - vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None)) - per_dim = [] - for d, m in enumerate(mops): - (s, e), st = m['boundary'], abs(m['stride']) - if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros - vb = vb.unsqueeze(d+1) - vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim))) - vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:]) - vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim))) - idx = Tensor.arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1)) - per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0)) - vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0)) - vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) - return (Tensor.uprod(*per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self) - def __getitem__(self, indices) -> Tensor: """ Retrieves a sub-tensor using indexing. @@ -966,7 +858,7 @@ class Tensor(OpMixin): print(t[Tensor([4, 3, 2])].numpy()) ``` """ - return self._getitem(indices) + return super().__getitem__(indices) def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 594b09d112..9b89cf88cd 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -9,7 +9,7 @@ from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storag from tinygrad.device import Buffer, MultiBuffer, canonicalize_device from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey -from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST +from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST, get_shape, fully_flatten from tinygrad.helpers import colored, ansilen, printable if TYPE_CHECKING: from tinygrad.renderer import Estimates @@ -744,6 +744,20 @@ class UOp(OpMixin, metaclass=UOpMetaClass): device = canonicalize_device(self.device if device is None else device) axis = self.axis if isinstance(device, tuple) else None return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis) + @staticmethod + def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str, ...]|None=None) -> UOp: + device = canonicalize_device(device) + if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x + else: + # bfloat16 and fp8 have no struct format, so pack a float32 buffer and cast + bdtype = dtypes.float32 if dtype in [dtypes.bfloat16, *dtypes.fp8s] else dtype + assert bdtype.fmt is not None, f"{bdtype=} has None fmt" + ret = UOp.empty(shape:=get_shape(x), bdtype, "PYTHON") + data = struct.pack(f"{prod(shape)}{bdtype.fmt}", *[truncate[bdtype](bdtype.const(xi)) for xi in fully_flatten(x)]) + # fake realize. if target device is PYTHON it needs bytearray to be writable + ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data))) + if ret.dtype != dtype: ret = ret.cast(dtype) + return ret if ret.device == device else ret.copy_to_device(device) def clone(self, device=None) -> UOp: device = device or self.device ret = self.empty_like(device=device)