mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Enable whisper test in CI for more backends (#2355)
This commit is contained in:
21
.github/workflows/test.yml
vendored
21
.github/workflows/test.yml
vendored
@@ -81,6 +81,13 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
|
||||
key: testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -109,6 +116,13 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
|
||||
key: testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -285,6 +299,13 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
|
||||
key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache model weights
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
weights/whisper-tiny.en.pt
|
||||
weights/gpt2.tiktoken
|
||||
key: model-weights-v1
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
|
||||
- name: Install OpenCL
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.ops import Device
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user