use more non-buffered const (#16327)

This commit is contained in:
chenyu
2026-05-21 21:37:52 -04:00
committed by GitHub
parent 150a82de1f
commit 9744d512d9
5 changed files with 31 additions and 29 deletions

View File

@@ -30,7 +30,7 @@ def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
vals = Tensor.arange(n, device=x.device).reshape(1,1,n).cast(x.dtype).expand(x.shape)
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
(Tensor.arange(n, device=x.device).reshape(1,1,n,1) < Tensor.arange(n, device=x.device).reshape(1,1,1,n)))
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
sel = x.const_like(0).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
return x.gather(-1, sel), sel
@dataclass(frozen=True)
@@ -177,7 +177,8 @@ class TransformerBlock(FFNBlock):
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
if resolve(T != 1) else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
@@ -222,7 +223,8 @@ class MLATransformerBlock(FFNBlock):
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
v = k[..., :self.config.kv_lora_rank]
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \
if resolve(T != 1) else None
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
if mask is not None: attn = attn + mask
attn = attn.softmax(-1)
@@ -282,8 +284,8 @@ class GatedDeltaNetBlock(FFNBlock):
# recurrent state can't be partially reused after divergence, force a full rebuild
def _state_reset_ops(self):
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
return [self.conv_state.assign(self.conv_state.const_like(0)),
self.recurrent_state.assign(self.recurrent_state.const_like(0))] if hasattr(self, "conv_state") else []
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
def _init_state(self, x):

View File

@@ -118,7 +118,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
lo, hi = (start, stop-step) if step > 0 else (stop-step, start)
if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, **kwargs)
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False, **kwargs)
return (cls.full((output_len,), step, dtype=dtype, buffer=False, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
@classmethod
@@ -138,7 +138,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
"""
if steps < 0: raise ValueError("number of steps must be non-negative")
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
if steps == 1: return cls.full((1,), start, dtype=dtype, **kwargs)
if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False, **kwargs)
return (start + cls.arange(steps, dtype=dtypes.default_float, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@classmethod
@@ -186,7 +186,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
print(t.triu(diagonal=-1).numpy())
```
"""
return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.zeros_like())
return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.const_like(0))
def tril(self, diagonal:sint=0) -> Self:
"""
@@ -209,7 +209,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
print(t.tril(diagonal=-1).numpy())
```
"""
return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.zeros_like(), self)
return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.const_like(0), self)
# ***** random *****
@@ -238,7 +238,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
_, nmant = dtypes.finfo(dtype)
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
uint_bits = bits.bitcast(uint_dtype)
float_one_bits = uint_bits.ones_like(dtype=dtype).bitcast(uint_dtype)
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
def _pad_constant(self, pX, value:ConstType) -> Self:
@@ -250,7 +250,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
base = MovementMixin.pad(X, pads)
if value == 0: return base
base = base.cast(least_upper_dtype(base.dtype, dtypes.from_py(value)))
return MovementMixin.pad(X.ones_like(dtype=dtypes.bool), pads).where(base, base.full_like(value))
return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value))
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self:
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
@@ -715,7 +715,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
print(indices.numpy())
```
"""
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device)
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device, buffer=False)
values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis])
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
match = x.unsqueeze(-1).eq(values_t.unsqueeze(-2)) * type(self).ones(n, n, device=self.device, buffer=False).triu()
@@ -838,7 +838,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
```
"""
x, dim = self, self._resolve_dim(dim)
if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
if (orig_len := int(x.shape[dim])) <= 1: return x, x.const_like(0).cast(dtypes.default_int)
# pad to power of 2
n_stages = (orig_len-1).bit_length()
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
@@ -1174,9 +1174,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
reg_pads = resolve_pool_pads(padding, len(k_))
pads = self._apply_ceil_mode(reg_pads, k_, s_, dilation) if ceil_mode else reg_pads
if not count_include_pad:
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
return pool(self, pads).sum(axis) / pool(self.const_like(1), pads).sum(axis)
if not ceil_mode: return pool(self, pads).mean(axis)
return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).ones_like(),
return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).const_like(1),
tuple(cp-rp for cp,rp in zip(pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
@@ -1407,7 +1407,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if Y.device is not None and self.device is not None and Y.device != self.device:
raise RuntimeError(f"expected Y and self on the same device, {Y.device=}, {self.device=}")
log_probs = self.log_softmax()
loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.const_like(1).cast(dtypes.bool)
y = Y.unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
@@ -1459,7 +1459,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
```
"""
weight = Y.ones_like() if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape)
weight = Y.const_like(1) if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape)
masked_weight = weight if ignore_index is None else weight * Y.ne(ignore_index)
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)

View File

@@ -411,7 +411,7 @@ class LSTMCell:
self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device, buffer=False),)*2
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
i, f, g, o = gates.chunk(4, dim=1)
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()

View File

@@ -785,9 +785,9 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def _apply_transformation(input_sz, output_sz, scale_dim, mode):
index = Tensor.arange(output_sz, device=X.device)
if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5
if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index)
if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else index.const_like(0)
if mode == "asymmetric": return index / scale_dim
if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index)
if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else index.const_like(0)
if mode == "half_pixel_symmetric":
output_dim_scaled = input_sz * scale_dim
return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5
@@ -964,7 +964,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
import numpy as np
if not training_mode: return data, data.full_like(True, dtype=dtypes.bool)
if not training_mode: return data, data.const_like(True).cast(dtypes.bool)
if seed is not None:
rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), dtype=data.dtype, device=data.device)
else:
@@ -1043,7 +1043,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
attn_scores = mask.where(attn_scores, mask_filter_value)
if unidirectional:
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril()
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device, buffer=False).tril()
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
output = attn_scores.softmax(-1) @ v
@@ -1075,7 +1075,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
qk_matmul_return_val = scores
if is_causal:
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool).tril(0)
causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, buffer=False).tril(0)
scores = scores.masked_fill(causal_mask.logical_not(), -float("inf"))
if attn_mask is not None:

View File

@@ -62,7 +62,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
return ret
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...]|None, dtype:DType) -> list[list[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype, buffer=False) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
@@ -1063,9 +1063,9 @@ class Tensor(OpMixin):
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device)
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device).scatter(0, mask_cumsum, 1, reduce='add')
counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
@@ -1141,7 +1141,7 @@ class Tensor(OpMixin):
data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64)
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False)
for k in range(int(data.shape[1])):
state = state ^ data[:, k]
for i in range(24): # f1600
@@ -1348,7 +1348,7 @@ class Tensor(OpMixin):
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.zeros_like()
if p == 1: return self.const_like(0)
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
@@ -1376,7 +1376,7 @@ class Tensor(OpMixin):
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.ones_like(dtype=dtypes.bool).tril()
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask