From 9ffa1eb7e2ce6f7fe58b5fc1387b44e01c516be9 Mon Sep 17 00:00:00 2001 From: Paolo Gavazzi Date: Wed, 2 Aug 2023 13:52:04 -0400 Subject: [PATCH] Removed dep of torch, torchaudio, kept librosa only (#1264) --- examples/whisper.py | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/examples/whisper.py b/examples/whisper.py index 54e341fed2..1692071127 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -11,6 +11,8 @@ from tinygrad.state import torch_load, load_state_dict from tinygrad.helpers import getenv import tinygrad.nn as nn from tinygrad.tensor import Tensor +import itertools +import librosa # TODO: you have written this fifteen times class MultiHeadAttention: @@ -104,30 +106,22 @@ class Whisper: def __call__(self, mel:Tensor, tokens:Tensor): return self.decoder(tokens, self.encoder(mel)) -# TODO: this is tragic. remove this -import functools -import itertools -import torch -import torchaudio -import librosa +RATE = 16000 +CHUNK = 1600 +RECORD_SECONDS = 10 -@functools.lru_cache(None) -def get_filters(sample_rate, n_fft, n_mels):return torch.tensor(librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels)) -@functools.lru_cache(None) -def get_window(n_fft): return torch.hann_window(n_fft) - -def prep_audio(waveform, sample_rate) -> Tensor: +def prep_audio(waveform=None, sr=RATE) -> Tensor: N_FFT = 400 HOP_LENGTH = 160 N_MELS = 80 - stft = torch.stft(waveform, N_FFT, HOP_LENGTH, window=get_window(N_FFT), return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 - mel_spec = get_filters(sample_rate, N_FFT, N_MELS) @ magnitudes - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + if waveform is None: waveform = np.zeros(N_FFT, dtype=np.float32) + stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32) + magnitudes = stft[..., :-1] ** 2 + mel_spec = librosa.filters.mel(sr=sr, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes + log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8)) log_spec = (log_spec + 4.0) / 4.0 #print(waveform.shape, log_spec.shape) - return log_spec.numpy() + return log_spec LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", @@ -175,12 +169,8 @@ def img(x): plt.imshow(x.numpy()) plt.show() -RATE = 16000 -CHUNK = 1600 -RECORD_SECONDS = 10 - def listener(q): - prep_audio(torch.zeros(300), RATE) + prep_audio() import pyaudio p = pyaudio.PyAudio() stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK) @@ -205,7 +195,7 @@ if __name__ == "__main__": if len(sys.argv) > 1: # offline - waveform, sample_rate = torchaudio.load(sys.argv[1], normalize=True) + waveform, sample_rate = librosa.load(sys.argv[1], normalize=True) log_spec = prep_audio(waveform, sample_rate) lst = [enc._special_tokens["<|startoftranscript|>"]] dat = model.encoder(Tensor(log_spec)).realize() @@ -234,7 +224,7 @@ if __name__ == "__main__": did_read = True if did_read: last_total = total.shape[1] - log_spec = prep_audio(torch.Tensor(total), RATE) + log_spec = prep_audio(waveform=Tensor(total).numpy(), sr=RATE) encoded_audio = model.encoder(Tensor(log_spec)).realize() out = model.decoder(Tensor([lst]), encoded_audio).realize() idx = out[0,-1].numpy().argmax()