diff --git a/examples/whisper.py b/examples/whisper.py index 8406b0a466..159adbb873 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -34,7 +34,7 @@ class MultiHeadAttention: if not hasattr(self, 'cache_k'): self.cache_k, self.cache_v = k, v else: - # see test_jitted_read_assign in test_jit.py + # see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994 self.cache_k.assign(k+1-1).realize() self.cache_v.assign(v+1-1).realize() else: diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 70bb8aa67a..49acab4200 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -4,6 +4,14 @@ from examples.whisper import init_whisper, load_file_waveform, transcribe_file, from tinygrad.helpers import CI from tinygrad.ops import Device +# Audio generated with the command on MacOS: +# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test +# We use the WAVE type because it's easier to decode in CI test environments +TEST_FILE_1 = str(pathlib.Path(__file__).parent / "whisper/test.wav") +TRANSCRIPTION_1 = "Could you please let me out of the box?" +TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav") +TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length." + @unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others") class TestWhisper(unittest.TestCase): @classmethod @@ -17,21 +25,22 @@ class TestWhisper(unittest.TestCase): del cls.model del cls.enc - def test_transcribe_file(self): - # Audio generated with the command on MacOS: - # say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test - # We use the WAVE type because it's easier to decode in CI test environments - filename = str(pathlib.Path(__file__).parent / "whisper/test.wav") - transcription = transcribe_file(self.model, self.enc, filename) - self.assertEqual("Could you please let me out of the box?", transcription) + def test_transcribe_file1(self): + self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1) - def test_transcribe_batch(self): - file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav") - file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav") - - waveforms = [load_file_waveform(file1), load_file_waveform(file2)] + def test_transcribe_file2(self): + self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2) + def test_transcribe_batch12(self): + waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)] transcriptions = transcribe_waveform(self.model, self.enc, waveforms) self.assertEqual(2, len(transcriptions)) - self.assertEqual("Could you please let me out of the box?", transcriptions[0]) - self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1]) + self.assertEqual(TRANSCRIPTION_1, transcriptions[0]) + self.assertEqual(TRANSCRIPTION_2, transcriptions[1]) + + def test_transcribe_batch21(self): + waveforms = [load_file_waveform(TEST_FILE_2), load_file_waveform(TEST_FILE_1)] + transcriptions = transcribe_waveform(self.model, self.enc, waveforms) + self.assertEqual(2, len(transcriptions)) + self.assertEqual(TRANSCRIPTION_2, transcriptions[0]) + self.assertEqual(TRANSCRIPTION_1, transcriptions[1])