diff --git a/test/external/external_llm_eval.py b/test/external/external_llm_eval.py index 619cd631f6..ed617cfae4 100644 --- a/test/external/external_llm_eval.py +++ b/test/external/external_llm_eval.py @@ -10,7 +10,7 @@ if __name__ == "__main__": model, kv = Transformer.from_gguf(Tensor.from_url(models["1B"]), max_context=4096) - tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"]) + tok = SimpleTokenizer.from_gguf_kv(kv) bos_id: int = kv['tokenizer.ggml.bos_token_id'] eos_id: int = kv['tokenizer.ggml.eos_token_id'] diff --git a/test/external/external_test_simple_tokenizer.py b/test/external/external_test_simple_tokenizer.py index b3e33e0d4d..9c3ca8f420 100644 --- a/test/external/external_test_simple_tokenizer.py +++ b/test/external/external_test_simple_tokenizer.py @@ -1,17 +1,19 @@ from transformers import AutoTokenizer from datasets import load_dataset -from tinygrad.apps.llm import SimpleTokenizer -from tinygrad.helpers import tqdm, getenv +from tinygrad.apps.llm import SimpleTokenizer, gpt2_decode_vocab, get_llama_re +from tinygrad.helpers import tqdm, getenv, partition # use ALLOW_FAILED=-1 to go over the entire dataset without printing. if __name__ == "__main__": base_tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") - vocab_words = [ word for word, _ in sorted(base_tokenizer.get_vocab().items(), key=lambda t: t[1]) ] + special_tokens, normal_tokens = partition(((t, tid) for t, tid in base_tokenizer.vocab.items()), + lambda e: e[1] in base_tokenizer.all_special_ids) inv_vocab = { tid: word for word, tid in base_tokenizer.get_vocab().items() } - simple_tokenizer = SimpleTokenizer(vocab_words) + simple_tokenizer = SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens)) color_codes = [ 91, 92, 94, 93, 95 ] - def color_tokens(tids): return "".join(f"\033[{color_codes[i%len(color_codes)]}m{inv_vocab[t]}" for i, t in enumerate(tids)) + "\033[0m" + def color_tokens(tids): + return "".join(f"\033[{color_codes[i%len(color_codes)]}m{base_tokenizer.decode([t])}" for i, t in enumerate(tids)) + "\033[0m" ds = load_dataset("OpenAssistant/oasst1") allow_failed = getenv("ALLOW_FAILED", 10) diff --git a/test/unit/test_llm_tokenizer.py b/test/unit/test_llm_tokenizer.py new file mode 100644 index 0000000000..fca60bd7bb --- /dev/null +++ b/test/unit/test_llm_tokenizer.py @@ -0,0 +1,57 @@ +import unittest, base64, functools +from tinygrad.apps.llm import SimpleTokenizer, get_llama_re +from tinygrad.helpers import fetch + +class TestLLMTokenizer(unittest.TestCase): + @functools.cached_property + def basic_tok(self): return SimpleTokenizer(".*", { b"a": 0, b"b": 1, b"c": 2, b"ab": 3, b"bc": 4 }, { "": 5, "": 6, "": 7 }) + + @functools.cached_property + def llama_tok(self): + # from https://github.com/tinygrad/tinygrad/blob/e0106b6b257ebc003eb3694144e3e198f7d8cc37/examples/llama3.py#L14 + model_file = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model") + with open(model_file, "rt") as fd: + str_vocab = [ line.split(maxsplit=1) for line in fd.read().splitlines() if line ] + normal_tokens = { base64.b64decode(stok): int(srank) for stok, srank in str_vocab } + + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", + ] + [ f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5) ] + return SimpleTokenizer(get_llama_re(), normal_tokens, { token: len(normal_tokens) + i for i, token in enumerate(special_tokens) }) + + def _test_coding(self, tok: SimpleTokenizer, text: str, expected_tokens: list[int]): + self.assertEqual(tok.encode(text), expected_tokens) + self.assertEqual(tok.decode(expected_tokens), text) + + def test_abc(self): self._test_coding(self.basic_tok, "abc", [ 3, 2 ]) + def test_abbc(self): self._test_coding(self.basic_tok, "abbc", [ 3, 4 ]) + def test_aabbbcc(self): self._test_coding(self.basic_tok, "aabbbcc", [ 0, 3, 1, 4, 2 ]) + def test_specials1(self): self._test_coding(self.basic_tok, "aaaa", [ 0, 5, 0, 6, 0, 7, 0 ]) + def test_specials2(self): self._test_coding(self.basic_tok, "aa", [ 5, 0, 6, 0, 7 ]) + def test_invalid_token(self): + with self.assertRaises(RuntimeError): self._test_coding(self.basic_tok, "L", []) + + def test_no_specials(self): self._test_coding(SimpleTokenizer(".*", { bytes([i]): i for i in range(256) }, {}), "abc", [97, 98, 99]) + + # NOTE: the correct tokenization for this can only be found by looking up the text chunk in the vocab, not by applying merges + def test_llama_early_tokenize(self): self._test_coding(self.llama_tok, " например", [ 111797 ]) + + def test_llama_basic(self): self._test_coding(self.llama_tok, "hello world", [ 15339, 1917 ]) + def test_llama_control_char(self): self._test_coding(self.llama_tok, " \x850", [ 220, 116360, 15 ]) + def test_llama_bytes(self): self._test_coding(self.llama_tok, " \xec\x8b\xa4\xed", [ 1717, 105, 116174, 82638, 2483 ]) + def test_llama_special1(self): self._test_coding(self.llama_tok, "hello <|end_of_text|>", [ 15339, 220, 128001 ]) + def test_llama_special2(self): self._test_coding(self.llama_tok, "<|start_header_id|>user<|end_header_id|>\n\n", [ 128006, 882, 128007, 271 ]) + def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ]) + def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ]) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index baffe85aeb..4f80d4bf2d 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,33 +1,57 @@ from __future__ import annotations -import sys, argparse -from tinygrad import Tensor, nn, UOp, TinyJit, getenv +import sys, argparse, typing, re, itertools, unicodedata +from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers + +def gpt2_decode_vocab(voc: dict[str, int]): # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 + c2b = { chr(cp): cp for cp in itertools.chain(range(ord("!"), ord("~")+1), range(ord("¡"), ord("¬")+1), range(ord("®"), ord("ÿ")+1)) } + c2b.update({ chr(256+off): cp for off, cp in enumerate(cp for cp in range(256) if chr(cp) not in c2b) }) + return { bytes(c2b[c] for c in tok): tid for tok, tid in voc.items() } + +def get_llama_re(): + def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre)) + r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L") + # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286 + return "(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \ + f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+" class SimpleTokenizer: - def __init__(self, vocab: list[str]): - self.vocab: list[str] = vocab - self.biggest_token: int = max(map(len, vocab)) - self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)} - self.replace_space = "Ġ" - self.replace_newline = "Ċ" + def __init__(self, pat: str, normal_tokens: dict[bytes, int], special_tokens: dict[str, int]): + self._normal_tokens, self._special_tokens, self._pat = normal_tokens, special_tokens, re.compile(pat) + self._tok2str = { tid: tok.encode() for tok, tid in special_tokens.items() } | { tid: tok for tok, tid in normal_tokens.items() } + self._special_re = re.compile("|".join(re.escape(tok) for tok in self._special_tokens.keys()) if special_tokens else r"(?!)") - def encode(self, text:str) -> list[int]: - s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline) - out: list[int] = [] - i = 0 - while i < len(s): - j = min(i+self.biggest_token, len(s)) - while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1 - if tid is None: raise RuntimeError(f"token not found in {s}") - assert tid is not None, f"token not found in {s}" - out.append(tid) - i = j - return out + @staticmethod + def from_gguf_kv(kv: dict): + # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820 + if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'") + vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"])) + normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1) + return SimpleTokenizer(get_llama_re(), gpt2_decode_vocab(dict(normal_tokens)), dict(special_tokens)) - def decode(self, ids: list[int]) -> str: - return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n") + def encode(self, text: str): + tokens: list[int] = [] + pos = 0 + for match in self._special_re.finditer(text): + tokens.extend(self._encode_sentence(text[pos:match.start(0)]) + [self._special_tokens[text[match.start(0):match.end(0)]]]) + pos = match.end(0) + return tokens + self._encode_sentence(text[pos:]) - def role(self, role:str): - return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style + def decode(self, ids: list[int]) -> str: return b''.join(self._tok2str[tid] for tid in ids).decode() + def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n") + + def _encode_sentence(self, chunk: str): return [ tok for word in self._pat.findall(chunk) for tok in self._encode_word(word.encode()) ] + def _encode_word(self, word: bytes): + if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token] + parts = [word[i:i+1] for i in range(len(word))] + while True: + min_tid, min_idx = 2**32, -1 + for idx, (p1, p2) in enumerate(zip(parts[:-1], parts[1:])): + tid = self._normal_tokens.get(p1 + p2, min_tid) + if tid < min_tid: min_tid, min_idx = tid, idx + if min_idx == -1: break + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:] + try: return [ self._normal_tokens[p] for p in parts ] + except KeyError: raise RuntimeError("token not found") def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000): B, H, T, Hd = x.shape @@ -165,7 +189,7 @@ if __name__ == "__main__": model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context) # extract some metadata - tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"]) + tok = SimpleTokenizer.from_gguf_kv(kv) bos_id: int = kv['tokenizer.ggml.bos_token_id'] eos_id: int = kv['tokenizer.ggml.eos_token_id']