Files
tinygrad/tinygrad/runtime/ops_disk.py
George Hotz 73a6ed7862 Apply ShapeTracker in interpreted backends (#1846)
* applying st

* tests pass

* minor cleanups

* torch too

* hack

* contiguous

* move mops

* contig in BN

* tests should pass

* make torch fast

* make zeros and ones contig by default

* no contig there

* fix padding with expanding

* might fix tests

* still doesn't fix bug, but should be there

* Revert "still doesn't fix bug, but should be there"

This reverts commit 8ea92f3e07.

* minor cleanups
2023-09-23 10:05:13 +08:00

41 lines
2.3 KiB
Python

import os, mmap
from typing import Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
self.shape = (size, ) if shape is None else shape
self.offset = offset # this is an offset in bytes
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
f = open(device, "a+b")
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
else:
buf[2] += 1
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf = size, dtype, buf
def __del__(self):
self._buf[2] -= 1
if self._buf[2] == 0: self._buf[0].close()
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
def as_strided(self, arg):
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf):
self._buf[0].seek(self.offset)
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, to_underlying=lambda x:x, from_underlying=lambda x:x)