diff --git a/examples/whisper.py b/examples/whisper.py index 40e4009c06..54e341fed2 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -106,6 +106,7 @@ class Whisper: # TODO: this is tragic. remove this import functools +import itertools import torch import torchaudio import librosa @@ -158,10 +159,8 @@ def get_encoding(n_vocab_in): "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], ] - special_tokens = {} - for token in specials: - special_tokens[token] = n_vocab - n_vocab += 1 + special_tokens = dict(zip(specials, itertools.count(n_vocab))) + n_vocab += len(specials) assert n_vocab == n_vocab_in import tiktoken return tiktoken.Encoding(