Enable whisper test in CI for more backends (#2355)

This commit is contained in:
mmmkkaaayy
2023-11-18 14:52:50 -08:00
committed by GitHub
parent d7d078c7f9
commit 08d09eb666
2 changed files with 23 additions and 1 deletions

View File

@@ -81,6 +81,13 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
@@ -109,6 +116,13 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
key: testing-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
@@ -285,6 +299,13 @@ jobs:
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
- name: Cache model weights
uses: actions/cache@v3
with:
path: |
weights/whisper-tiny.en.pt
weights/gpt2.tiktoken
key: model-weights-v1
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
- name: Install OpenCL

View File

@@ -1,9 +1,10 @@
import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI
from tinygrad.ops import Device
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):