mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Fast DiskTensor to other Tensor (#916)
* make disktensors fast * loading * loader for sd and llama
This commit is contained in:
@@ -207,6 +207,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
chatbot = args.prompt == None
|
||||
|
||||
"""
|
||||
# load model (you have to find the weights yourself)
|
||||
from extra.utils import fake_torch_load_zipped, get_child
|
||||
|
||||
@@ -262,18 +263,12 @@ if __name__ == "__main__":
|
||||
get_child(model, k).assign(v).realize()
|
||||
|
||||
del weights
|
||||
"""
|
||||
|
||||
# disktensor loader isn't fast yet
|
||||
"""
|
||||
from tinygrad.state import torch_load, get_state_dict
|
||||
state_dict = torch_load(WEIGHTS_7B_FILENAME)
|
||||
model = Transformer(**args_7B)
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
for k,v in (t := tqdm(get_state_dict(model).items())):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k}")
|
||||
if k not in state_dict: continue
|
||||
v.assign(state_dict[k].to(v.device)).realize()
|
||||
"""
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
load_state_dict(model, torch_load(WEIGHTS_7B_FILENAME), strict=False)
|
||||
|
||||
# *** prompt engineers work here ****
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@
|
||||
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
||||
|
||||
from pathlib import Path
|
||||
import gzip
|
||||
import argparse
|
||||
import math
|
||||
import re
|
||||
import gzip, argparse, math, re
|
||||
from functools import lru_cache
|
||||
from collections import namedtuple
|
||||
|
||||
@@ -14,7 +11,8 @@ from tqdm import tqdm
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm
|
||||
from extra.utils import fake_torch_load_zipped, get_child, download_file
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
|
||||
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
|
||||
|
||||
@@ -613,24 +611,10 @@ if __name__ == "__main__":
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file(
|
||||
'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
|
||||
FILENAME
|
||||
)
|
||||
dat = fake_torch_load_zipped(open(FILENAME, "rb"))
|
||||
for k,v in dat['state_dict'].items():
|
||||
try:
|
||||
w = get_child(model, k)
|
||||
except (AttributeError, KeyError, IndexError):
|
||||
#traceback.print_exc()
|
||||
w = None
|
||||
#print(f"{str(v.shape):30s}" if v is not None else v, w.shape if w is not None else w, k)
|
||||
if w is not None:
|
||||
assert w.shape == v.shape and w.dtype == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype} != {v.dtype}"
|
||||
w.assign(v)
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
|
||||
# run through CLIP to get context
|
||||
|
||||
tokenizer = ClipTokenizer()
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
|
||||
@@ -10,11 +10,19 @@
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
//#define FN "/dev/nvme0n1"
|
||||
#define FN "../../weights/LLaMA/7B/consolidated.00.pth"
|
||||
|
||||
#define SZ (unsigned long long)(512*1024*1024)
|
||||
#define CNT 10LL
|
||||
|
||||
void test_read() {
|
||||
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
|
||||
#ifdef O_DIRECT
|
||||
int f = open(FN, O_RDONLY|O_DIRECT);
|
||||
#else
|
||||
int f = open(FN, O_RDONLY);
|
||||
//fcntl(f, F_NOCACHE, 1);
|
||||
#endif
|
||||
printf("open %d\n", f);
|
||||
|
||||
/*void *buf = malloc(CNT*SZ);
|
||||
@@ -42,7 +50,11 @@ void test_read() {
|
||||
}
|
||||
|
||||
void test_mmap() {
|
||||
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
|
||||
#ifdef O_DIRECT
|
||||
int f = open(FN, O_RDONLY|O_DIRECT);
|
||||
#else
|
||||
int f = open(FN, O_RDONLY);
|
||||
#endif
|
||||
printf("open %d\n", f);
|
||||
|
||||
void *dat = mmap(NULL, SZ*CNT, PROT_READ, MAP_PRIVATE, f, 0);
|
||||
@@ -62,10 +74,13 @@ void test_mmap() {
|
||||
}
|
||||
|
||||
int main() {
|
||||
system("sync; echo 1 > /proc/sys/vm/drop_caches");
|
||||
test_mmap();
|
||||
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
|
||||
//system("sudo purge");
|
||||
//test_mmap();
|
||||
|
||||
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
|
||||
//test_read();
|
||||
system("sudo purge");
|
||||
test_read();
|
||||
test_read();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
import time
|
||||
|
||||
class Timing(object):
|
||||
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
||||
def __enter__(self): self.st = time.perf_counter_ns()
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.et = time.perf_counter_ns() - self.st
|
||||
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
||||
from tinygrad.helpers import Timing
|
||||
|
||||
def enable_early_exec():
|
||||
import subprocess, multiprocessing
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, asdict
|
||||
import os, math, functools
|
||||
import os, math, functools, time
|
||||
import numpy as np
|
||||
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
ShapeType = Tuple[int, ...]
|
||||
@@ -39,6 +39,13 @@ class ContextVar:
|
||||
|
||||
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
|
||||
|
||||
class Timing(object):
|
||||
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
||||
def __enter__(self): self.st = time.perf_counter_ns()
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.et = time.perf_counter_ns() - self.st
|
||||
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
||||
|
||||
# **** tinygrad now supports dtypes! *****
|
||||
|
||||
class DType(NamedTuple):
|
||||
|
||||
@@ -6,7 +6,8 @@ from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -122,6 +123,14 @@ class LazyBuffer:
|
||||
elif self.op.op == LoadOps.CUSTOM:
|
||||
# this needs to immediately realize
|
||||
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
|
||||
elif self.op.op == LoadOps.FROM:
|
||||
rawbuf = self.op.src[0].realize()
|
||||
# TODO: make this generic
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped):
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer())
|
||||
else:
|
||||
self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args())
|
||||
elif self.optype == LoadOps:
|
||||
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
|
||||
if self.op.op == LoadOps.EMPTY:
|
||||
@@ -167,8 +176,8 @@ class LazyBuffer:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple(), arg), dtype)
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
@@ -67,15 +67,6 @@ class LAMB(Optimizer):
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||||
# TODO: remove this
|
||||
from tinygrad.state import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
|
||||
|
||||
@@ -11,11 +11,16 @@ class RawDiskBuffer(RawBufferMapped):
|
||||
self.offset = offset # this is an offset in bytes
|
||||
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
||||
if device is not None:
|
||||
with open(device, "a+b") as f:
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
|
||||
f = open(device, "a+b")
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
|
||||
else:
|
||||
buf[2] += 1
|
||||
# NOTE: we don't call super since disk tensors don't use RAM
|
||||
self.size, self.dtype, self._buf = size, dtype, buf
|
||||
def __del__(self):
|
||||
self._buf[2] -= 1
|
||||
if self._buf[2] == 0: self._buf[0].close()
|
||||
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||
def shrink(self, arg):
|
||||
@@ -23,7 +28,10 @@ class RawDiskBuffer(RawBufferMapped):
|
||||
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
|
||||
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
|
||||
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
|
||||
def _buffer(self): return memoryview(self._buf)[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def readinto(self, buf):
|
||||
self._buf[0].seek(self.offset)
|
||||
self._buf[0].readinto(buf)
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os, json, pathlib, zipfile, pickle
|
||||
from typing import Dict, Union
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Union, List
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters
|
||||
from tinygrad.shape.shapetracker import strides_for_shape
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
@@ -26,8 +27,30 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
|
||||
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
# TODO: move get_state_dict and get_parameters here
|
||||
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
|
||||
# state dict
|
||||
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||||
|
||||
def load_state_dict(model, state_dict, strict=True):
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
for k,v in (t := tqdm(get_state_dict(model).items())):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
|
||||
if k not in state_dict and not strict:
|
||||
if DEBUG >= 2: print(f"WARNING: not loading {k}")
|
||||
continue
|
||||
v.assign(state_dict[k].to(v.device)).realize()
|
||||
|
||||
# torch support!
|
||||
|
||||
@@ -43,19 +66,28 @@ def torch_load(fn:str):
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
|
||||
|
||||
# 6 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 2: print(f"WARNING: this torch load is slow. it has to convert to CPU to permute {permute_indexes}")
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
return ret.reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name): return intercept[name] if module.startswith("torch") else super().find_class(module, name)
|
||||
def find_class(self, module, name):
|
||||
module_root = module.split(".")[0]
|
||||
if module_root not in whitelist:
|
||||
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
||||
return Dummy
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
|
||||
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||||
|
||||
@@ -39,14 +39,12 @@ class Tensor:
|
||||
device = Device.canonicalize(device)
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
# TODO: this has to realize, it shouldn't have to
|
||||
data = data.realize().toCPU()
|
||||
|
||||
if isinstance(data, LazyBuffer):
|
||||
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
lazydata = data
|
||||
lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
# TODO: create CPUBuffer directly
|
||||
lazydata = LazyBuffer.loadop(LoadOps.FROMCPU, data.shape, dtypes.from_np(data.dtype), device, data)
|
||||
elif isinstance(data, (int, float)):
|
||||
lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data)
|
||||
@@ -94,6 +92,7 @@ class Tensor:
|
||||
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
|
||||
return self
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
# NOTE: we are currently allowing assignments from different dtypes
|
||||
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
|
||||
Reference in New Issue
Block a user