mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
remove kwargs from arange and linspace [PR] (#16505)
it used to have requires_grad and device, now both are removed
This commit is contained in:
@@ -133,7 +133,7 @@ For our loss function we will be using sparse categorical cross entropy loss. Th
|
||||
```python
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
```
|
||||
|
||||
@@ -4,10 +4,10 @@ from tinygrad.dtype import DTypeLike, dtypes
|
||||
import math
|
||||
|
||||
# rewritten from numpy
|
||||
def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor:
|
||||
def rfftfreq(n: int, d: float = 1.0) -> Tensor:
|
||||
val = 1.0 / (n * d)
|
||||
N = n // 2 + 1
|
||||
results = Tensor.arange(N, device=device)
|
||||
results = Tensor.arange(N)
|
||||
return results * val
|
||||
|
||||
# just like in librosa
|
||||
|
||||
@@ -102,7 +102,7 @@ class Int8Embedding:
|
||||
self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).unsqueeze(-1)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz).unsqueeze(-1)
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
|
||||
@@ -99,7 +99,7 @@ class GPT:
|
||||
|
||||
def __call__(self, idx:Tensor, targets=None):
|
||||
b, t = idx.shape
|
||||
pos = Tensor.arange(0, t, device=idx.device)
|
||||
pos = Tensor.arange(0, t)
|
||||
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||
|
||||
@@ -57,7 +57,7 @@ class EmbeddingBert(nn.Embedding):
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
|
||||
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).reshape(arange_shp)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
||||
return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class BertForPretraining:
|
||||
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
|
||||
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
|
||||
log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index)
|
||||
y_counter = Tensor.arange(predictions.shape[-1], device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y_counter = Tensor.arange(predictions.shape[-1]).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
|
||||
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
|
||||
|
||||
@@ -159,7 +159,7 @@ class BertPooler:
|
||||
return self.dense(hidden_states[:, 0]).tanh()
|
||||
|
||||
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
|
||||
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
counter = Tensor.arange(prediction_logits.shape[1]).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
return onehot @ prediction_logits
|
||||
|
||||
@@ -189,7 +189,7 @@ class BertEmbeddings:
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
|
||||
position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape)
|
||||
position_ids = Tensor.arange(seq_length).unsqueeze(0).expand(*input_shape)
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
@@ -466,7 +466,7 @@ class OpenClipEncoder:
|
||||
x = x + self.positional_embedding
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
|
||||
x = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1)]
|
||||
x = x @ self.text_projection
|
||||
return x
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
# softmax
|
||||
t = (logits / temp).softmax()
|
||||
|
||||
counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
|
||||
counter, counter2 = Tensor.arange(t.numel()).contiguous(), Tensor.arange(t.numel() - 1, -1, -1).contiguous()
|
||||
# top k
|
||||
if k:
|
||||
output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
|
||||
|
||||
@@ -776,7 +776,7 @@ def _bilinear_interpolate(
|
||||
y = Tensor.where(ymask[:, None, :], y, 0)
|
||||
x = Tensor.where(xmask[:, None, :], x, 0)
|
||||
key1 = roi_batch_ind[:, None, None, None, None, None]
|
||||
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
|
||||
key2 = Tensor.arange(channels)[None, :, None, None, None, None]
|
||||
key3 = y[:, None, :, None, :, None]
|
||||
key4 = x[:, None, None, :, None, :]
|
||||
return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
|
||||
@@ -802,8 +802,8 @@ def _bilinear_interpolate(
|
||||
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
||||
orig_dtype = input.dtype
|
||||
_, _, height, width = input.shape
|
||||
ph = Tensor.arange(pooled_height, device=input.device)
|
||||
pw = Tensor.arange(pooled_width, device=input.device)
|
||||
ph = Tensor.arange(pooled_height)
|
||||
pw = Tensor.arange(pooled_width)
|
||||
|
||||
roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
|
||||
offset = 0.5 if aligned else 0.0
|
||||
@@ -827,14 +827,14 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
||||
|
||||
if exact_sampling:
|
||||
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
||||
iy = Tensor.arange(roi_bin_grid_h, device=input.device)
|
||||
ix = Tensor.arange(roi_bin_grid_w, device=input.device)
|
||||
iy = Tensor.arange(roi_bin_grid_h)
|
||||
ix = Tensor.arange(roi_bin_grid_w)
|
||||
ymask = None
|
||||
xmask = None
|
||||
else:
|
||||
count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
|
||||
iy = Tensor.arange(height, device=input.device)
|
||||
ix = Tensor.arange(width, device=input.device)
|
||||
iy = Tensor.arange(height)
|
||||
ix = Tensor.arange(width)
|
||||
ymask = iy[None, :] < roi_bin_grid_h[:, None]
|
||||
xmask = ix[None, :] < roi_bin_grid_w[:, None]
|
||||
|
||||
|
||||
@@ -164,12 +164,10 @@ class T5Attention:
|
||||
relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device=None) -> Tensor:
|
||||
def compute_bias(self, query_length, key_length) -> Tensor:
|
||||
"""Compute binned relative position bias"""
|
||||
if device is None:
|
||||
device = self.relative_attention_bias.weight.device
|
||||
context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None]
|
||||
memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :]
|
||||
context_position = Tensor.arange(query_length, dtype=dtypes.long)[:, None]
|
||||
memory_position = Tensor.arange(key_length, dtype=dtypes.long)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
@@ -212,7 +210,7 @@ class T5Attention:
|
||||
scores = Tensor.matmul(query_states, key_states.transpose(3, 2))
|
||||
|
||||
if position_bias is None:
|
||||
position_bias = self.compute_bias(key_length, key_length, device=scores.device)
|
||||
position_bias = self.compute_bias(key_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
@@ -9,7 +9,7 @@ attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Te
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(mixed_precision_dtype) if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes() else out
|
||||
|
||||
@@ -94,7 +94,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
return cls.full(argfix(*shape), 1.0, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def arange(cls, start, stop=None, step=1, **kwargs) -> Self:
|
||||
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
|
||||
|
||||
@@ -116,15 +116,15 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
```
|
||||
"""
|
||||
if stop is None: stop, start = start, 0
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
if dtype is None: dtype = dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int
|
||||
lo, hi = (start, stop-step) if step > 0 else (stop-step, start)
|
||||
if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False, **kwargs)
|
||||
return (cls.full((output_len,), step, dtype=dtype, buffer=False, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False)
|
||||
return (cls.full((output_len,), step, dtype=dtype, buffer=False)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
@classmethod
|
||||
def linspace(cls, start:int|float, stop:int|float, steps:int, **kwargs) -> Self:
|
||||
def linspace(cls, start:int|float, stop:int|float, steps:int, dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
|
||||
|
||||
@@ -136,9 +136,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
```
|
||||
"""
|
||||
if steps < 0: raise ValueError("number of steps must be non-negative")
|
||||
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
|
||||
if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False, **kwargs)
|
||||
return (start + cls.arange(steps, dtype=dtypes.default_float, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
if (dtype := to_dtype(dtype or dtypes.default_float)) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
|
||||
if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False)
|
||||
return (start + cls.arange(steps, dtype=dtypes.default_float) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
|
||||
@classmethod
|
||||
def eye(cls, n:int, m:int|None=None, dtype:DTypeLike|None=None) -> Self:
|
||||
|
||||
Reference in New Issue
Block a user