Update Whisper to use fetch helper (#2401)

* update whisper to use new fetch helper

* simplify file opening

* update name

* update key name to "downloads-cache"
This commit is contained in:
Francis Lata
2023-11-23 15:59:59 -05:00
committed by GitHub
parent 0505c5ea50
commit 6d672785db
2 changed files with 15 additions and 28 deletions

View File

@@ -81,13 +81,11 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
- name: Cache downloads
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
path: ~/.cache/tinygrad/downloads/
key: downloads-cache
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
@@ -116,13 +114,11 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
- name: Cache downloads
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
path: ~/.cache/tinygrad/downloads/
key: downloads-cache
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
@@ -229,13 +225,11 @@ jobs:
key: metal-webgpu-testing-packages-${{ hashFiles('**/setup.py') }}
- name: Install Dependencies
run: pip install -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Cache model weights
- name: Cache downloads
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
path: ~/Library/Caches/tinygrad/downloads/
key: downloads-cache
- name: Test LLaMA compile speed
run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
#- name: Run dtype test
@@ -296,13 +290,11 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
- name: Cache downloads
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
path: ~/Library/Caches/tinygrad/downloads/
key: downloads-cache
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
- name: Install OpenCL

View File

@@ -6,10 +6,9 @@ import base64
import multiprocessing
import numpy as np
from typing import Optional, Union, Literal, List
from extra.utils import download_file
from tinygrad.jit import TinyJit
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.helpers import getenv, DEBUG, CI, fetch
import tinygrad.nn as nn
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
@@ -194,11 +193,8 @@ LANGUAGES = {
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
}
BASE = pathlib.Path(__file__).parents[1] / "weights"
def get_encoding(encoding_name):
filename = encoding_name + ".tiktoken"
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename)
with open(BASE / filename) as f:
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
n_vocab = len(ranks)
specials = [
@@ -239,8 +235,7 @@ MODEL_URLS = {
def init_whisper(model_name="tiny.en", batch_size=1):
assert MODEL_URLS[model_name] is not None
filename = BASE / "whisper-{}.pt".format(model_name)
download_file(MODEL_URLS[model_name], filename)
filename = fetch(MODEL_URLS[model_name])
state = torch_load(filename)
model = Whisper(state['dims'], batch_size)
load_state_dict(model, state['model_state_dict'], strict=False)