mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Update Whisper to use fetch helper (#2401)
* update whisper to use new fetch helper * simplify file opening * update name * update key name to "downloads-cache"
This commit is contained in:
32
.github/workflows/test.yml
vendored
32
.github/workflows/test.yml
vendored
@@ -81,13 +81,11 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
|
||||
key: testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -116,13 +114,11 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
|
||||
key: testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -229,13 +225,11 @@ jobs:
|
||||
key: metal-webgpu-testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Cache model weights
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
path: ~/Library/Caches/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
- name: Test LLaMA compile speed
|
||||
run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
|
||||
#- name: Run dtype test
|
||||
@@ -296,13 +290,11 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
|
||||
key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
path: ~/Library/Caches/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
|
||||
- name: Install OpenCL
|
||||
|
||||
@@ -6,10 +6,9 @@ import base64
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Literal, List
|
||||
from extra.utils import download_file
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, fetch
|
||||
import tinygrad.nn as nn
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
@@ -194,11 +193,8 @@ LANGUAGES = {
|
||||
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
|
||||
}
|
||||
|
||||
BASE = pathlib.Path(__file__).parents[1] / "weights"
|
||||
def get_encoding(encoding_name):
|
||||
filename = encoding_name + ".tiktoken"
|
||||
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename)
|
||||
with open(BASE / filename) as f:
|
||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
@@ -239,8 +235,7 @@ MODEL_URLS = {
|
||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = BASE / "whisper-{}.pt".format(model_name)
|
||||
download_file(MODEL_URLS[model_name], filename)
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
|
||||
Reference in New Issue
Block a user