diff --git a/examples/llama.py b/examples/llama.py index 3716f145a8..ff1473eab3 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -4,9 +4,8 @@ #typeguard.importhook.install_import_hook('tinygrad') from pathlib import Path -import functools, sys, argparse, math, platform +import functools, sys, argparse, json, os import numpy as np -from tqdm import tqdm np.set_printoptions(linewidth=200) from typing import Optional, Tuple @@ -14,6 +13,7 @@ from tinygrad.helpers import Timing, getenv, DEBUG, dtypes from tinygrad.ops import Device from tinygrad.tensor import Tensor from tinygrad.nn import Embedding, Linear +from tinygrad.nn.state import safe_load, torch_load, load_state_dict from tinygrad.ops import GlobalCounters from tinygrad.jit import TinyJit from tinygrad.shape.symbolic import Variable, sym_infer @@ -225,6 +225,28 @@ def concat_weights(models): return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis) return {name: convert(name) for name in {name: None for model in models for name in model}} +def load(fn:str): + if fn.endswith('.index.json'): + with open(fn) as fp: weight_map = json.load(fp)['weight_map'] + parts = {n: load(f'{os.path.dirname(fn)}/{os.path.basename(n)}') for n in set(weight_map.values())} + return {k: parts[n][k] for k, n in weight_map.items()} + elif fn.endswith('.safetensors'): + return safe_load(fn) + else: + return torch_load(fn) + +def convert_from_huggingface(weights, model): + keymap = { + 'model.embed_tokens.weight': 'tok_embeddings.weight', + **{f'model.layers.{l}.input_layernorm.weight': f'layers.{l}.attention_norm.weight' for l in range(len(model.layers))}, + **{f'model.layers.{l}.self_attn.{x}_proj.weight': f'layers.{l}.attention.w{x}.weight' for x in ['q', 'k', 'v', 'o'] for l in range(len(model.layers))}, + **{f'model.layers.{l}.post_attention_layernorm.weight': f'layers.{l}.ffn_norm.weight' for l in range(len(model.layers))}, + **{f'model.layers.{l}.mlp.{x}_proj.weight': f'layers.{l}.feed_forward.w{y}.weight' for x, y in {'gate': '1', 'down': '2', 'up': '3'}.items() for l in range(len(model.layers))}, + 'model.norm.weight': 'norm.weight', + 'lm_head.weight': 'output.weight', + } + return {keymap[k]: v for k,v in weights.items() if '.rotary_emb.' not in k} + class AbsmaxQuantizedLinear: def __init__(self, in_features, out_features, bias=False): assert bias == False @@ -254,10 +276,16 @@ class LLaMa: sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) assert sp_model.vocab_size() == VOCAB_SIZE - from tinygrad.nn.state import torch_load, load_state_dict params = MODEL_PARAMS[model_gen][model_size] model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"]) - weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) + + if model_path.is_dir(): + weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) + else: + weights = load(str(model_path)) + if 'model.embed_tokens.weight' in weights: + weights = convert_from_huggingface(weights, model) + if quantize: weights = AbsmaxQuantizedLinear.quantize(weights) load_state_dict(model, weights, strict=False) @@ -304,6 +332,7 @@ if __name__ == "__main__": parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2") parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]") parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory") + parser.add_argument('--model', type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file") args = parser.parse_args() chatbot = args.prompt == None @@ -399,10 +428,10 @@ After you are done speaking, output [EOS]. You are not Chad. LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen] - WEIGHTS_DIR = Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/" - TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model" + MODEL_PATH = args.model or Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}" + TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model" print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model") - llama = LLaMa.build(WEIGHTS_DIR, TOKENIZER_FILENAME, model_gen=args.gen, model_size=args.size, quantize=args.quantize) + llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize) if chatbot: # encode pre prompt