diff --git a/examples/mixtral.py b/examples/mixtral.py index 3266c8248e..c621d409e6 100644 --- a/examples/mixtral.py +++ b/examples/mixtral.py @@ -56,7 +56,7 @@ if __name__ == "__main__": with Profiling(sort="time", frac=0.1, enabled=args.profile): with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"): with WallTimeEvent(BenchEvent.STEP): - tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).item() + tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024-1).bind(start_pos), args.temperature).item() toks.append(tok) start_pos += 1 print(spp.decode(toks))