mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-09 14:35:33 +08:00
advance indexing to mixin [PR] (#16532)
This commit is contained in:
@@ -2,8 +2,8 @@ import ctypes, gzip, unittest, timeit, pickle
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count, all_same
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.helpers import ceildiv, ansistrip
|
||||
from tinygrad.tensor import Tensor, get_shape
|
||||
from tinygrad.helpers import ceildiv, ansistrip, get_shape
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -10,6 +10,9 @@ def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
|
||||
def _t(*shape):
|
||||
return Tensor.arange(math.prod(shape)).reshape(*shape)
|
||||
|
||||
def _ti(data): # int32 index tensor
|
||||
return Tensor(data, dtype=dtypes.int32)
|
||||
|
||||
# Tensor().func().uop should be the same as UOp.func()
|
||||
def _check(tc: unittest.TestCase, t: Tensor, fn):
|
||||
tc.assertIs(fn(t).uop, fn(t.uop), f"\ntensor.uop = {fn(t).uop}\nuop = {fn(t.uop)}")
|
||||
@@ -110,6 +113,44 @@ class TestTensorUOpGetitem(unittest.TestCase):
|
||||
def test_mixed_slice_slice(self): _check(self, _t(3, 4, 5), lambda x: x[1:3, :, 0:2])
|
||||
def test_high_rank_combo(self): _check(self, _t(4, 5, 6), lambda x: x[1:3, :, -1, None])
|
||||
|
||||
# ---- advanced indexing: UOp index on UOp must match Tensor index on Tensor (same uop) ----
|
||||
def _check_adv(self, t, idx):
|
||||
# idx is a Tensor or a tuple mixing Tensor index arrays with slice/int/None/Ellipsis
|
||||
ui = idx.uop if isinstance(idx, Tensor) else (tuple(i.uop if isinstance(i,Tensor) else i for i in idx) if isinstance(idx, tuple) else idx)
|
||||
self.assertIs(t[idx].uop, t.uop[ui])
|
||||
|
||||
def test_adv_single(self): self._check_adv(_t(5), _ti([2,1,0,1,2]))
|
||||
def test_adv_negative(self): self._check_adv(_t(5), _ti([-1,-2,0]))
|
||||
def test_adv_out_of_bounds(self): self._check_adv(_t(5), _ti([4,7,2])) # oob -> 0
|
||||
def test_adv_2d_dim0(self): self._check_adv(_t(3,4), _ti([2,0,1]))
|
||||
def test_adv_2d_index_array(self): self._check_adv(_t(4,5), _ti([[0,1],[2,3]])) # shaped index
|
||||
def test_adv_two_consecutive(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([1,2,3]))) # linear path
|
||||
def test_adv_two_broadcast(self): self._check_adv(_t(3,4), (_ti([2,0,1]), _ti([[1],[2],[3]])))
|
||||
def test_adv_three_consecutive(self):self._check_adv(_t(3,4,5), (_ti([2,0]), _ti([1,3]), _ti([4,2])))
|
||||
def test_adv_after_slice(self): self._check_adv(_t(2,3,4), (slice(None), _ti([2,0,1])))
|
||||
def test_adv_non_consec_permute(self): self._check_adv(_t(2,3,4), (_ti([1,0]), slice(None), _ti([2,0])))
|
||||
def test_adv_idx_then_int(self): self._check_adv(_t(3,4), (_ti([2,0,1]), 2))
|
||||
def test_adv_int_then_idx(self): self._check_adv(_t(3,4), (1, _ti([2,0,1])))
|
||||
def test_adv_idx_then_none(self): self._check_adv(_t(3,4), (_ti([2,0,1]), None))
|
||||
def test_adv_none_then_idx(self): self._check_adv(_t(3,4), (None, _ti([2,0,1])))
|
||||
def test_adv_ellipsis_then_idx(self):self._check_adv(_t(2,3,4), (Ellipsis, _ti([2,0,1])))
|
||||
def test_adv_idx_slice_mix(self): self._check_adv(_t(4,5,6), (_ti([1,3]), slice(1,4), _ti([2,0])))
|
||||
|
||||
# bool index is unsupported
|
||||
def test_adv_bool_index_rejected(self):
|
||||
with self.assertRaises(IndexError): _t(5)[_t(5) > 2]
|
||||
with self.assertRaises(IndexError): _t(5).uop[(_t(5) > 2).uop]
|
||||
|
||||
# python list/tuple indices
|
||||
def test_adv_python_list(self):
|
||||
self.assertIs(_strip_unique(_t(5)[[2,1,0]].uop), _strip_unique(_t(5).uop[[2,1,0]]))
|
||||
def test_adv_python_list_negative(self):
|
||||
self.assertIs(_strip_unique(_t(5)[[-1,-2,0]].uop), _strip_unique(_t(5).uop[[-1,-2,0]]))
|
||||
def test_adv_python_list_nested(self):
|
||||
self.assertIs(_strip_unique(_t(3,4)[[0,1],[2,0]].uop), _strip_unique(_t(3,4).uop[[0,1],[2,0]]))
|
||||
def test_adv_python_list_2d_index(self):
|
||||
self.assertIs(_strip_unique(_t(4,5)[[[0,1],[2,3]]].uop), _strip_unique(_t(4,5).uop[[[0,1],[2,3]]]))
|
||||
|
||||
class TestTensorUOpCumalu(unittest.TestCase):
|
||||
def test_cumsum_1d(self): _check(self, _t(5), lambda x: x.cumsum())
|
||||
def test_cumsum_2d(self): _check(self, _t(3, 4), lambda x: x.cumsum(1))
|
||||
|
||||
@@ -29,6 +29,11 @@ def argfix(*x):
|
||||
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
|
||||
def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input
|
||||
def get_shape(x) -> tuple[int, ...]:
|
||||
# NOTE: str is special because iterating it still yields strs
|
||||
if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return ()
|
||||
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color:str|None, background=False): # replace the termcolor library
|
||||
if NO_COLOR: return st
|
||||
|
||||
@@ -7,7 +7,8 @@ from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
|
||||
from tinygrad.helpers import resolve_pool_pads, round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import sint, UOp
|
||||
@@ -50,6 +51,102 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||
return val.clone(device=device) if buffer else val
|
||||
|
||||
def __getitem__(self, indices) -> Self: return self._getitem(indices)
|
||||
|
||||
def _getitem(self, indices, v=None) -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
def is_adv(i): return isinstance(i,(list,tuple)) or (isinstance(i,type(self)) and (not isinstance(i,UOp) or i.shape != ()))
|
||||
# view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__
|
||||
if v is None and not any(is_adv(i) for i in (indices if isinstance(indices,tuple) else (indices,))):
|
||||
return super().__getitem__(indices)
|
||||
# wrap single index into a list
|
||||
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
||||
indices_parsed, dim = [], 0
|
||||
for index in self._normalize_indices(list(indices)):
|
||||
size = 1 if index is None else self.shape[dim]
|
||||
parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False}
|
||||
if isinstance(index,(list,tuple)):
|
||||
flat = fully_flatten(index)
|
||||
inferred = dtypes.bool if (flat and all(isinstance(s,bool) for s in flat)) else \
|
||||
(dtypes.default_int if flat and all_int(flat) else dtypes.default_float)
|
||||
if not dtypes.is_int(inferred): raise IndexError(f"{index=} contains non-int element")
|
||||
index = self._wrap_uop(UOp._frompy([i+size if i<0 else i for i in flat], inferred, self.device)).reshape(get_shape(index))
|
||||
elif is_adv(index):
|
||||
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
||||
if index.device is not None and self.device is not None and index.device != self.device:
|
||||
raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
|
||||
assert isinstance(size, int), "size must be an int"
|
||||
index = (index < 0).where(index+size, index) # treat negative index values
|
||||
else: parsed = self._parse_view_index(index, size)
|
||||
indices_parsed.append({**parsed, "index":index})
|
||||
if index is not None: dim += 1
|
||||
|
||||
# apply view ops then dim injection (None) and collapse (int)
|
||||
x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self
|
||||
x_dims = [p for p in indices_parsed if not p["collapse_dim"]]
|
||||
x = x.reshape(tuple(p["size"] for p in x_dims))
|
||||
|
||||
# tensor indexing
|
||||
if tops := [(d, p) for d, p in enumerate(x_dims) if is_adv(p['index'])]:
|
||||
dims, tensors, masks = [d for d, _ in tops], [p['index'] for _, p in tops], []
|
||||
big_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||
|
||||
# consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks
|
||||
consecutive = dims == list(range(dims[0], dims[0] + len(dims)))
|
||||
if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)):
|
||||
strides = tuple(prod(ishp[i+1:]) for i in range(len(dims)))
|
||||
try: linear_idx = type(self).usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)])
|
||||
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
|
||||
valid = type(self).uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)])
|
||||
pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:]
|
||||
x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)]
|
||||
return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0)
|
||||
|
||||
pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:]
|
||||
|
||||
# create index masks
|
||||
for dim, tensor in zip(dims, tensors):
|
||||
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
|
||||
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
|
||||
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
|
||||
|
||||
# reduce masks to 1 mask
|
||||
mask = type(self).uprod(*masks)
|
||||
|
||||
# inject 1's for the extra dims added in create masks
|
||||
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
x_pre = x # save collapsed shape for advanced setitem
|
||||
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
|
||||
|
||||
# special permute case
|
||||
if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))):
|
||||
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
|
||||
|
||||
if v is None: return x # advanced getitem
|
||||
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
|
||||
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
||||
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
|
||||
start = dims[0] if not permuted else 0
|
||||
vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape))))
|
||||
elif v is None: return x # basic getitem
|
||||
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
|
||||
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
vb = vb.reshape(tuple(1 if p['collapse_dim'] else p['size'] for p in indices_parsed if p['index'] is not None))
|
||||
per_dim = []
|
||||
for d, m in enumerate(mops):
|
||||
(s, e), st = m['boundary'], abs(m['stride'])
|
||||
if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros
|
||||
vb = vb.unsqueeze(d+1)
|
||||
vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim)))
|
||||
vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:])
|
||||
vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim)))
|
||||
idx = type(self).arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1))
|
||||
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
|
||||
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
|
||||
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
|
||||
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
|
||||
|
||||
@classmethod
|
||||
def invalids(cls, *shape, **kwargs) -> Self:
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref
|
||||
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
||||
@@ -43,23 +43,6 @@ def _fromnp(x: 'numpy.ndarray') -> UOp:
|
||||
ret.buffer.allocate(x)
|
||||
return ret.reshape(x.shape)
|
||||
|
||||
def get_shape(x) -> tuple[int, ...]:
|
||||
# NOTE: str is special because iterating it still yields strs
|
||||
if not hasattr(x, "__len__") or isinstance(x, str) or getattr(x, "shape", None) == (): return ()
|
||||
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
|
||||
def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
|
||||
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
|
||||
else:
|
||||
ret = UOp.empty(shape:=get_shape(x), dtype, "PYTHON")
|
||||
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"{prod(shape)}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)])
|
||||
# fake realize. if target device is PYTHON it needs bytearray to be writable
|
||||
ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data)))
|
||||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], dtype:DType) -> list[list[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), dtype=dtype, buffer=False) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
@@ -113,13 +96,12 @@ class Tensor(OpMixin):
|
||||
data = UOp.const(_dtype or dtypes.default_float, 0)
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
data = UOp.const(_dtype or dtypes.from_py(data), data)
|
||||
elif isinstance(data, bytes): data = _frompy(data, _dtype or dtypes.uint8, _device)
|
||||
elif isinstance(data, bytes): data = UOp._frompy(data, _dtype or dtypes.uint8, _device)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if _dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): _dtype = dtypes.bool
|
||||
else: _dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
||||
if _dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = _frompy(data, dtypes.float32, _device).cast(_dtype)
|
||||
else: data = _frompy(data, _dtype, _device)
|
||||
data = UOp._frompy(data, _dtype, _device)
|
||||
elif is_numpy_ndarray(data):
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
@@ -838,96 +820,6 @@ class Tensor(OpMixin):
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
|
||||
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
||||
# view-only indexing (no Tensor/list indices, no setitem) is handled by MovementMixin.__getitem__
|
||||
if v is None and not any(isinstance(i, (Tensor, list, tuple)) for i in (indices if isinstance(indices, tuple) else (indices,))):
|
||||
return super().__getitem__(indices)
|
||||
# wrap single index into a list
|
||||
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
||||
indices_parsed, dim = [], 0
|
||||
for index in self._normalize_indices(list(indices)):
|
||||
size = 1 if index is None else self.shape[dim]
|
||||
parsed = {"size":size, "boundary":(0, size), "stride":1, "collapse_dim":False}
|
||||
match index:
|
||||
case Tensor():
|
||||
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
||||
if index.device is not None and self.device is not None and index.device != self.device:
|
||||
raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
|
||||
assert isinstance(size, int), "size must be an int"
|
||||
index = (index < 0).where(index+size, index) # treat negative index values
|
||||
case list() | tuple():
|
||||
if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element")
|
||||
index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device).reshape(ti.shape)
|
||||
case _: parsed = self._parse_view_index(index, size)
|
||||
indices_parsed.append({**parsed, "index":index})
|
||||
if index is not None: dim += 1
|
||||
|
||||
# apply view ops then dim injection (None) and collapse (int)
|
||||
x = self._apply_view_ops(mops) if (mops := [p for p in indices_parsed if p["index"] is not None]) else self
|
||||
x_dims = [p for p in indices_parsed if not p["collapse_dim"]]
|
||||
x = x.reshape(tuple(p["size"] for p in x_dims))
|
||||
|
||||
# tensor indexing
|
||||
if tops := [(d, p) for d, p in enumerate(x_dims) if isinstance(p['index'], Tensor)]:
|
||||
dims, tensors, masks = [d for d, _ in tops], cast(list[Tensor], [p['index'] for _, p in tops]), []
|
||||
big_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||
|
||||
# consecutive tensor indices with int shapes: use linear indexing instead of one-hot masks
|
||||
consecutive = dims == list(range(dims[0], dims[0] + len(dims)))
|
||||
if v is None and len(dims) > 1 and consecutive and all_int(ishp := tuple(x.shape[d] for d in dims)):
|
||||
strides = tuple(prod(ishp[i+1:]) for i in range(len(dims)))
|
||||
try: linear_idx = Tensor.usum(*[t._broadcast_to(big_shape) * s for t, s in zip(tensors, strides)])
|
||||
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
|
||||
valid = Tensor.uprod(*[(t >= 0) & (t < s) for t, s in zip(tensors, ishp)])
|
||||
pre, post = x.shape[:dims[0]], x.shape[dims[-1]+1:]
|
||||
x = x.reshape(pre + (prod(ishp),) + post)[tuple([slice(None)] * len(pre)) + (valid.where(linear_idx, 0),)]
|
||||
return valid.reshape((1,) * len(pre) + big_shape + (1,) * len(post)).where(x, 0)
|
||||
|
||||
pre_reduce_shape = x.shape[:dims[0]] + big_shape + x.shape[dims[0]:]
|
||||
|
||||
# create index masks
|
||||
for dim, tensor in zip(dims, tensors):
|
||||
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
|
||||
except ValueError as err: raise IndexError(f"cannot broadcast indices: {err}") from err
|
||||
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
|
||||
|
||||
# reduce masks to 1 mask
|
||||
mask: Tensor = Tensor.uprod(*masks)
|
||||
|
||||
# inject 1's for the extra dims added in create masks
|
||||
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
x_pre = x # save collapsed shape for advanced setitem
|
||||
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
|
||||
|
||||
# special permute case
|
||||
if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))):
|
||||
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
|
||||
|
||||
if v is None: return x # advanced getitem
|
||||
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
|
||||
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
||||
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
|
||||
start = dims[0] if not permuted else 0
|
||||
vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape))))
|
||||
elif v is None: return x # basic getitem
|
||||
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
|
||||
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None))
|
||||
per_dim = []
|
||||
for d, m in enumerate(mops):
|
||||
(s, e), st = m['boundary'], abs(m['stride'])
|
||||
if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros
|
||||
vb = vb.unsqueeze(d+1)
|
||||
vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim)))
|
||||
vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:])
|
||||
vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim)))
|
||||
idx = Tensor.arange(self.shape[d]).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1))
|
||||
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
|
||||
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
|
||||
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
|
||||
return (Tensor.uprod(*per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self)
|
||||
|
||||
def __getitem__(self, indices) -> Tensor:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
@@ -966,7 +858,7 @@ class Tensor(OpMixin):
|
||||
print(t[Tensor([4, 3, 2])].numpy())
|
||||
```
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
return super().__getitem__(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storag
|
||||
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST
|
||||
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST, get_shape, fully_flatten
|
||||
from tinygrad.helpers import colored, ansilen, printable
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.renderer import Estimates
|
||||
@@ -744,6 +744,20 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
device = canonicalize_device(self.device if device is None else device)
|
||||
axis = self.axis if isinstance(device, tuple) else None
|
||||
return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis)
|
||||
@staticmethod
|
||||
def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str, ...]|None=None) -> UOp:
|
||||
device = canonicalize_device(device)
|
||||
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
|
||||
else:
|
||||
# bfloat16 and fp8 have no struct format, so pack a float32 buffer and cast
|
||||
bdtype = dtypes.float32 if dtype in [dtypes.bfloat16, *dtypes.fp8s] else dtype
|
||||
assert bdtype.fmt is not None, f"{bdtype=} has None fmt"
|
||||
ret = UOp.empty(shape:=get_shape(x), bdtype, "PYTHON")
|
||||
data = struct.pack(f"{prod(shape)}{bdtype.fmt}", *[truncate[bdtype](bdtype.const(xi)) for xi in fully_flatten(x)])
|
||||
# fake realize. if target device is PYTHON it needs bytearray to be writable
|
||||
ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data)))
|
||||
if ret.dtype != dtype: ret = ret.cast(dtype)
|
||||
return ret if ret.device == device else ret.copy_to_device(device)
|
||||
def clone(self, device=None) -> UOp:
|
||||
device = device or self.device
|
||||
ret = self.empty_like(device=device)
|
||||
|
||||
Reference in New Issue
Block a user