From 9744d512d93a8f60dfd45b464a0c449fddd7075e Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 21 May 2026 21:37:52 -0400 Subject: [PATCH] use more non-buffered const (#16327) --- tinygrad/llm/model.py | 12 +++++++----- tinygrad/mixin/__init__.py | 24 ++++++++++++------------ tinygrad/nn/__init__.py | 2 +- tinygrad/nn/onnx.py | 10 +++++----- tinygrad/tensor.py | 12 ++++++------ 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/tinygrad/llm/model.py b/tinygrad/llm/model.py index 6278e91bef..bc7fe59f0e 100644 --- a/tinygrad/llm/model.py +++ b/tinygrad/llm/model.py @@ -30,7 +30,7 @@ def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]: vals = Tensor.arange(n, device=x.device).reshape(1,1,n).cast(x.dtype).expand(x.shape) cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \ (Tensor.arange(n, device=x.device).reshape(1,1,n,1) < Tensor.arange(n, device=x.device).reshape(1,1,1,n))) - sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32') + sel = x.const_like(0).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32') return x.gather(-1, sel), sel @dataclass(frozen=True) @@ -177,7 +177,8 @@ class TransformerBlock(FFNBlock): # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True # TODO: this if statement should be removed and it shouldn't generate extra kernels - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \ + if resolve(T != 1) else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid())) @@ -222,7 +223,8 @@ class MLATransformerBlock(FFNBlock): k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :] v = k[..., :self.config.kv_lora_rank] - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device, buffer=False).triu(start_pos+1) \ + if resolve(T != 1) else None attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5) if mask is not None: attn = attn + mask attn = attn.softmax(-1) @@ -282,8 +284,8 @@ class GatedDeltaNetBlock(FFNBlock): # recurrent state can't be partially reused after divergence, force a full rebuild def _state_reset_ops(self): - return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)), - self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else [] + return [self.conv_state.assign(self.conv_state.const_like(0)), + self.recurrent_state.assign(self.recurrent_state.const_like(0))] if hasattr(self, "conv_state") else [] def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len def _init_state(self, x): diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 487a8a1395..d24a7d9d60 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -118,7 +118,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): lo, hi = (start, stop-step) if step > 0 else (stop-step, start) if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}") # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs - if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, **kwargs) + if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False, **kwargs) return (cls.full((output_len,), step, dtype=dtype, buffer=False, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @classmethod @@ -138,7 +138,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): """ if steps < 0: raise ValueError("number of steps must be non-negative") if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") - if steps == 1: return cls.full((1,), start, dtype=dtype, **kwargs) + if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False, **kwargs) return (start + cls.arange(steps, dtype=dtypes.default_float, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) @classmethod @@ -186,7 +186,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): print(t.triu(diagonal=-1).numpy()) ``` """ - return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.zeros_like()) + return self._tri(self.shape[-2], self.shape[-1], diagonal, self.device).where(self, self.const_like(0)) def tril(self, diagonal:sint=0) -> Self: """ @@ -209,7 +209,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): print(t.tril(diagonal=-1).numpy()) ``` """ - return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.zeros_like(), self) + return self._tri(self.shape[-2], self.shape[-1], diagonal+1, self.device).where(self.const_like(0), self) # ***** random ***** @@ -238,7 +238,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): _, nmant = dtypes.finfo(dtype) uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize] uint_bits = bits.bitcast(uint_dtype) - float_one_bits = uint_bits.ones_like(dtype=dtype).bitcast(uint_dtype) + float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype) return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape) def _pad_constant(self, pX, value:ConstType) -> Self: @@ -250,7 +250,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): base = MovementMixin.pad(X, pads) if value == 0: return base base = base.cast(least_upper_dtype(base.dtype, dtypes.from_py(value))) - return MovementMixin.pad(X.ones_like(dtype=dtypes.bool), pads).where(base, base.full_like(value)) + return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value)) def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self: if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.') @@ -715,7 +715,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): print(indices.numpy()) ``` """ - if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device) + if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), type(self).zeros(self.shape, dtype=dtypes.int32, device=self.device, buffer=False) values, n = self._split_cumalu(axis, Ops.MAX), int(self.shape[axis]) x, values_t = self.transpose(axis, -1), values.transpose(axis, -1) match = x.unsqueeze(-1).eq(values_t.unsqueeze(-2)) * type(self).ones(n, n, device=self.device, buffer=False).triu() @@ -838,7 +838,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): ``` """ x, dim = self, self._resolve_dim(dim) - if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) + if (orig_len := int(x.shape[dim])) <= 1: return x, x.const_like(0).cast(dtypes.default_int) # pad to power of 2 n_stages = (orig_len-1).bit_length() pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) @@ -1174,9 +1174,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin): reg_pads = resolve_pool_pads(padding, len(k_)) pads = self._apply_ceil_mode(reg_pads, k_, s_, dilation) if ceil_mode else reg_pads if not count_include_pad: - return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis) + return pool(self, pads).sum(axis) / pool(self.const_like(1), pads).sum(axis) if not ceil_mode: return pool(self, pads).mean(axis) - return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).ones_like(), + return pool(self, pads).sum(axis) / pool(self._pad_constant(((0,0),)*(self.ndim-len(k_)) + flat_to_grouped(reg_pads), 0.0).const_like(1), tuple(cp-rp for cp,rp in zip(pads, reg_pads))).sum(axis) def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, @@ -1407,7 +1407,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): if Y.device is not None and self.device is not None and Y.device != self.device: raise RuntimeError(f"expected Y and self on the same device, {Y.device=}, {self.device=}") log_probs = self.log_softmax() - loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool) + loss_mask = Y.ne(ignore_index) if ignore_index != -1 else Y.const_like(1).cast(dtypes.bool) y = Y.unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1) smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask) unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing) @@ -1459,7 +1459,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): print(t.log_softmax().nll_loss(Y, reduction='none').numpy()) ``` """ - weight = Y.ones_like() if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape) + weight = Y.const_like(1) if weight is None else weight.gather(0, Y.flatten()).reshape(Y.shape) masked_weight = weight if ignore_index is None else weight * Y.ne(ignore_index) nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index b1b4cc211a..d8c120466d 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -411,7 +411,7 @@ class LSTMCell: self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]: - if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2 + if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device, buffer=False),)*2 gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh) i, f, g, o = gates.chunk(4, dim=1) i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid() diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 39262bba3a..b117f286d7 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -785,9 +785,9 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def _apply_transformation(input_sz, output_sz, scale_dim, mode): index = Tensor.arange(output_sz, device=X.device) if mode == "half_pixel": return (index + 0.5) / scale_dim - 0.5 - if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else Tensor.zeros_like(index) + if mode == "align_corners": return index * (input_sz - 1) / (output_sz - 1) if output_sz != 1 else index.const_like(0) if mode == "asymmetric": return index / scale_dim - if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else Tensor.zeros_like(index) + if mode == "pytorch_half_pixel": return ((index + 0.5) / scale_dim - 0.5) if output_sz != 1 else index.const_like(0) if mode == "half_pixel_symmetric": output_dim_scaled = input_sz * scale_dim return (input_sz / 2) * (1 - (output_sz / output_dim_scaled)) + (index + 0.5) / scale_dim - 0.5 @@ -964,7 +964,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT # Reimplemented here because you need legacy RNG for passing ONNX tests. def dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): import numpy as np - if not training_mode: return data, data.full_like(True, dtype=dtypes.bool) + if not training_mode: return data, data.const_like(True).cast(dtypes.bool) if seed is not None: rand = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)), dtype=data.dtype, device=data.device) else: @@ -1043,7 +1043,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT attn_scores = mask.where(attn_scores, mask_filter_value) if unidirectional: - causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril() + causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device, buffer=False).tril() attn_scores = causal_mask.where(attn_scores, mask_filter_value) output = attn_scores.softmax(-1) @ v @@ -1075,7 +1075,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT qk_matmul_return_val = scores if is_causal: - causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool).tril(0) + causal_mask = Tensor.ones(Q.shape[-2], K.shape[-2], device=Q.device, dtype=dtypes.bool, buffer=False).tril(0) scores = scores.masked_fill(causal_mask.logical_not(), -float("inf")) if attn_mask is not None: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index aa66cc6d0b..837701db65 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -62,7 +62,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp: return ret def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...]|None, dtype:DType) -> list[list[Tensor]]: - return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim) + return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype, buffer=False) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 @@ -1063,9 +1063,9 @@ class Tensor(OpMixin): x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten() mask_cumsum = mask.cumsum() if size is None: - counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device) + counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False) return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()] - counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device).scatter(0, mask_cumsum, 1, reduce='add') + counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add') return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype) def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor: @@ -1141,7 +1141,7 @@ class Tensor(OpMixin): data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64) - state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64) + state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False) for k in range(int(data.shape[1])): state = state ^ data[:, k] for i in range(24): # f1600 @@ -1348,7 +1348,7 @@ class Tensor(OpMixin): """ if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]") if not Tensor.training or p == 0: return self - if p == 1: return self.zeros_like() + if p == 1: return self.const_like(0) return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, @@ -1376,7 +1376,7 @@ class Tensor(OpMixin): # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") - attn_mask = qk.ones_like(dtype=dtypes.bool).tril() + attn_mask = qk.const_like(1).cast(dtypes.bool).tril() if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask