diff --git a/examples/llama.py b/examples/llama.py index 1b17fee5f3..a9cdd0b702 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -155,7 +155,8 @@ class LLaMa: sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}" - model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT) + jit = bool(getenv("JIT", 1)) + model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT, jit=jit) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT, jit=jit) if model_path.is_dir(): weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) diff --git a/extra/models/llama.py b/extra/models/llama.py index d8604c0e13..8d1e84fd4f 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -119,7 +119,7 @@ class Transformer: def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0): # TODO: better way to handle the first call v.s. the rest? - if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1): + if tokens.shape[0:2] == (1,1) and self.forward_jit is not None: assert start_pos > 0 return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature) return self.forward(tokens, start_pos, temperature)