mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Removed dep of torch, torchaudio, kept librosa only (#1264)
This commit is contained in:
@@ -11,6 +11,8 @@ from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv
|
||||
import tinygrad.nn as nn
|
||||
from tinygrad.tensor import Tensor
|
||||
import itertools
|
||||
import librosa
|
||||
|
||||
# TODO: you have written this fifteen times
|
||||
class MultiHeadAttention:
|
||||
@@ -104,30 +106,22 @@ class Whisper:
|
||||
def __call__(self, mel:Tensor, tokens:Tensor):
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
# TODO: this is tragic. remove this
|
||||
import functools
|
||||
import itertools
|
||||
import torch
|
||||
import torchaudio
|
||||
import librosa
|
||||
RATE = 16000
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_filters(sample_rate, n_fft, n_mels):return torch.tensor(librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels))
|
||||
@functools.lru_cache(None)
|
||||
def get_window(n_fft): return torch.hann_window(n_fft)
|
||||
|
||||
def prep_audio(waveform, sample_rate) -> Tensor:
|
||||
def prep_audio(waveform=None, sr=RATE) -> Tensor:
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
stft = torch.stft(waveform, N_FFT, HOP_LENGTH, window=get_window(N_FFT), return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
mel_spec = get_filters(sample_rate, N_FFT, N_MELS) @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
if waveform is None: waveform = np.zeros(N_FFT, dtype=np.float32)
|
||||
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
|
||||
magnitudes = stft[..., :-1] ** 2
|
||||
mel_spec = librosa.filters.mel(sr=sr, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
#print(waveform.shape, log_spec.shape)
|
||||
return log_spec.numpy()
|
||||
return log_spec
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
|
||||
@@ -175,12 +169,8 @@ def img(x):
|
||||
plt.imshow(x.numpy())
|
||||
plt.show()
|
||||
|
||||
RATE = 16000
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
def listener(q):
|
||||
prep_audio(torch.zeros(300), RATE)
|
||||
prep_audio()
|
||||
import pyaudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
@@ -205,7 +195,7 @@ if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
# offline
|
||||
waveform, sample_rate = torchaudio.load(sys.argv[1], normalize=True)
|
||||
waveform, sample_rate = librosa.load(sys.argv[1], normalize=True)
|
||||
log_spec = prep_audio(waveform, sample_rate)
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
dat = model.encoder(Tensor(log_spec)).realize()
|
||||
@@ -234,7 +224,7 @@ if __name__ == "__main__":
|
||||
did_read = True
|
||||
if did_read:
|
||||
last_total = total.shape[1]
|
||||
log_spec = prep_audio(torch.Tensor(total), RATE)
|
||||
log_spec = prep_audio(waveform=Tensor(total).numpy(), sr=RATE)
|
||||
encoded_audio = model.encoder(Tensor(log_spec)).realize()
|
||||
out = model.decoder(Tensor([lst]), encoded_audio).realize()
|
||||
idx = out[0,-1].numpy().argmax()
|
||||
|
||||
Reference in New Issue
Block a user