diff --git a/examples/gpt2.py b/examples/gpt2.py index 562312cb3f..fb690c9f69 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -15,7 +15,7 @@ from tinygrad.nn import Embedding, Linear from tinygrad.jit import TinyJit from tinygrad.shape.symbolic import Variable -from examples.llama import sample +MAX_CONTEXT = 128 class LayerNorm: def __init__(self, dim, eps=1e-5): @@ -81,7 +81,7 @@ class TransformerBlock: if start_pos > 0 and mask is None and getenv("JIT"): seqlen = x.shape[1] - pos = Variable("pos", 1, 128) # max context + pos = Variable("pos", 1, MAX_CONTEXT) self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3]) self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3]) @@ -104,18 +104,34 @@ class Transformer: self.ln_f = LayerNorm(dim, norm_eps) self.lm_head = linear(dim, vocab_size, bias=False) - def __call__(self, tokens:Tensor, start_pos:int): - _bsz, seqlen = tokens.shape + self.embed_jitted = TinyJit(self.embed) + self.postprocess_jitted = TinyJit(self.postprocess) + + def embed(self, tokens, pos): tok_emb = self.wte(tokens) - pos = Tensor.arange(start_pos, start_pos + seqlen).reshape(1, -1) pos_emb = self.wpe(pos) h = tok_emb + pos_emb + return h.realize() - # get only the part we are using. making it contiguous avoids more kernel calls - mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None - h = h.sequential([functools.partial(layer, start_pos=start_pos, mask=mask) for layer in self.h]) - h = self.ln_f(h) - return self.lm_head(h) + def postprocess(self, x, temperature:Optional[float]): + logits = self.lm_head(self.ln_f(x)) + if temperature is not None: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize() + return logits.realize() + + def __call__(self, tokens:Tensor, start_pos:int, temperature:Optional[float]): + _bsz, seqlen = tokens.shape + if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize() + if seqlen == 1 and start_pos > 0 and getenv("JIT"): + start_pos_var = Variable("start_pos", 1, MAX_CONTEXT) + pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var+seqlen))) + pos.lazydata.st.var_vals[start_pos_var] = start_pos + h = self.embed_jitted(tokens, pos).sequential([functools.partial(layer, start_pos=start_pos, mask=None) for layer in self.h]) + return self.postprocess_jitted(h, temperature) + else: + pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos, start_pos+seqlen))) + mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() + h = self.embed(tokens, pos).sequential([functools.partial(layer, start_pos=start_pos, mask=mask) for layer in self.h]) + return self.postprocess(h, temperature) # **** files and arguments **** @@ -163,10 +179,13 @@ class GPT2: GlobalCounters.reset() if args.timing: print("") st = GlobalCounters.time_sum_s - with Timing(f"ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU, {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB") if DEBUG else None, enabled=timing): - logits = self.model(Tensor([toks[start_pos:]]), start_pos)[:, -1, :].realize() + with Timing(f"ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"+ + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ + f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s") if DEBUG else None, enabled=timing): + probs = self.model(Tensor([toks[start_pos:]]), start_pos, temperature) with Timing("sync in ", enabled=timing): - tok = sample(logits, temperature) + probs_np = probs.numpy() + tok = int(np.random.choice(len(probs_np), p=probs_np)) start_pos = len(toks) toks.append(tok) output = self.tokenizer.decode(toks) diff --git a/test/external/external_test_embedding.py b/test/external/external_test_embedding.py new file mode 100644 index 0000000000..9d6bd7f2b0 --- /dev/null +++ b/test/external/external_test_embedding.py @@ -0,0 +1,8 @@ +from tinygrad.tensor import Tensor +from tinygrad.nn import Embedding + +if __name__ == "__main__": + vocab_size = 50257 + dim = 128 + test = Embedding(vocab_size, dim) + ret = test(Tensor([[1,2,3]])).numpy() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index ae3ad3ebb3..5bde6b3233 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -213,6 +213,11 @@ class Linearizer: # print early if DEBUG >= 5: self.printbufs("early") + def has_variable_shape(self) -> bool: + for b in self.bufs: + if any(not isinstance(x, int) for x in b.st.shape): return True + return False + def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 2c59d844e9..6d556b4bcc 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -72,7 +72,7 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg): if global_db is not None and skey in global_db: choice = global_db[skey] - elif any(not isinstance(x, int) for x in k.full_shape): + elif k.has_variable_shape(): # don't optimize variable shapes choice = "BASELINE" else: @@ -260,6 +260,10 @@ def hand_coded_optimizations(k:Linearizer): # no more opt if we are grouping if k.group_for_reduce: return + # no more opt if there's non ints in any shapes + # TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT + if k.has_variable_shape(): return + # **** below this line need to be optional and benchmarked **** # potentially do more upcasts of non reduce axes based on a heuristic diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index db34204fa6..977c8740df 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -120,5 +120,5 @@ class Embedding: self.weight = Tensor.glorot_uniform(vocab_size, embed_size) def __call__(self, idx:Tensor) -> Tensor: - vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size) - return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight + if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size) + return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight