move fetch to helpers (#2363)

* switch datasets to new fetch

* add test_helpers

* fix convnext and delete old torch load
This commit is contained in:
George Hotz
2023-11-19 12:29:51 -08:00
committed by GitHub
parent 03968622a2
commit a0890f4e6c
6 changed files with 40 additions and 156 deletions

View File

@@ -2,8 +2,7 @@ import os, gzip, tarfile, pickle
import numpy as np
from pathlib import Path
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from extra.utils import download_file
from tinygrad.helpers import dtypes, fetch
def fetch_mnist(tensors=False):
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
@@ -36,8 +35,7 @@ def fetch_cifar():
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
print("downloading and extracting CIFAR...")
fn = Path(__file__).parent.resolve() / "cifar-10-python.tar.gz"
download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn)
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
tt = tarfile.open(fn, mode='r:gz')
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])

View File

@@ -1,5 +1,6 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
from tinygrad.helpers import fetch
class Block:
def __init__(self, dim):
@@ -43,11 +44,12 @@ versions = {
def get_model(version, load_weights=False):
model = ConvNeXt(**versions[version])
if load_weights:
from extra.utils import fetch, fake_torch_load, get_child
weights = fake_torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
from extra.utils import get_child
from tinygrad.nn.state import torch_load
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
for k,v in weights.items():
mv = get_child(model, k)
mv.assign(v.reshape(mv.shape)).realize()
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
return model
if __name__ == "__main__":

View File

@@ -47,152 +47,6 @@ def download_file(url, fp, skip_if_exists=True):
f.close()
Path(f.name).rename(fp)
def my_unpickle(fb0):
key_prelookup = defaultdict(list)
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
ident, storage_type, obj_key, location, obj_size = storage[0:5]
assert ident == 'storage'
assert prod(size) <= (obj_size - storage_offset)
if storage_type not in [np.float16, np.float32]:
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
ret = None
else:
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
return ret
def _rebuild_parameter(*args):
#print(args)
pass
class Dummy: pass
class MyPickle(pickle.Unpickler):
def find_class(self, module, name):
#print(module, name)
if name == 'FloatStorage': return np.float32
if name == 'LongStorage': return np.int64
if name == 'IntStorage': return np.int32
if name == 'HalfStorage': return np.float16
if module == "torch._utils":
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
if name == "_rebuild_parameter": return _rebuild_parameter
else:
if module.startswith('pytorch_lightning'): return Dummy
try:
return super().find_class(module, name)
except Exception:
return Dummy
def persistent_load(self, pid):
return pid
return MyPickle(fb0).load(), key_prelookup
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
bytes_size = np.dtype(dtype).itemsize
if t is None:
myfile.seek(prod(shape) * bytes_size, 1)
return
bytes_offset = 0
if storage_offset is not None:
bytes_offset = storage_offset * bytes_size
myfile.seek(bytes_offset)
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
# slow path
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
buffer_size += t.dtype.itemsize
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
np_array = np.lib.stride_tricks.as_strided(
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
lna = t.lazydata.op.arg
lna.fxn = lambda _: np_array
t.realize()
return
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
# this needs real APIs
if t.device in ["METAL", "CLANG", "LLVM"]:
del t.lazydata.op
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
myfile.readinto(t.lazydata.realized._buffer())
else:
def _mmap(lna):
assert myfile._compress_type == 0, "compressed data can't be mmaped"
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
def _read(lna):
ret = np.empty(lna.shape, dtype=lna.dtype)
myfile.readinto(ret.data)
return ret
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
else: t.lazydata.op.arg.fxn = _read
t.realize()
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
with zipfile.ZipFile(fb0, 'r') as myzip:
base_name = myzip.namelist()[0].split('/', 1)[0]
with myzip.open(f'{base_name}/data.pkl') as myfile:
ret = my_unpickle(myfile)
if load_weights:
def load_weight(k, vv):
with myzip.open(f'{base_name}/data/{k}') as myfile:
for v in vv:
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
if multithreaded:
# 2 seems fastest
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
if future.exception() is not None: raise future.exception()
k = futures[future]
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
else:
for k,v in (t := tqdm(ret[1].items())):
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
load_weight(k,v)
return ret[0]
def fake_torch_load(b0):
# convert it to a file
fb0 = io.BytesIO(b0)
if b0[0:2] == b"\x50\x4b":
return fake_torch_load_zipped(fb0)
# skip three junk pickles
pickle.load(fb0)
pickle.load(fb0)
pickle.load(fb0)
ret, key_prelookup = my_unpickle(fb0)
# create key_lookup
key_lookup = pickle.load(fb0)
key_real = [None] * len(key_lookup)
for k,v in key_prelookup.items():
assert len(v) == 1
key_real[key_lookup.index(k)] = v[0]
# read in the actual data
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
ll = struct.unpack("Q", fb0.read(8))[0]
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
assert storage_offset == 0, "not implemented"
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
return ret
def get_child(parent, key):
obj = parent
for k in key.split('.'):

View File

@@ -38,6 +38,7 @@ setup(name='tinygrad',
"pre-commit",
"ruff",
"types-tqdm",
"types-requests",
],
'testing': [
"torch",

View File

@@ -1,6 +1,7 @@
import unittest
import unittest, io
import numpy as np
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up
from PIL import Image
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up, fetch
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)
@@ -147,5 +148,17 @@ class TestRoundUp(unittest.TestCase):
self.assertEqual(round_up(232, 24984), 24984)
self.assertEqual(round_up(24984, 232), 25056)
class TestFetch(unittest.TestCase):
def test_fetch_bad_http(self):
self.assertRaises(AssertionError, fetch, 'http://www.google.com/404')
def test_fetch_small(self):
assert(len(fetch('https://google.com').read_bytes())>0)
def test_fetch_img(self):
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
with Image.open(img) as pimg:
assert pimg.size == (705, 1024)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, requests, tempfile, pathlib
import numpy as np
from tqdm import tqdm
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
@@ -219,3 +220,18 @@ def diskcache(func):
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper
# *** http support ***
def fetch(url:str) -> pathlib.Path:
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / hashlib.md5(url.encode('utf-8')).hexdigest()
if not fp.is_file():
r = requests.get(url, stream=True, timeout=10)
assert r.status_code == 200
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
(path := fp.parent).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk))
f.close()
pathlib.Path(f.name).rename(fp)
return fp