diff --git a/docs/quickstart.md b/docs/quickstart.md index 57546abce7..89ed8e5012 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -133,7 +133,7 @@ For our loss function we will be using sparse categorical cross entropy loss. Th ```python def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) + y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32).unsqueeze(0).expand(Y.numel(), self.shape[-1]) y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) return self.log_softmax().mul(y).sum() / loss_mask.sum() ``` diff --git a/examples/audio_helpers.py b/examples/audio_helpers.py index 6ea2cea3ab..603a9cdc43 100644 --- a/examples/audio_helpers.py +++ b/examples/audio_helpers.py @@ -4,10 +4,10 @@ from tinygrad.dtype import DTypeLike, dtypes import math # rewritten from numpy -def rfftfreq(n: int, d: float = 1.0, device=None) -> Tensor: +def rfftfreq(n: int, d: float = 1.0) -> Tensor: val = 1.0 / (n * d) N = n // 2 + 1 - results = Tensor.arange(N, device=device) + results = Tensor.arange(N) return results * val # just like in librosa diff --git a/examples/llama3.py b/examples/llama3.py index f55fdb3273..27eda2ffcc 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -102,7 +102,7 @@ class Int8Embedding: self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half) def __call__(self, idx:Tensor) -> Tensor: - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).unsqueeze(-1) + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz).unsqueeze(-1) big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) diff --git a/examples/llm.c/train_gpt2.py b/examples/llm.c/train_gpt2.py index e0ba635a94..792e231927 100755 --- a/examples/llm.c/train_gpt2.py +++ b/examples/llm.c/train_gpt2.py @@ -99,7 +99,7 @@ class GPT: def __call__(self, idx:Tensor, targets=None): b, t = idx.shape - pos = Tensor.arange(0, t, device=idx.device) + pos = Tensor.arange(0, t) tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd) diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index 338d283d89..8e1bb38a43 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -57,7 +57,7 @@ class EmbeddingBert(nn.Embedding): def __call__(self, idx:Tensor) -> Tensor: if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device) arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).reshape(arange_shp) + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz).reshape(arange_shp) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp) return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype) diff --git a/extra/models/bert.py b/extra/models/bert.py index df619724f0..bd515d8c98 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -52,7 +52,7 @@ class BertForPretraining: # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1): log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index) - y_counter = Tensor.arange(predictions.shape[-1], device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) + y_counter = Tensor.arange(predictions.shape[-1]).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero @@ -159,7 +159,7 @@ class BertPooler: return self.dense(hidden_states[:, 0]).tanh() def gather(prediction_logits:Tensor, masked_lm_positions:Tensor): - counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) + counter = Tensor.arange(prediction_logits.shape[1]).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) return onehot @ prediction_logits @@ -189,7 +189,7 @@ class BertEmbeddings: input_shape = input_ids.shape seq_length = input_shape[1] - position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape) + position_ids = Tensor.arange(seq_length).unsqueeze(0).expand(*input_shape) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) diff --git a/extra/models/clip.py b/extra/models/clip.py index bd59e609c3..5960c111a3 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -466,7 +466,7 @@ class OpenClipEncoder: x = x + self.positional_embedding x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) - x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)] + x = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1)] x = x @ self.text_projection return x diff --git a/extra/models/llama.py b/extra/models/llama.py index abb9d63a06..4bd2d0d06f 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -164,7 +164,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): # softmax t = (logits / temp).softmax() - counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() + counter, counter2 = Tensor.arange(t.numel()).contiguous(), Tensor.arange(t.numel() - 1, -1, -1).contiguous() # top k if k: output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous() diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index d61b811e75..09d334f1b7 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -776,7 +776,7 @@ def _bilinear_interpolate( y = Tensor.where(ymask[:, None, :], y, 0) x = Tensor.where(xmask[:, None, :], x, 0) key1 = roi_batch_ind[:, None, None, None, None, None] - key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None] + key2 = Tensor.arange(channels)[None, :, None, None, None, None] key3 = y[:, None, :, None, :, None] key4 = x[:, None, None, :, None, :] return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX] @@ -802,8 +802,8 @@ def _bilinear_interpolate( def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype _, _, height, width = input.shape - ph = Tensor.arange(pooled_height, device=input.device) - pw = Tensor.arange(pooled_width, device=input.device) + ph = Tensor.arange(pooled_height) + pw = Tensor.arange(pooled_width) roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous() offset = 0.5 if aligned else 0.0 @@ -827,14 +827,14 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling if exact_sampling: count = max(roi_bin_grid_h * roi_bin_grid_w, 1) - iy = Tensor.arange(roi_bin_grid_h, device=input.device) - ix = Tensor.arange(roi_bin_grid_w, device=input.device) + iy = Tensor.arange(roi_bin_grid_h) + ix = Tensor.arange(roi_bin_grid_w) ymask = None xmask = None else: count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1) - iy = Tensor.arange(height, device=input.device) - ix = Tensor.arange(width, device=input.device) + iy = Tensor.arange(height) + ix = Tensor.arange(width) ymask = iy[None, :] < roi_bin_grid_h[:, None] xmask = ix[None, :] < roi_bin_grid_w[:, None] diff --git a/extra/models/t5.py b/extra/models/t5.py index fd0ba34af5..7a0fff23f2 100644 --- a/extra/models/t5.py +++ b/extra/models/t5.py @@ -164,12 +164,10 @@ class T5Attention: relative_buckets += Tensor.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None) -> Tensor: + def compute_bias(self, query_length, key_length) -> Tensor: """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = Tensor.arange(query_length, dtype=dtypes.long, device=device)[:, None] - memory_position = Tensor.arange(key_length, dtype=dtypes.long, device=device)[None, :] + context_position = Tensor.arange(query_length, dtype=dtypes.long)[:, None] + memory_position = Tensor.arange(key_length, dtype=dtypes.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) @@ -212,7 +210,7 @@ class T5Attention: scores = Tensor.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: - position_bias = self.compute_bias(key_length, key_length, device=scores.device) + position_bias = self.compute_bias(key_length, key_length) scores += position_bias attn_weights = Tensor.softmax(scores.float(), axis=-1).cast(scores.dtype) # (batch_size, n_heads, seq_length, key_length) diff --git a/extra/models/unet.py b/extra/models/unet.py index 88be107981..f1b488cef4 100644 --- a/extra/models/unet.py +++ b/extra/models/unet.py @@ -9,7 +9,7 @@ attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Te # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp() + freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) out = Tensor.cat(args.cos(), args.sin(), dim=-1) return out.cast(mixed_precision_dtype) if mixed_precision_dtype in Device[Device.DEFAULT].renderer.supported_dtypes() else out diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 073136915f..e5fad7db23 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -94,7 +94,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): return cls.full(argfix(*shape), 1.0, **kwargs) @classmethod - def arange(cls, start, stop=None, step=1, **kwargs) -> Self: + def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self: """ Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`. @@ -116,15 +116,15 @@ class OpMixin(ElementwiseMixin, ReduceMixin): ``` """ if stop is None: stop, start = start, 0 - dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) + if dtype is None: dtype = dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int lo, hi = (start, stop-step) if step > 0 else (stop-step, start) if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}") # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs - if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False, **kwargs) - return (cls.full((output_len,), step, dtype=dtype, buffer=False, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) + if (output_len:=ceildiv(stop-start, step)) <= 0: return cls.full((0,), 0, dtype=dtype, buffer=False) + return (cls.full((output_len,), step, dtype=dtype, buffer=False)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @classmethod - def linspace(cls, start:int|float, stop:int|float, steps:int, **kwargs) -> Self: + def linspace(cls, start:int|float, stop:int|float, steps:int, dtype:DTypeLike|None=None) -> Self: """ Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive. @@ -136,9 +136,9 @@ class OpMixin(ElementwiseMixin, ReduceMixin): ``` """ if steps < 0: raise ValueError("number of steps must be non-negative") - if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") - if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False, **kwargs) - return (start + cls.arange(steps, dtype=dtypes.default_float, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) + if (dtype := to_dtype(dtype or dtypes.default_float)) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") + if steps == 1: return cls.full((1,), start, dtype=dtype, buffer=False) + return (start + cls.arange(steps, dtype=dtypes.default_float) * ((stop - start) / (steps - 1))).cast(dtype) @classmethod def eye(cls, n:int, m:int|None=None, dtype:DTypeLike|None=None) -> Self: