Files
tinygrad/examples/mlperf/model_spec.py
wozeparrot 01ae45a43c Add mlperf RNN-T model (#782)
* feat: initial rnn-t

* feat: working with BS>1

* feat: add lstm test

* feat: test passing hidden

* clean: cleanup

* feat: specify start

* feat: way faster lstm & model

* fix: default batch size

* feat: optimization

* fix: fix metrics

* fix: fix feature splicing

* feat: cleaner stacktime

* clean: remove unused import

* clean: remove extra prints

* fix: fix tests and happy llvm

* feat: have the librispeech dataset in its own dir

* clean: unused variable

* feat: no longer need numpy for the embedding + slightly more memory efficient lstm

* fix: forgot to remove something that broke tests

* feat: use relative paths

* feat: even faster

* feat: remove pointless transposes in StackTime

* fix: correct forward

* feat: switch to soundfile for loading and fix some leaks

* feat: add comment about initial dataset setup

* feat: jit more things

* feat: default batch size back to 1

larger than 1 is broken again :(
and even in the reference implementation it gives worse results
2023-05-25 00:41:21 -07:00

40 lines
968 B
Python

# load each model here, quick benchmark
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters
def test_model(model, *inputs):
GlobalCounters.reset()
model(*inputs).numpy()
# TODO: return event future to still get the time_sum_s without DEBUG=2
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
if __name__ == "__main__":
# inference only for now
Tensor.training = False
Tensor.no_grad = True
# Resnet50-v1.5
from models.resnet import ResNet50
mdl = ResNet50()
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
# Retinanet
# 3D UNET
from models.unet3d import UNet3D
mdl = UNet3D()
#mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 5, 224, 224)
test_model(mdl, img)
# RNNT
from models.rnnt import RNNT
mdl = RNNT()
mdl.load_from_pretrained()
x = Tensor.randn(220, 1, 240)
y = Tensor.randn(1, 220)
test_model(mdl, x, y)
# BERT-large