From 2d4e182294e96f0d3df64db8d638dc3f71ac83a1 Mon Sep 17 00:00:00 2001 From: Cole Sutyak <76934850+bigfoot1144@users.noreply.github.com> Date: Sun, 23 Jul 2023 15:00:16 -0400 Subject: [PATCH] change fetch to allow for local file selection (#1309) --- examples/efficientnet.py | 5 +---- extra/utils.py | 4 ++-- test/extra/test_utils.py | 22 ++++++++++++++++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 412814c76d..95e060bd7a 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -85,10 +85,7 @@ if __name__ == "__main__": cap.release() cv2.destroyAllWindows() else: - if url.startswith('http'): - img = Image.open(io.BytesIO(fetch(url))) - else: - img = Image.open(url) + img = Image.open(io.BytesIO(fetch(url))) st = time.time() out, _ = infer(model, img) print(np.argmax(out), np.max(out), lbls[np.argmax(out)]) diff --git a/extra/utils.py b/extra/utils.py index 9e2331cf1c..f8ca0b0f96 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -14,7 +14,7 @@ WINDOWS = platform.system() == "Windows" def temp(x:str) -> str: return os.path.join(tempfile.gettempdir(), x) def fetch(url): - if url.startswith("/"): + if url.startswith("/") or url.startswith("."): with open(url, "rb") as f: return f.read() import hashlib @@ -24,7 +24,7 @@ def fetch(url): return f.read() def fetch_as_file(url): - if url.startswith("/"): + if url.startswith("/") or url.startswith("."): with open(url, "rb") as f: return f.read() import hashlib diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py index 62d4c3c0b6..6350403ade 100644 --- a/test/extra/test_utils.py +++ b/test/extra/test_utils.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import io, unittest import os +import tempfile from unittest.mock import patch, MagicMock import torch @@ -25,6 +26,27 @@ class TestFetch(unittest.TestCase): pimg = Image.open(io.BytesIO(img)) assert pimg.size == (705, 1024) +class TestFetchRelative(unittest.TestCase): + def setUp(self): + self.working_dir = os.getcwd() + self.tempdir = tempfile.TemporaryDirectory() + os.chdir(self.tempdir.name) + with open('test_file.txt', 'x') as f: + f.write("12345") + + def tearDown(self): + os.chdir(self.working_dir) + self.tempdir.cleanup() + + #test ./ + def test_fetch_relative_dotslash(self): + self.assertEqual(b'12345', fetch("./test_file.txt")) + + #test ../ + def test_fetch_relative_dotdotslash(self): + os.mkdir('test_file_path') + os.chdir('test_file_path') + self.assertEqual(b'12345', fetch("../test_file.txt")) class TestDownloadFile(unittest.TestCase): def setUp(self):