remove whisper +1-1 hack (#2360)

* remove whisper +1-1 hack

* Revert "remove whisper +1-1 hack"

This reverts commit 5db3800f09.

* update whisper tests

* comment context
This commit is contained in:
chenyu
2023-11-19 17:56:36 -05:00
committed by GitHub
parent a0890f4e6c
commit e9847be790
2 changed files with 24 additions and 15 deletions

View File

@@ -34,7 +34,7 @@ class MultiHeadAttention:
if not hasattr(self, 'cache_k'):
self.cache_k, self.cache_v = k, v
else:
# see test_jitted_read_assign in test_jit.py
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
self.cache_k.assign(k+1-1).realize()
self.cache_v.assign(v+1-1).realize()
else:

View File

@@ -4,6 +4,14 @@ from examples.whisper import init_whisper, load_file_waveform, transcribe_file,
from tinygrad.helpers import CI
from tinygrad.ops import Device
# Audio generated with the command on MacOS:
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
# We use the WAVE type because it's easier to decode in CI test environments
TEST_FILE_1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
TRANSCRIPTION_1 = "Could you please let me out of the box?"
TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
class TestWhisper(unittest.TestCase):
@classmethod
@@ -17,21 +25,22 @@ class TestWhisper(unittest.TestCase):
del cls.model
del cls.enc
def test_transcribe_file(self):
# Audio generated with the command on MacOS:
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
# We use the WAVE type because it's easier to decode in CI test environments
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
transcription = transcribe_file(self.model, self.enc, filename)
self.assertEqual("Could you please let me out of the box?", transcription)
def test_transcribe_file1(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
def test_transcribe_batch(self):
file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
waveforms = [load_file_waveform(file1), load_file_waveform(file2)]
def test_transcribe_file2(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
def test_transcribe_batch12(self):
waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)]
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
self.assertEqual(2, len(transcriptions))
self.assertEqual("Could you please let me out of the box?", transcriptions[0])
self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1])
self.assertEqual(TRANSCRIPTION_1, transcriptions[0])
self.assertEqual(TRANSCRIPTION_2, transcriptions[1])
def test_transcribe_batch21(self):
waveforms = [load_file_waveform(TEST_FILE_2), load_file_waveform(TEST_FILE_1)]
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
self.assertEqual(2, len(transcriptions))
self.assertEqual(TRANSCRIPTION_2, transcriptions[0])
self.assertEqual(TRANSCRIPTION_1, transcriptions[1])