use the TinyJit in the efficientnet runner, 200ms -> 20ms

This commit is contained in:
George Hotz
2023-02-20 19:58:16 -08:00
parent 714bf4b108
commit d9fa47ecc9

View File

@@ -12,9 +12,22 @@ from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import fetch
from extra.jit import TinyJit
from models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)
# TODO: you should be able to put these in the jitted function
bias = Tensor([0.485, 0.456, 0.406])
scale = Tensor([0.229, 0.224, 0.225])
@TinyJit
def _infer(model, img):
img = img.permute((2,0,1))
img = img / 255.0
img = img - bias.reshape((1,-1,1,1))
img = img / scale.reshape((1,-1,1,1))
return model.forward(img).realize()
def infer(model, img):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
@@ -31,15 +44,8 @@ def infer(model, img):
plt.show()
"""
# low level preprocess
img = np.moveaxis(img, [2,0,1], [0,1,2])
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
img /= 255.0
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
# run the net
out = model.forward(Tensor(img)).cpu()
out = _infer(model, Tensor(img)).numpy()
# if you want to look at the outputs
"""
@@ -67,8 +73,9 @@ if __name__ == "__main__":
_ = cap.grab() # discard one frame to circumvent capture buffering
ret, frame = cap.read()
img = Image.fromarray(frame[:, :, [2,1,0]])
lt = time.monotonic_ns()
out, retimg = infer(model, img)
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
SCALE = 3
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
@@ -84,5 +91,5 @@ if __name__ == "__main__":
img = Image.open(url)
st = time.time()
out, _ = infer(model, img)
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
print(f"did inference in {(time.time()-st):2f}")