mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-12 07:55:14 +08:00
use the TinyJit in the efficientnet runner, 200ms -> 20ms
This commit is contained in:
@@ -12,9 +12,22 @@ from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch
|
||||
from extra.jit import TinyJit
|
||||
from models.efficientnet import EfficientNet
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
# TODO: you should be able to put these in the jitted function
|
||||
bias = Tensor([0.485, 0.456, 0.406])
|
||||
scale = Tensor([0.229, 0.224, 0.225])
|
||||
|
||||
@TinyJit
|
||||
def _infer(model, img):
|
||||
img = img.permute((2,0,1))
|
||||
img = img / 255.0
|
||||
img = img - bias.reshape((1,-1,1,1))
|
||||
img = img / scale.reshape((1,-1,1,1))
|
||||
return model.forward(img).realize()
|
||||
|
||||
def infer(model, img):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
@@ -31,15 +44,8 @@ def infer(model, img):
|
||||
plt.show()
|
||||
"""
|
||||
|
||||
# low level preprocess
|
||||
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
||||
img /= 255.0
|
||||
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
|
||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
||||
|
||||
# run the net
|
||||
out = model.forward(Tensor(img)).cpu()
|
||||
out = _infer(model, Tensor(img)).numpy()
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
@@ -67,8 +73,9 @@ if __name__ == "__main__":
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
img = Image.fromarray(frame[:, :, [2,1,0]])
|
||||
lt = time.monotonic_ns()
|
||||
out, retimg = infer(model, img)
|
||||
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
|
||||
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
@@ -84,5 +91,5 @@ if __name__ == "__main__":
|
||||
img = Image.open(url)
|
||||
st = time.time()
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out.numpy()), np.max(out.numpy()), lbls[np.argmax(out.numpy())])
|
||||
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
print(f"did inference in {(time.time()-st):2f}")
|
||||
|
||||
Reference in New Issue
Block a user