diff --git a/extra/datasets/__init__.py b/extra/datasets/__init__.py index 8c263be979..085df38969 100644 --- a/extra/datasets/__init__.py +++ b/extra/datasets/__init__.py @@ -2,8 +2,7 @@ import os, gzip, tarfile, pickle import numpy as np from pathlib import Path from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes -from extra.utils import download_file +from tinygrad.helpers import dtypes, fetch def fetch_mnist(tensors=False): parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy() @@ -36,8 +35,7 @@ def fetch_cifar(): assert idx == X.shape[0] and X.shape[0] == Y.shape[0] print("downloading and extracting CIFAR...") - fn = Path(__file__).parent.resolve() / "cifar-10-python.tar.gz" - download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn) + fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz') tt = tarfile.open(fn, mode='r:gz') _load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]) _load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]) diff --git a/extra/models/convnext.py b/extra/models/convnext.py index d9d58ad8a6..3aa08ed90f 100644 --- a/extra/models/convnext.py +++ b/extra/models/convnext.py @@ -1,5 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear +from tinygrad.helpers import fetch class Block: def __init__(self, dim): @@ -43,11 +44,12 @@ versions = { def get_model(version, load_weights=False): model = ConvNeXt(**versions[version]) if load_weights: - from extra.utils import fetch, fake_torch_load, get_child - weights = fake_torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model'] + from extra.utils import get_child + from tinygrad.nn.state import torch_load + weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model'] for k,v in weights.items(): mv = get_child(model, k) - mv.assign(v.reshape(mv.shape)).realize() + mv.assign(v.reshape(mv.shape).to(mv.device)).realize() return model if __name__ == "__main__": diff --git a/extra/utils.py b/extra/utils.py index 4c9ace2d15..b261831136 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -47,152 +47,6 @@ def download_file(url, fp, skip_if_exists=True): f.close() Path(f.name).rename(fp) -def my_unpickle(fb0): - key_prelookup = defaultdict(list) - def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): - #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) - ident, storage_type, obj_key, location, obj_size = storage[0:5] - assert ident == 'storage' - assert prod(size) <= (obj_size - storage_offset) - - if storage_type not in [np.float16, np.float32]: - if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}") - ret = None - else: - ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type)) - key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset)) - return ret - - def _rebuild_parameter(*args): - #print(args) - pass - - class Dummy: pass - - class MyPickle(pickle.Unpickler): - def find_class(self, module, name): - #print(module, name) - if name == 'FloatStorage': return np.float32 - if name == 'LongStorage': return np.int64 - if name == 'IntStorage': return np.int32 - if name == 'HalfStorage': return np.float16 - if module == "torch._utils": - if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2 - if name == "_rebuild_parameter": return _rebuild_parameter - else: - if module.startswith('pytorch_lightning'): return Dummy - try: - return super().find_class(module, name) - except Exception: - return Dummy - - def persistent_load(self, pid): - return pid - - return MyPickle(fb0).load(), key_prelookup - -def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False): - bytes_size = np.dtype(dtype).itemsize - if t is None: - myfile.seek(prod(shape) * bytes_size, 1) - return - - bytes_offset = 0 - if storage_offset is not None: - bytes_offset = storage_offset * bytes_size - myfile.seek(bytes_offset) - - assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}" - assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size - if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)): - # slow path - buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape))) - buffer_size += t.dtype.itemsize - np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np) - - np_array = np.lib.stride_tricks.as_strided( - np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides]) - - lna = t.lazydata.op.arg - lna.fxn = lambda _: np_array - t.realize() - return - - # ["METAL", "CLANG", "LLVM"] support readinto for more speed - # ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway - # this needs real APIs - if t.device in ["METAL", "CLANG", "LLVM"]: - del t.lazydata.op - t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype) - myfile.readinto(t.lazydata.realized._buffer()) - else: - def _mmap(lna): - assert myfile._compress_type == 0, "compressed data can't be mmaped" - return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape) - def _read(lna): - ret = np.empty(lna.shape, dtype=lna.dtype) - myfile.readinto(ret.data) - return ret - if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap - else: t.lazydata.op.arg.fxn = _read - t.realize() - -def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True): - if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap - with zipfile.ZipFile(fb0, 'r') as myzip: - base_name = myzip.namelist()[0].split('/', 1)[0] - with myzip.open(f'{base_name}/data.pkl') as myfile: - ret = my_unpickle(myfile) - if load_weights: - def load_weight(k, vv): - with myzip.open(f'{base_name}/data/{k}') as myfile: - for v in vv: - load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True) - if multithreaded: - # 2 seems fastest - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()} - for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))): - if future.exception() is not None: raise future.exception() - k = futures[future] - t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") - else: - for k,v in (t := tqdm(ret[1].items())): - t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") - load_weight(k,v) - return ret[0] - -def fake_torch_load(b0): - - # convert it to a file - fb0 = io.BytesIO(b0) - - if b0[0:2] == b"\x50\x4b": - return fake_torch_load_zipped(fb0) - - # skip three junk pickles - pickle.load(fb0) - pickle.load(fb0) - pickle.load(fb0) - - ret, key_prelookup = my_unpickle(fb0) - - # create key_lookup - key_lookup = pickle.load(fb0) - key_real = [None] * len(key_lookup) - for k,v in key_prelookup.items(): - assert len(v) == 1 - key_real[key_lookup.index(k)] = v[0] - - # read in the actual data - for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real: - ll = struct.unpack("Q", fb0.read(8))[0] - assert ll == obj_size, f"size mismatch {ll} != {obj_size}" - assert storage_offset == 0, "not implemented" - load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None) - - return ret - def get_child(parent, key): obj = parent for k in key.split('.'): diff --git a/setup.py b/setup.py index b411371ffe..06404d9e5c 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ setup(name='tinygrad', "pre-commit", "ruff", "types-tqdm", + "types-requests", ], 'testing': [ "torch", diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 4d4fd5155d..c4f0582b88 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,7 @@ -import unittest +import unittest, io import numpy as np -from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up +from PIL import Image +from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up, fetch from tinygrad.shape.symbolic import Variable, NumNode VARIABLE = ContextVar("VARIABLE", 0) @@ -147,5 +148,17 @@ class TestRoundUp(unittest.TestCase): self.assertEqual(round_up(232, 24984), 24984) self.assertEqual(round_up(24984, 232), 25056) +class TestFetch(unittest.TestCase): + def test_fetch_bad_http(self): + self.assertRaises(AssertionError, fetch, 'http://www.google.com/404') + + def test_fetch_small(self): + assert(len(fetch('https://google.com').read_bytes())>0) + + def test_fetch_img(self): + img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190") + with Image.open(img) as pimg: + assert pimg.size == (705, 1024) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index c8cc95de19..1fea84115f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations -import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats +import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, requests, tempfile, pathlib import numpy as np +from tqdm import tqdm from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 from typing_extensions import TypeGuard @@ -219,3 +220,18 @@ def diskcache(func): return diskcache_put(table, key, func(*args, **kwargs)) setattr(wrapper, "__wrapped__", func) return wrapper + +# *** http support *** + +def fetch(url:str) -> pathlib.Path: + fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / hashlib.md5(url.encode('utf-8')).hexdigest() + if not fp.is_file(): + r = requests.get(url, stream=True, timeout=10) + assert r.status_code == 200 + progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url) + (path := fp.parent).mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: + for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk)) + f.close() + pathlib.Path(f.name).rename(fp) + return fp