mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
single model (#1560)
This commit is contained in:
10
test/external/external_model_benchmark.py
vendored
10
test/external/external_model_benchmark.py
vendored
@@ -39,7 +39,7 @@ def benchmark(mnm, nm, fxn):
|
||||
st = time.perf_counter_ns()
|
||||
ret = fxn()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
print(f"{m:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
|
||||
print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
|
||||
CSV[nm] = min(tms)*1e-6
|
||||
return min(tms), ret
|
||||
|
||||
@@ -63,7 +63,7 @@ def benchmark_model(m, validate_outs=False):
|
||||
# print input names
|
||||
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
|
||||
|
||||
for device in ["METAL" if OSX else "GPU", "CLANG"]:
|
||||
for device in ["METAL" if OSX else "GPU", "CLANG"]: # + (["CUDA"] if torch.cuda.is_available() else []):
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
@@ -95,7 +95,7 @@ def benchmark_model(m, validate_outs=False):
|
||||
provider = backend+"ExecutionProvider"
|
||||
if provider not in ort.get_available_providers(): continue
|
||||
ort_sess = ort.InferenceSession(str(fn), ort_options, [provider])
|
||||
benchmark(m, f"onnxruntime_{backend}", lambda: ort_sess.run(output_names, np_inputs))
|
||||
benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs))
|
||||
del ort_sess
|
||||
|
||||
if validate_outs:
|
||||
@@ -124,4 +124,6 @@ def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5):
|
||||
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
for m in MODELS: benchmark_model(m, True)
|
||||
if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), True)
|
||||
else:
|
||||
for m in MODELS: benchmark_model(m, True)
|
||||
|
||||
Reference in New Issue
Block a user