diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 22cee7ac89..03a8f9d37f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -81,13 +81,11 @@ jobs: with: path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages key: testing-packages-${{ hashFiles('**/setup.py') }} - - name: Cache model weights + - name: Cache downloads uses: actions/cache@v3 with: - path: | - weights/whisper-tiny.en.pt - weights/gpt2.tiktoken - key: model-weights-v1 + path: ~/.cache/tinygrad/downloads/ + key: downloads-cache - name: Install Dependencies run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Run Pytest @@ -116,13 +114,11 @@ jobs: with: path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages key: testing-packages-${{ hashFiles('**/setup.py') }} - - name: Cache model weights + - name: Cache downloads uses: actions/cache@v3 with: - path: | - weights/whisper-tiny.en.pt - weights/gpt2.tiktoken - key: model-weights-v1 + path: ~/.cache/tinygrad/downloads/ + key: downloads-cache - name: Install Dependencies run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Run Pytest @@ -229,13 +225,11 @@ jobs: key: metal-webgpu-testing-packages-${{ hashFiles('**/setup.py') }} - name: Install Dependencies run: pip install -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu - - name: Cache model weights + - name: Cache downloads uses: actions/cache@v3 with: - path: | - weights/whisper-tiny.en.pt - weights/gpt2.tiktoken - key: model-weights-v1 + path: ~/Library/Caches/tinygrad/downloads/ + key: downloads-cache - name: Test LLaMA compile speed run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py #- name: Run dtype test @@ -296,13 +290,11 @@ jobs: with: path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }} - - name: Cache model weights + - name: Cache downloads uses: actions/cache@v3 with: - path: | - weights/whisper-tiny.en.pt - weights/gpt2.tiktoken - key: model-weights-v1 + path: ~/Library/Caches/tinygrad/downloads/ + key: downloads-cache - name: Set env run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV - name: Install OpenCL diff --git a/examples/whisper.py b/examples/whisper.py index c30843e7eb..e31f904b28 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -6,10 +6,9 @@ import base64 import multiprocessing import numpy as np from typing import Optional, Union, Literal, List -from extra.utils import download_file from tinygrad.jit import TinyJit from tinygrad.nn.state import torch_load, load_state_dict -from tinygrad.helpers import getenv, DEBUG, CI +from tinygrad.helpers import getenv, DEBUG, CI, fetch import tinygrad.nn as nn from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor @@ -194,11 +193,8 @@ LANGUAGES = { "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", } -BASE = pathlib.Path(__file__).parents[1] / "weights" def get_encoding(encoding_name): - filename = encoding_name + ".tiktoken" - download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/" + filename, BASE / filename) - with open(BASE / filename) as f: + with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f: ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)} n_vocab = len(ranks) specials = [ @@ -239,8 +235,7 @@ MODEL_URLS = { def init_whisper(model_name="tiny.en", batch_size=1): assert MODEL_URLS[model_name] is not None - filename = BASE / "whisper-{}.pt".format(model_name) - download_file(MODEL_URLS[model_name], filename) + filename = fetch(MODEL_URLS[model_name]) state = torch_load(filename) model = Whisper(state['dims'], batch_size) load_state_dict(model, state['model_state_dict'], strict=False)