mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
remove whisper +1-1 hack (#2360)
* remove whisper +1-1 hack
* Revert "remove whisper +1-1 hack"
This reverts commit 5db3800f09.
* update whisper tests
* comment context
This commit is contained in:
@@ -34,7 +34,7 @@ class MultiHeadAttention:
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
# see test_jitted_read_assign in test_jit.py
|
||||
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
||||
self.cache_k.assign(k+1-1).realize()
|
||||
self.cache_v.assign(v+1-1).realize()
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,14 @@ from examples.whisper import init_whisper, load_file_waveform, transcribe_file,
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.ops import Device
|
||||
|
||||
# Audio generated with the command on MacOS:
|
||||
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
|
||||
# We use the WAVE type because it's easier to decode in CI test environments
|
||||
TEST_FILE_1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
TRANSCRIPTION_1 = "Could you please let me out of the box?"
|
||||
TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
||||
TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length."
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -17,21 +25,22 @@ class TestWhisper(unittest.TestCase):
|
||||
del cls.model
|
||||
del cls.enc
|
||||
|
||||
def test_transcribe_file(self):
|
||||
# Audio generated with the command on MacOS:
|
||||
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
|
||||
# We use the WAVE type because it's easier to decode in CI test environments
|
||||
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
transcription = transcribe_file(self.model, self.enc, filename)
|
||||
self.assertEqual("Could you please let me out of the box?", transcription)
|
||||
def test_transcribe_file1(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
def test_transcribe_batch(self):
|
||||
file1 = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
file2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav")
|
||||
|
||||
waveforms = [load_file_waveform(file1), load_file_waveform(file2)]
|
||||
def test_transcribe_file2(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
|
||||
|
||||
def test_transcribe_batch12(self):
|
||||
waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)]
|
||||
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
self.assertEqual(2, len(transcriptions))
|
||||
self.assertEqual("Could you please let me out of the box?", transcriptions[0])
|
||||
self.assertEqual("a slightly longer audio file so that we can test batch transcriptions of varying length.", transcriptions[1])
|
||||
self.assertEqual(TRANSCRIPTION_1, transcriptions[0])
|
||||
self.assertEqual(TRANSCRIPTION_2, transcriptions[1])
|
||||
|
||||
def test_transcribe_batch21(self):
|
||||
waveforms = [load_file_waveform(TEST_FILE_2), load_file_waveform(TEST_FILE_1)]
|
||||
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
self.assertEqual(2, len(transcriptions))
|
||||
self.assertEqual(TRANSCRIPTION_2, transcriptions[0])
|
||||
self.assertEqual(TRANSCRIPTION_1, transcriptions[1])
|
||||
|
||||
Reference in New Issue
Block a user