mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
change fetch to allow for local file selection (#1309)
This commit is contained in:
@@ -85,10 +85,7 @@ if __name__ == "__main__":
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
if url.startswith('http'):
|
||||
img = Image.open(io.BytesIO(fetch(url)))
|
||||
else:
|
||||
img = Image.open(url)
|
||||
img = Image.open(io.BytesIO(fetch(url)))
|
||||
st = time.time()
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
|
||||
@@ -14,7 +14,7 @@ WINDOWS = platform.system() == "Windows"
|
||||
def temp(x:str) -> str: return os.path.join(tempfile.gettempdir(), x)
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/"):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
import hashlib
|
||||
@@ -24,7 +24,7 @@ def fetch(url):
|
||||
return f.read()
|
||||
|
||||
def fetch_as_file(url):
|
||||
if url.startswith("/"):
|
||||
if url.startswith("/") or url.startswith("."):
|
||||
with open(url, "rb") as f:
|
||||
return f.read()
|
||||
import hashlib
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import io, unittest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import torch
|
||||
@@ -25,6 +26,27 @@ class TestFetch(unittest.TestCase):
|
||||
pimg = Image.open(io.BytesIO(img))
|
||||
assert pimg.size == (705, 1024)
|
||||
|
||||
class TestFetchRelative(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.working_dir = os.getcwd()
|
||||
self.tempdir = tempfile.TemporaryDirectory()
|
||||
os.chdir(self.tempdir.name)
|
||||
with open('test_file.txt', 'x') as f:
|
||||
f.write("12345")
|
||||
|
||||
def tearDown(self):
|
||||
os.chdir(self.working_dir)
|
||||
self.tempdir.cleanup()
|
||||
|
||||
#test ./
|
||||
def test_fetch_relative_dotslash(self):
|
||||
self.assertEqual(b'12345', fetch("./test_file.txt"))
|
||||
|
||||
#test ../
|
||||
def test_fetch_relative_dotdotslash(self):
|
||||
os.mkdir('test_file_path')
|
||||
os.chdir('test_file_path')
|
||||
self.assertEqual(b'12345', fetch("../test_file.txt"))
|
||||
|
||||
class TestDownloadFile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user