Files
carrotpilot/tinygrad_repo/tinygrad/nn/state.py

401 lines
20 KiB
Python

import json, pathlib, zipfile, pickle, tarfile, struct, functools, io, zlib
from collections import OrderedDict
from typing import Any, Callable, BinaryIO, Iterable, cast
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
self._position, self._tensor = 0, t
def readable(self) -> bool: return True
def read(self, size: int = -1) -> bytes:
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
return buf
def readinto(self, buffer: Any) -> int:
data = self._tensor[self._position:self._position+len(buffer)].data()
buffer[:len(data)] = data
self._position += len(data)
return len(data)
def seekable(self) -> bool: return True
def seek(self, offset: int, whence: int = 0) -> int:
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
return self._position
# required to correctly implement BinaryIO
def __enter__(self): return self
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Tensor|str|pathlib.Path], T]:
@functools.wraps(func)
def wrapper(fn: Tensor|str|pathlib.Path) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
"""
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
"""
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Tensor|str|pathlib.Path) -> dict[str, Tensor]:
"""
Loads a .safetensor file, returning the `state_dict`.
```python
state_dict = nn.state.safe_load("test.safetensor")
```
"""
t, data_start, metadata = safe_load_metadata(fn)
data = t[data_start:]
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
for k, v in metadata.items() if k != "__metadata__" }
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=None):
"""
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
```python
t = Tensor([1, 2, 3])
nn.state.safe_save({'t':t}, "test.safetensor")
```
"""
headers, offset = {}, 0
if metadata: headers['__metadata__'] = metadata
for k,v in tensors.items():
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
"""
Returns a `state_dict` of the object, with optional prefix.
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(nn.state.get_state_dict(net).keys())
```
"""
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> list[Tensor]:
"""
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(len(nn.state.get_parameters(net)))
```
"""
return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> list[Tensor]:
"""
Loads a `state_dict` into a model. Return the loaded Tensors.
```python
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
state_dict = nn.state.get_state_dict(net)
nn.state.load_state_dict(net, state_dict)
```
"""
start_mem_used = GlobalCounters.mem_used
ret = []
with Timing("loaded weights in ",
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
if v.shape != state_dict[k].shape:
if {(), (1,)} == {state_dict[k].shape, v.shape}: state_dict[k] = state_dict[k].reshape(v.shape)
else: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
else: v.replace(state_dict[k].to(v.device))
if realize: v.realize()
if consume: del state_dict[k]
ret.append(v)
return ret
@accept_filename
def zip_extract(t: Tensor) -> dict[str, Tensor]:
files: dict[str, Tensor] = {}
with zipfile.ZipFile(TensorIO(t), "r") as myzip:
# sadly, the extra length needs to be read from the local header of each file.
# this is a limitation of the zip file format
header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist]
Tensor.realize(*header_contents)
for zi, header_content in zip(myzip.filelist, header_contents):
# header_offset + sizeFileHeader + File name length + Extra field length
file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist()))
files[zi.filename] = t[file_offset:file_offset+zi.compress_size]
match zi.compress_type:
case zipfile.ZIP_STORED: pass
# TODO: we need a zlib UOp so this can be lazy
case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15))
case _: raise NotImplementedError(f"compression {zi.compress_type} not supported")
return files
@accept_filename
def tar_extract(t: Tensor) -> dict[str, Tensor]:
"""
```python
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
```
"""
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
# torch support!
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
```python
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Loads a torch .pth file, returning the `state_dict`.
```python
state_dict = nn.state.torch_load("test.pth")
```
"""
storage_source: dict[str|int, Tensor] = {}
lens: dict[str|int, int] = {}
def _rebuild_tensor(storage, storage_offset, size, stride):
return _rebuild_tensor_v2(storage, storage_offset, size, stride)
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in storage_source: return None
byte_start, byte_end = storage_offset*storage[1].itemsize, (storage_offset + prod(size))*storage[1].itemsize
ret = storage_source[storage[2]][byte_start:byte_end].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
# TODO: find a nice way to support all movement ops on disktensors
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
class Parameter:
def __setstate__(self, state): self.tensor = state[0]
deserialized_objects: dict[str, Any] = {}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
"LongStorage": dtypes.int64, "_rebuild_tensor": _rebuild_tensor, "_rebuild_tensor_v2": _rebuild_tensor_v2,
"FloatTensor": None, "Parameter": Parameter}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
files = zip_extract(t)
base_name = next(iter(files)).split('/', 1)[0]
# keyed by persistent_id in pickle file
storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")}
return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load()
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
files = tar_extract(t)
f = io.BufferedReader(TensorIO(files["storages"]), 1_000_000)
# slice source tensor t
for _ in range(TorchPickle(f).load()):
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
byte_offset = f.tell()
storage_source[key] = files["storages"][byte_offset:byte_offset + sz * storage_type.itemsize]
f.seek(sz * storage_type.itemsize, 1)
f = io.BufferedReader(TensorIO(files["tensors"]), 1_000_000)
# get tensor metadata
for _ in range(TorchPickle(f).load()):
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
pkl_data = TorchPickle(io.BufferedReader(TensorIO(files["pickle"]), 1_000_000)).load()
return {k: v.tensor if isinstance(v, Parameter) else v for k, v in pkl_data.items()}
else:
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
# slice source tensor t
for i in ids:
storage_source[i] = t[base_offset + 8:base_offset + 8 + lens[i]]
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), bfloat16 (id: 30),
int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q5_0 (id: 6),
Q5_1 (id: 7), Q8_0 (id: 8), Q4_K (id: 12), Q5_K (id: 13),
Q6_K (id: 14), MXFP4 (id: 39)
"""
# https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356
# native types
if (dtype := {
0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8,
17: dtypes.int16, 18: dtypes.int32, 30: dtypes.bfloat16,
}.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].contiguous().bitcast(dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
# map to (number of elements, number of bytes)
if (nelements_nbytes := {
2:(32,18), 3:(32,20), 6:(32,22), 7:(32,24), 8:(32,34),
12:(256,144), 13:(256,176), 14:(256,210), 39:(32,17),
}.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])).contiguous()
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type in (6, 7):
d = blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
qh_off = 2 if ggml_type == 6 else 4
qh = q_to_uint8(blocks[:,qh_off:qh_off+4], 1).reshape((-1, 8, 4)).transpose(-1, -2).flatten(-2).bitcast(dtypes.int8)
q = q_to_uint8(blocks[:,qh_off+4:], 4).bitcast(dtypes.int8) + qh * 16
return q * d + (blocks[:,2:4].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 7 else -16 * d)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
# Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128)
# Q5_K: 256 elements per 176-byte block (d:2, dmin:2, scales:12, qh:32, qs:128)
if ggml_type in (12, 13):
d, dmin = (blocks[:,i:i+2].bitcast(dtypes.float16).cast(dtypes.float32).unsqueeze(-1) for i in [0, 2])
s = blocks[:,4:16] # 12 bytes: 6-bit scales[0-3], 6-bit mins[0-3], high bits[4-7]
sc = s[:,0:4].bitwise_and(63).cat(s[:,8:12].bitwise_and(0xF).bitwise_or(s[:,0:4].rshift(6).lshift(4)), dim=-1)
mn = s[:,4:8].bitwise_and(63).cat(s[:,8:12].rshift(4).bitwise_or(s[:,4:8].rshift(6).lshift(4)), dim=-1)
qs_off = 48 if ggml_type == 13 else 16
q = Tensor.stack((qs:=blocks[:,qs_off:qs_off+128].reshape(-1,4,32)).bitwise_and(0xF), qs.rshift(4), dim=2).reshape(-1,8,32)
if ggml_type == 13: q = q + q_to_uint8(blocks[:,16:48], 1).reshape(-1, 8, 32) * 16
return (d * sc.unsqueeze(-1) * q - dmin * mn.unsqueeze(-1)).flatten(-2)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
if ggml_type == 39:
e = blocks[:, 0].cast(dtypes.uint32)
small_bits = Tensor([0x00200000, 0x00400000], dtype=dtypes.uint32, device=t.device)[e.clip(0, 1).cast(dtypes.int32)] # e = 0 or e = 1 case
d = (e < 2).where(small_bits, ((e - 1) * 0x00800000).cast(dtypes.uint32)).bitcast(dtypes.float32).unsqueeze(-1)
codes = q_to_uint8(blocks[:, 1:17], 4)
fp4_lut = Tensor([0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0,
-0.0,-1.0,-2.0,-3.0,-4.0,-6.0,-8.0,-12.0],
dtype=dtypes.float32, device=t.device)
fp4_val = fp4_lut[codes]
return (fp4_val * d).flatten(-2)[:n]
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
@accept_filename
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a .gguf file, returning the `kv_data` and `state_dict`.
```python
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
```
NOTE: The provided tensor must be on a device that supports execution.
"""
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
def read_str(): return str(reader.read(read_uint64()), "utf-8")
def read_arr():
reader, n = readers[read_int32()], read_uint64()
return [ reader() for _ in range(n) ]
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
for _ in range(n_kv):
k, typ = read_str(), read_int32()
kv_data[k] = readers[typ]()
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
data_start = round_up(pos, alignment)
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
return kv_data, state_dict