mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
move fetch to helpers (#2363)
* switch datasets to new fetch * add test_helpers * fix convnext and delete old torch load
This commit is contained in:
@@ -2,8 +2,7 @@ import os, gzip, tarfile, pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
from extra.utils import download_file
|
||||
from tinygrad.helpers import dtypes, fetch
|
||||
|
||||
def fetch_mnist(tensors=False):
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
@@ -36,8 +35,7 @@ def fetch_cifar():
|
||||
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
||||
|
||||
print("downloading and extracting CIFAR...")
|
||||
fn = Path(__file__).parent.resolve() / "cifar-10-python.tar.gz"
|
||||
download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn)
|
||||
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
|
||||
tt = tarfile.open(fn, mode='r:gz')
|
||||
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
|
||||
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
@@ -43,11 +44,12 @@ versions = {
|
||||
def get_model(version, load_weights=False):
|
||||
model = ConvNeXt(**versions[version])
|
||||
if load_weights:
|
||||
from extra.utils import fetch, fake_torch_load, get_child
|
||||
weights = fake_torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
||||
from extra.utils import get_child
|
||||
from tinygrad.nn.state import torch_load
|
||||
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
||||
for k,v in weights.items():
|
||||
mv = get_child(model, k)
|
||||
mv.assign(v.reshape(mv.shape)).realize()
|
||||
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
146
extra/utils.py
146
extra/utils.py
@@ -47,152 +47,6 @@ def download_file(url, fp, skip_if_exists=True):
|
||||
f.close()
|
||||
Path(f.name).rename(fp)
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = defaultdict(list)
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
ident, storage_type, obj_key, location, obj_size = storage[0:5]
|
||||
assert ident == 'storage'
|
||||
assert prod(size) <= (obj_size - storage_offset)
|
||||
|
||||
if storage_type not in [np.float16, np.float32]:
|
||||
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
|
||||
ret = None
|
||||
else:
|
||||
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
|
||||
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
|
||||
return ret
|
||||
|
||||
def _rebuild_parameter(*args):
|
||||
#print(args)
|
||||
pass
|
||||
|
||||
class Dummy: pass
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#print(module, name)
|
||||
if name == 'FloatStorage': return np.float32
|
||||
if name == 'LongStorage': return np.int64
|
||||
if name == 'IntStorage': return np.int32
|
||||
if name == 'HalfStorage': return np.float16
|
||||
if module == "torch._utils":
|
||||
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
|
||||
if name == "_rebuild_parameter": return _rebuild_parameter
|
||||
else:
|
||||
if module.startswith('pytorch_lightning'): return Dummy
|
||||
try:
|
||||
return super().find_class(module, name)
|
||||
except Exception:
|
||||
return Dummy
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
return MyPickle(fb0).load(), key_prelookup
|
||||
|
||||
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
|
||||
bytes_size = np.dtype(dtype).itemsize
|
||||
if t is None:
|
||||
myfile.seek(prod(shape) * bytes_size, 1)
|
||||
return
|
||||
|
||||
bytes_offset = 0
|
||||
if storage_offset is not None:
|
||||
bytes_offset = storage_offset * bytes_size
|
||||
myfile.seek(bytes_offset)
|
||||
|
||||
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
|
||||
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
|
||||
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
|
||||
# slow path
|
||||
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
|
||||
buffer_size += t.dtype.itemsize
|
||||
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
|
||||
|
||||
np_array = np.lib.stride_tricks.as_strided(
|
||||
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
|
||||
|
||||
lna = t.lazydata.op.arg
|
||||
lna.fxn = lambda _: np_array
|
||||
t.realize()
|
||||
return
|
||||
|
||||
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
|
||||
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
|
||||
# this needs real APIs
|
||||
if t.device in ["METAL", "CLANG", "LLVM"]:
|
||||
del t.lazydata.op
|
||||
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized._buffer())
|
||||
else:
|
||||
def _mmap(lna):
|
||||
assert myfile._compress_type == 0, "compressed data can't be mmaped"
|
||||
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
|
||||
def _read(lna):
|
||||
ret = np.empty(lna.shape, dtype=lna.dtype)
|
||||
myfile.readinto(ret.data)
|
||||
return ret
|
||||
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
|
||||
else: t.lazydata.op.arg.fxn = _read
|
||||
t.realize()
|
||||
|
||||
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
|
||||
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
|
||||
with zipfile.ZipFile(fb0, 'r') as myzip:
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
ret = my_unpickle(myfile)
|
||||
if load_weights:
|
||||
def load_weight(k, vv):
|
||||
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
||||
for v in vv:
|
||||
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
|
||||
if multithreaded:
|
||||
# 2 seems fastest
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
|
||||
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
|
||||
if future.exception() is not None: raise future.exception()
|
||||
k = futures[future]
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
else:
|
||||
for k,v in (t := tqdm(ret[1].items())):
|
||||
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
load_weight(k,v)
|
||||
return ret[0]
|
||||
|
||||
def fake_torch_load(b0):
|
||||
|
||||
# convert it to a file
|
||||
fb0 = io.BytesIO(b0)
|
||||
|
||||
if b0[0:2] == b"\x50\x4b":
|
||||
return fake_torch_load_zipped(fb0)
|
||||
|
||||
# skip three junk pickles
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
|
||||
ret, key_prelookup = my_unpickle(fb0)
|
||||
|
||||
# create key_lookup
|
||||
key_lookup = pickle.load(fb0)
|
||||
key_real = [None] * len(key_lookup)
|
||||
for k,v in key_prelookup.items():
|
||||
assert len(v) == 1
|
||||
key_real[key_lookup.index(k)] = v[0]
|
||||
|
||||
# read in the actual data
|
||||
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
|
||||
ll = struct.unpack("Q", fb0.read(8))[0]
|
||||
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
|
||||
assert storage_offset == 0, "not implemented"
|
||||
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
|
||||
|
||||
return ret
|
||||
|
||||
def get_child(parent, key):
|
||||
obj = parent
|
||||
for k in key.split('.'):
|
||||
|
||||
1
setup.py
1
setup.py
@@ -38,6 +38,7 @@ setup(name='tinygrad',
|
||||
"pre-commit",
|
||||
"ruff",
|
||||
"types-tqdm",
|
||||
"types-requests",
|
||||
],
|
||||
'testing': [
|
||||
"torch",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
import unittest, io
|
||||
import numpy as np
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up
|
||||
from PIL import Image
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up, fetch
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
@@ -147,5 +148,17 @@ class TestRoundUp(unittest.TestCase):
|
||||
self.assertEqual(round_up(232, 24984), 24984)
|
||||
self.assertEqual(round_up(24984, 232), 25056)
|
||||
|
||||
class TestFetch(unittest.TestCase):
|
||||
def test_fetch_bad_http(self):
|
||||
self.assertRaises(AssertionError, fetch, 'http://www.google.com/404')
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com').read_bytes())>0)
|
||||
|
||||
def test_fetch_img(self):
|
||||
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
|
||||
with Image.open(img) as pimg:
|
||||
assert pimg.size == (705, 1024)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, requests, tempfile, pathlib
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
@@ -219,3 +220,18 @@ def diskcache(func):
|
||||
return diskcache_put(table, key, func(*args, **kwargs))
|
||||
setattr(wrapper, "__wrapped__", func)
|
||||
return wrapper
|
||||
|
||||
# *** http support ***
|
||||
|
||||
def fetch(url:str) -> pathlib.Path:
|
||||
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / hashlib.md5(url.encode('utf-8')).hexdigest()
|
||||
if not fp.is_file():
|
||||
r = requests.get(url, stream=True, timeout=10)
|
||||
assert r.status_code == 200
|
||||
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
||||
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
return fp
|
||||
|
||||
Reference in New Issue
Block a user