mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
use more non-buffered const (#16327)
This commit is contained in:
@@ -30,7 +30,7 @@ def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
|
||||
vals = Tensor.arange(n, device=x.device).reshape(1,1,n).cast(x.dtype).expand(x.shape)
|
||||
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
|
||||
(Tensor.arange(n, device=x.device).reshape(1,1,n,1) < Tensor.arange(n, device=x.device).reshape(1,1,1,n)))
|
||||
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
|
||||
sel = x.const_like(0).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
|
||||
return x.gather(-1, sel), sel
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -177,7 +177,8 @@ class TransformerBlock(FFNBlock):
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
|
||||
if resolve(T != 1) else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
|
||||
@@ -222,7 +223,8 @@ class MLATransformerBlock(FFNBlock):
|
||||
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = k[..., :self.config.kv_lora_rank]
|
||||
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
|
||||
if resolve(T != 1) else None
|
||||
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
|
||||
if mask is not None: attn = attn + mask
|
||||
attn = attn.softmax(-1)
|
||||
@@ -282,8 +284,8 @@ class GatedDeltaNetBlock(FFNBlock):
|
||||
|
||||
# recurrent state can't be partially reused after divergence, force a full rebuild
|
||||
def _state_reset_ops(self):
|
||||
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
|
||||
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
|
||||
return [self.conv_state.assign(self.conv_state.const_like(0)),
|
||||
self.recurrent_state.assign(self.recurrent_state.const_like(0))] if hasattr(self, "conv_state") else []
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
|
||||
|
||||
def _init_state(self, x):
|
||||
|
||||
@@ -118,7 +118,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
lo, hi = (start, stop-step) if step > 0 else (stop-step, start)
|
||||
if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, **kwargs)
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False, **kwargs)
|
||||
return (cls.full((output_len,), step, dtype=dtype, buffer=False, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
@classmethod
|
||||
@@ -138,7 +138,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
"""
|
||||
if steps < 0: raise ValueError("number of steps must be non-negative")
|
||||
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
|
||||
if steps == 1: return cls.full((1,), start, dtype=dtype, **kwargs)
|
||||
if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False, **kwargs)
|
||||
return (start + cls.arange(steps, dtype=dtypes.default_float, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
|
||||
@classmethod
|
||||
@@ -186,7 +186,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
print(t.triu(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.zeros_like())
|
||||
return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.const_like(0))
|
||||
|
||||
def tril(self, diagonal:sint=0) -> Self:
|
||||
"""
|
||||
@@ -209,7 +209,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
print(t.tril(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.zeros_like(), self)
|
||||
return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.const_like(0), self)
|
||||
|
||||
# ***** random *****
|
||||
|
||||
@@ -238,7 +238,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
_, nmant = dtypes.finfo(dtype)
|
||||
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
|
||||
uint_bits = bits.bitcast(uint_dtype)
|
||||
float_one_bits = uint_bits.ones_like(dtype=dtype).bitcast(uint_dtype)
|
||||
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
|
||||
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
|
||||
|
||||
def _pad_constant(self, pX, value:ConstType) -> Self:
|
||||
@@ -250,7 +250,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
base = MovementMixin.pad(X, pads)
|
||||
if value == 0: return base
|
||||
base = base.cast(least_upper_dtype(base.dtype, dtypes.from_py(value)))
|
||||
return MovementMixin.pad(X.ones_like(dtype=dtypes.bool), pads).where(base, base.full_like(value))
|
||||
return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value))
|
||||
|
||||
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self:
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
@@ -715,7 +715,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
print(indices.numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device)
|
||||
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device, buffer=False)
|
||||
values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis])
|
||||
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
|
||||
match = x.unsqueeze(-1).eq(values_t.unsqueeze(-2)) * type(self).ones(n, n, device=self.device, buffer=False).triu()
|
||||
@@ -838,7 +838,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
```
|
||||
"""
|
||||
x, dim = self, self._resolve_dim(dim)
|
||||
if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
|
||||
if (orig_len := int(x.shape[dim])) <= 1: return x, x.const_like(0).cast(dtypes.default_int)
|
||||
# pad to power of 2
|
||||
n_stages = (orig_len-1).bit_length()
|
||||
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
|
||||
@@ -1174,9 +1174,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
reg_pads = resolve_pool_pads(padding, len(k_))
|
||||
pads = self._apply_ceil_mode(reg_pads, k_, s_, dilation) if ceil_mode else reg_pads
|
||||
if not count_include_pad:
|
||||
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
|
||||
return pool(self, pads).sum(axis) / pool(self.const_like(1), pads).sum(axis)
|
||||
if not ceil_mode: return pool(self, pads).mean(axis)
|
||||
return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).ones_like(),
|
||||
return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).const_like(1),
|
||||
tuple(cp-rp for cp,rp in zip(pads, reg_pads))).sum(axis)
|
||||
|
||||
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
@@ -1407,7 +1407,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
if Y.device is not None and self.device is not None and Y.device != self.device:
|
||||
raise RuntimeError(f"expected Y and self on the same device, {Y.device=}, {self.device=}")
|
||||
log_probs = self.log_softmax()
|
||||
loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
||||
loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.const_like(1).cast(dtypes.bool)
|
||||
y = Y.unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
||||
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
||||
@@ -1459,7 +1459,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
|
||||
```
|
||||
"""
|
||||
weight = Y.ones_like() if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape)
|
||||
weight = Y.const_like(1) if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape)
|
||||
masked_weight = weight if ignore_index is None else weight * Y.ne(ignore_index)
|
||||
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
||||
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
||||
|
||||
@@ -411,7 +411,7 @@ class LSTMCell:
|
||||
self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
|
||||
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
||||
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device, buffer=False),)*2
|
||||
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
|
||||
i, f, g, o = gates.chunk(4, dim=1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
|
||||
@@ -785,9 +785,9 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def _apply_transformation(input_sz, output_sz, scale_dim, mode):
|
||||
index = Tensor.arange(output_sz, device=X.device)
|
||||
if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5
|
||||
if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index)
|
||||
if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else index.const_like(0)
|
||||
if mode == "asymmetric": return index / scale_dim
|
||||
if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index)
|
||||
if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else index.const_like(0)
|
||||
if mode == "half_pixel_symmetric":
|
||||
output_dim_scaled = input_sz * scale_dim
|
||||
return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5
|
||||
@@ -964,7 +964,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
import numpy as np
|
||||
if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
|
||||
if not training_mode: return data, data.const_like(True).cast(dtypes.bool)
|
||||
if seed is not None:
|
||||
rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), dtype=data.dtype, device=data.device)
|
||||
else:
|
||||
@@ -1043,7 +1043,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
attn_scores = mask.where(attn_scores, mask_filter_value)
|
||||
|
||||
if unidirectional:
|
||||
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril()
|
||||
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device, buffer=False).tril()
|
||||
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
|
||||
|
||||
output = attn_scores.softmax(-1) @ v
|
||||
@@ -1075,7 +1075,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
qk_matmul_return_val = scores
|
||||
|
||||
if is_causal:
|
||||
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool).tril(0)
|
||||
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, buffer=False).tril(0)
|
||||
scores = scores.masked_fill(causal_mask.logical_not(), -float("inf"))
|
||||
|
||||
if attn_mask is not None:
|
||||
|
||||
@@ -62,7 +62,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
|
||||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...]|None, dtype:DType) -> list[list[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype, buffer=False) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
@@ -1063,9 +1063,9 @@ class Tensor(OpMixin):
|
||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
if size is None:
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device)
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False)
|
||||
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
||||
counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||
|
||||
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
|
||||
@@ -1141,7 +1141,7 @@ class Tensor(OpMixin):
|
||||
|
||||
data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
|
||||
|
||||
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64)
|
||||
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False)
|
||||
for k in range(int(data.shape[1])):
|
||||
state = state ^ data[:, k]
|
||||
for i in range(24): # f1600
|
||||
@@ -1348,7 +1348,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
||||
if not Tensor.training or p == 0: return self
|
||||
if p == 1: return self.zeros_like()
|
||||
if p == 1: return self.const_like(0)
|
||||
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||
@@ -1376,7 +1376,7 @@ class Tensor(OpMixin):
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = qk.ones_like(dtype=dtypes.bool).tril()
|
||||
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
qk = qk + attn_mask
|
||||
|
||||
Reference in New Issue
Block a user