mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix external test + speed
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -132,7 +132,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[gpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Optimizer Test
|
||||
run: GPU=1 python test/external_test_opt.py
|
||||
run: GPU=1 python test/external/external_test_opt.py
|
||||
- name: Run Pytest (default)
|
||||
run: GPU=1 python -m pytest -s -v -n=auto
|
||||
|
||||
|
||||
@@ -102,9 +102,17 @@ class TestOnnxModel(unittest.TestCase):
|
||||
input_name, input_new = "images:0", True
|
||||
self._test_model(dat, input_name, input_new)
|
||||
|
||||
def test_shufflenet(self):
|
||||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx")
|
||||
print(f"shufflenet downloaded : {len(dat)/1e6:.2f} MB")
|
||||
input_name, input_new = "gpu_0/data_0", False
|
||||
self._test_model(dat, input_name, input_new)
|
||||
|
||||
@unittest.skip("test is very slow")
|
||||
def test_resnet(self):
|
||||
# NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size
|
||||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx")
|
||||
print(f"resnet downloaded : {len(dat)/1e6:.2f} MB")
|
||||
input_name, input_new = "data", False
|
||||
self._test_model(dat, input_name, input_new)
|
||||
|
||||
@@ -124,7 +132,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
assert _LABELS[cls] == "hen" or _LABELS[cls] == "cock"
|
||||
cls = run(car_img)
|
||||
print(cls, _LABELS[cls])
|
||||
assert "car" in _LABELS[cls]
|
||||
assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user