Removed dep of torch, torchaudio, kept librosa only (#1264)

This commit is contained in:
Paolo Gavazzi
2023-08-02 13:52:04 -04:00
committed by GitHub
parent fc2303e520
commit 9ffa1eb7e2

View File

@@ -11,6 +11,8 @@ from tinygrad.state import torch_load, load_state_dict
from tinygrad.helpers import getenv
import tinygrad.nn as nn
from tinygrad.tensor import Tensor
import itertools
import librosa
# TODO: you have written this fifteen times
class MultiHeadAttention:
@@ -104,30 +106,22 @@ class Whisper:
def __call__(self, mel:Tensor, tokens:Tensor):
return self.decoder(tokens, self.encoder(mel))
# TODO: this is tragic. remove this
import functools
import itertools
import torch
import torchaudio
import librosa
RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
@functools.lru_cache(None)
def get_filters(sample_rate, n_fft, n_mels):return torch.tensor(librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels))
@functools.lru_cache(None)
def get_window(n_fft): return torch.hann_window(n_fft)
def prep_audio(waveform, sample_rate) -> Tensor:
def prep_audio(waveform=None, sr=RATE) -> Tensor:
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
stft = torch.stft(waveform, N_FFT, HOP_LENGTH, window=get_window(N_FFT), return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = get_filters(sample_rate, N_FFT, N_MELS) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if waveform is None: waveform = np.zeros(N_FFT, dtype=np.float32)
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
magnitudes = stft[..., :-1] ** 2
mel_spec = librosa.filters.mel(sr=sr, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
log_spec = (log_spec + 4.0) / 4.0
#print(waveform.shape, log_spec.shape)
return log_spec.numpy()
return log_spec
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
@@ -175,12 +169,8 @@ def img(x):
plt.imshow(x.numpy())
plt.show()
RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
def listener(q):
prep_audio(torch.zeros(300), RATE)
prep_audio()
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
@@ -205,7 +195,7 @@ if __name__ == "__main__":
if len(sys.argv) > 1:
# offline
waveform, sample_rate = torchaudio.load(sys.argv[1], normalize=True)
waveform, sample_rate = librosa.load(sys.argv[1], normalize=True)
log_spec = prep_audio(waveform, sample_rate)
lst = [enc._special_tokens["<|startoftranscript|>"]]
dat = model.encoder(Tensor(log_spec)).realize()
@@ -234,7 +224,7 @@ if __name__ == "__main__":
did_read = True
if did_read:
last_total = total.shape[1]
log_spec = prep_audio(torch.Tensor(total), RATE)
log_spec = prep_audio(waveform=Tensor(total).numpy(), sr=RATE)
encoded_audio = model.encoder(Tensor(log_spec)).realize()
out = model.decoder(Tensor([lst]), encoded_audio).realize()
idx = out[0,-1].numpy().argmax()