Fast DiskTensor to other Tensor (#916)

* make disktensors fast

* loading

* loader for sd and llama
This commit is contained in:
George Hotz
2023-06-03 12:25:41 -07:00
committed by GitHub
parent 791530045d
commit ed1963b899
11 changed files with 109 additions and 76 deletions

View File

@@ -207,6 +207,7 @@ if __name__ == "__main__":
args = parser.parse_args()
chatbot = args.prompt == None
"""
# load model (you have to find the weights yourself)
from extra.utils import fake_torch_load_zipped, get_child
@@ -262,18 +263,12 @@ if __name__ == "__main__":
get_child(model, k).assign(v).realize()
del weights
"""
# disktensor loader isn't fast yet
"""
from tinygrad.state import torch_load, get_state_dict
state_dict = torch_load(WEIGHTS_7B_FILENAME)
model = Transformer(**args_7B)
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
for k,v in (t := tqdm(get_state_dict(model).items())):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k}")
if k not in state_dict: continue
v.assign(state_dict[k].to(v.device)).realize()
"""
from tinygrad.state import torch_load, load_state_dict
load_state_dict(model, torch_load(WEIGHTS_7B_FILENAME), strict=False)
# *** prompt engineers work here ****

View File

@@ -2,10 +2,7 @@
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
from pathlib import Path
import gzip
import argparse
import math
import re
import gzip, argparse, math, re
from functools import lru_cache
from collections import namedtuple
@@ -14,7 +11,8 @@ from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm
from extra.utils import fake_torch_load_zipped, get_child, download_file
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
@@ -613,24 +611,10 @@ if __name__ == "__main__":
model = StableDiffusion()
# load in weights
download_file(
'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
FILENAME
)
dat = fake_torch_load_zipped(open(FILENAME, "rb"))
for k,v in dat['state_dict'].items():
try:
w = get_child(model, k)
except (AttributeError, KeyError, IndexError):
#traceback.print_exc()
w = None
#print(f"{str(v.shape):30s}" if v is not None else v, w.shape if w is not None else w, k)
if w is not None:
assert w.shape == v.shape and w.dtype == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype} != {v.dtype}"
w.assign(v)
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
# run through CLIP to get context
tokenizer = ClipTokenizer()
prompt = tokenizer.encode(args.prompt)
context = model.cond_stage_model.transformer.text_model(prompt).realize()

View File

@@ -10,11 +10,19 @@
#include <thread>
#include <chrono>
//#define FN "/dev/nvme0n1"
#define FN "../../weights/LLaMA/7B/consolidated.00.pth"
#define SZ (unsigned long long)(512*1024*1024)
#define CNT 10LL
void test_read() {
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
#ifdef O_DIRECT
int f = open(FN, O_RDONLY|O_DIRECT);
#else
int f = open(FN, O_RDONLY);
//fcntl(f, F_NOCACHE, 1);
#endif
printf("open %d\n", f);
/*void *buf = malloc(CNT*SZ);
@@ -42,7 +50,11 @@ void test_read() {
}
void test_mmap() {
int f = open("/dev/nvme0n1", O_RDONLY|O_DIRECT);
#ifdef O_DIRECT
int f = open(FN, O_RDONLY|O_DIRECT);
#else
int f = open(FN, O_RDONLY);
#endif
printf("open %d\n", f);
void *dat = mmap(NULL, SZ*CNT, PROT_READ, MAP_PRIVATE, f, 0);
@@ -62,10 +74,13 @@ void test_mmap() {
}
int main() {
system("sync; echo 1 > /proc/sys/vm/drop_caches");
test_mmap();
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
//system("sudo purge");
//test_mmap();
//system("sync; echo 1 > /proc/sys/vm/drop_caches");
//test_read();
system("sudo purge");
test_read();
test_read();
}

View File

@@ -1,11 +1,4 @@
import time
class Timing(object):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self): self.st = time.perf_counter_ns()
def __exit__(self, exc_type, exc_val, exc_tb):
self.et = time.perf_counter_ns() - self.st
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
from tinygrad.helpers import Timing
def enable_early_exec():
import subprocess, multiprocessing

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, asdict
import os, math, functools
import os, math, functools, time
import numpy as np
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
ShapeType = Tuple[int, ...]
@@ -39,6 +39,13 @@ class ContextVar:
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
class Timing(object):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self): self.st = time.perf_counter_ns()
def __exit__(self, exc_type, exc_val, exc_tb):
self.et = time.perf_counter_ns() - self.st
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
# **** tinygrad now supports dtypes! *****
class DType(NamedTuple):

View File

@@ -6,7 +6,8 @@ from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
from tinygrad.runtime.lib import RawConst, RawBuffer
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped
from tinygrad.runtime.ops_disk import RawDiskBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@@ -122,6 +123,14 @@ class LazyBuffer:
elif self.op.op == LoadOps.CUSTOM:
# this needs to immediately realize
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
elif self.op.op == LoadOps.FROM:
rawbuf = self.op.src[0].realize()
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped):
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer())
else:
self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args())
elif self.optype == LoadOps:
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
if self.op.op == LoadOps.EMPTY:
@@ -167,8 +176,8 @@ class LazyBuffer:
return self
@staticmethod
def loadop(op, shape, dtype, device, arg=None) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple(), arg), dtype)
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:

View File

@@ -1,5 +1,5 @@
# sorted in order of increasing complexity
from typing import List, Dict
from typing import List
from tinygrad.tensor import Tensor
class Optimizer:
@@ -67,15 +67,6 @@ class LAMB(Optimizer):
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
# TODO: remove this
from tinygrad.state import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401

View File

@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]

View File

@@ -11,11 +11,16 @@ class RawDiskBuffer(RawBufferMapped):
self.offset = offset # this is an offset in bytes
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
with open(device, "a+b") as f:
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = mmap.mmap(f.fileno(), size * dtype.itemsize)
f = open(device, "a+b")
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
else:
buf[2] += 1
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def __del__(self):
self._buf[2] -= 1
if self._buf[2] == 0: self._buf[0].close()
def cast(self, new_dtype:DType): return RawDiskBuffer(self.size, new_dtype, buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg):
@@ -23,7 +28,10 @@ class RawDiskBuffer(RawBufferMapped):
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
def _buffer(self): return memoryview(self._buf)[self.offset:self.offset+self.size*self.dtype.itemsize]
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf):
self._buf[0].seek(self.offset)
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }

View File

@@ -1,7 +1,8 @@
import os, json, pathlib, zipfile, pickle
from typing import Dict, Union
from tqdm import tqdm
from typing import Dict, Union, List
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters
from tinygrad.shape.shapetracker import strides_for_shape
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
@@ -26,8 +27,30 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# TODO: move get_state_dict and get_parameters here
from tinygrad.nn.optim import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
# state dict
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict, strict=True):
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
for k,v in (t := tqdm(get_state_dict(model).items())):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
if k not in state_dict and not strict:
if DEBUG >= 2: print(f"WARNING: not loading {k}")
continue
v.assign(state_dict[k].to(v.device)).realize()
# torch support!
@@ -43,19 +66,28 @@ def torch_load(fn:str):
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1])
# 6 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 2: print(f"WARNING: this torch load is slow. it has to convert to CPU to permute {permute_indexes}")
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name): return intercept[name] if module.startswith("torch") else super().find_class(module, name)
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return pid
if tuple(t[0:2].numpy()) == (0x50, 0x4b):

View File

@@ -39,14 +39,12 @@ class Tensor:
device = Device.canonicalize(device)
if isinstance(data, list):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
# TODO: this has to realize, it shouldn't have to
data = data.realize().toCPU()
if isinstance(data, LazyBuffer):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
lazydata = data
lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
elif isinstance(data, np.ndarray):
# TODO: create CPUBuffer directly
lazydata = LazyBuffer.loadop(LoadOps.FROMCPU, data.shape, dtypes.from_np(data.dtype), device, data)
elif isinstance(data, (int, float)):
lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data)
@@ -94,6 +92,7 @@ class Tensor:
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we are currently allowing assignments from different dtypes
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")