mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
subbuffer support (#4397)
* subbuffer support * diskbuffer offset * cuda subbuffer works * use subbuffer * more subbuffer tests * consecutive * cast * consec * offset * view is a better name * offset is in nbytes * fix view + memory planner * delete unused DiskRunner * reverse order * no subbuffers on unrealized consts * only enabled for disk * don't reverse memory * view supported devices * pickle buffer view * ring jit * support extra view inputs in jit * fix JIT=2 issue * test copy jit * p2p isn't an option anymore * fix dep tracking issue * fix mypy * fix pickle * from_nv is contents now
This commit is contained in:
@@ -26,6 +26,13 @@ N = 128
|
||||
# shard_x is "data parallel"
|
||||
# shard_w is "model parallel"
|
||||
|
||||
def _test_allreduce(t:Tensor):
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
||||
ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize()
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
class TestMultiTensor(unittest.TestCase):
|
||||
def test_to(self):
|
||||
@@ -132,19 +139,37 @@ class TestMultiTensor(unittest.TestCase):
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
|
||||
def _test_allreduce(self):
|
||||
t = Tensor.rand(256, 256)
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).numpy()
|
||||
ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize()
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
|
||||
b.realize()
|
||||
np.testing.assert_almost_equal(aa, b.numpy(), decimal=5)
|
||||
|
||||
def test_allreduce_naive(self):
|
||||
with Context(RING=0): self._test_allreduce()
|
||||
with Context(RING=0):
|
||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_allreduce_ring(self):
|
||||
with Context(RING=2): self._test_allreduce()
|
||||
with Context(RING=2):
|
||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_copy_jit(self):
|
||||
@TinyJit
|
||||
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
|
||||
for _ in range(5):
|
||||
t = Tensor.rand(256).realize()
|
||||
x = copy_tensor(t)
|
||||
np.testing.assert_equal((t+1).numpy(), x.numpy())
|
||||
|
||||
def test_allreduce_naive_jit(self):
|
||||
with Context(RING=0):
|
||||
jit_allreduce = TinyJit(_test_allreduce)
|
||||
for _ in range(5):
|
||||
a,b = jit_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_allreduce_ring_jit(self):
|
||||
with Context(RING=2):
|
||||
jit_allreduce = TinyJit(_test_allreduce)
|
||||
for _ in range(5):
|
||||
a,b = jit_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
@unittest.skip("slow")
|
||||
def test_fuzz_allreduce(self):
|
||||
|
||||
@@ -16,6 +16,17 @@ class TestPickle(unittest.TestCase):
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_buffer_view(self):
|
||||
t = Tensor.arange(10, device="CLANG").contiguous().realize()
|
||||
vt = t[3:5].contiguous().realize()
|
||||
assert hasattr(vt.lazydata.buffer, 'base')
|
||||
ref_value = vt.tolist()
|
||||
st = pickle.dumps(vt)
|
||||
del t, vt
|
||||
vt2 = pickle.loads(st)
|
||||
assert hasattr(vt2.lazydata.buffer, 'base')
|
||||
assert ref_value == vt2.tolist()
|
||||
|
||||
def test_pickle_numpy(self):
|
||||
t = Tensor(np.array([1,2,3,4.]))
|
||||
st = pickle.dumps(t)
|
||||
|
||||
52
test/test_subbuffer.py
Normal file
52
test/test_subbuffer.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import unittest
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.lazy import view_supported_devices
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
class TestSubBuffer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated()
|
||||
self.buf.copyin(memoryview(bytearray(range(10))))
|
||||
|
||||
def test_subbuffer(self):
|
||||
vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
|
||||
tst = vbuf.as_buffer().tolist()
|
||||
assert tst == [3, 4]
|
||||
|
||||
def test_subbuffer_cast(self):
|
||||
# NOTE: bitcast depends on endianness
|
||||
vbuf = self.buf.view(2, dtypes.uint16, offset=3).ensure_allocated()
|
||||
tst = vbuf.as_buffer().cast("H").tolist()
|
||||
assert tst == [3|(4<<8), 5|(6<<8)]
|
||||
|
||||
def test_subbuffer_double(self):
|
||||
vbuf = self.buf.view(4, dtypes.uint8, offset=3).ensure_allocated()
|
||||
vvbuf = vbuf.view(2, dtypes.uint8, offset=1).ensure_allocated()
|
||||
tst = vvbuf.as_buffer().tolist()
|
||||
assert tst == [4, 5]
|
||||
|
||||
def test_subbuffer_len(self):
|
||||
vbuf = self.buf.view(5, dtypes.uint8, 2).ensure_allocated()
|
||||
mv = vbuf.as_buffer()
|
||||
assert len(mv) == 5
|
||||
mv = vbuf.as_buffer(allow_zero_copy=True)
|
||||
assert len(mv) == 5
|
||||
|
||||
def test_subbuffer_used(self):
|
||||
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
|
||||
# TODO: why does it needs contiguous
|
||||
vt = t[2:4].contiguous().realize()
|
||||
out = (vt + 100).tolist()
|
||||
assert out == [102, 103]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CUDA" or CI, "only CUDA")
|
||||
def test_subbuffer_transfer(self):
|
||||
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
|
||||
vt = t[2:5].contiguous().realize()
|
||||
out = vt.to("CUDA:1").realize().tolist()
|
||||
assert out == [2, 3, 4]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -231,6 +231,15 @@ class TestDiskTensor(unittest.TestCase):
|
||||
tout = [(x//256, x%256) for x in out]
|
||||
assert tout == list([(x+1,x) for x in range(32,64,2)])
|
||||
|
||||
def test_simple_read_bitcast_alt(self):
|
||||
fn = pathlib.Path(temp("dt1"))
|
||||
fn.unlink(missing_ok=True)
|
||||
fn.write_bytes(bytes(range(256))*2)
|
||||
t = Tensor.empty(16, 16*2, device=f"disk:{temp('dt1')}", dtype=dtypes.uint8)
|
||||
out = t.bitcast(dtypes.uint16)[1].to(Device.DEFAULT).tolist()
|
||||
tout = [(x//256, x%256) for x in out]
|
||||
assert tout == list([(x+1,x) for x in range(32,64,2)])
|
||||
|
||||
def test_write_ones(self):
|
||||
pathlib.Path(temp("dt2")).unlink(missing_ok=True)
|
||||
|
||||
@@ -269,6 +278,12 @@ class TestDiskTensor(unittest.TestCase):
|
||||
|
||||
np.testing.assert_array_equal(t.numpy(), np.array([3] * 10))
|
||||
|
||||
def test_bitcast(self):
|
||||
with open(temp('bf16'), "wb") as f: f.write(bytes(range(10,20)))
|
||||
t = Tensor.empty(5, dtype=dtypes.int16, device=f"disk:{temp('bf16')}")
|
||||
ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
|
||||
@@ -279,8 +294,9 @@ class TestDiskTensor(unittest.TestCase):
|
||||
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
|
||||
with open(temp('bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm_bf16_cast(dtypes.float)
|
||||
assert t.numpy().tolist() == [9984., -1, -1000, -9984, 20]
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}")
|
||||
ct = t.llvm_bf16_cast(dtypes.float)
|
||||
assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -13,25 +13,46 @@ class BufferOptions:
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
initial_value:Optional[bytes]=None, lb_refcount=0):
|
||||
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
||||
assert isinstance(dtype, DType)
|
||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount
|
||||
if opaque is not None: self.allocate(opaque)
|
||||
if initial_value is not None:
|
||||
self.allocate()
|
||||
self.copyin(memoryview(initial_value))
|
||||
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
||||
if base is None:
|
||||
assert offset == 0, "base buffers can't have offset"
|
||||
self._base = None
|
||||
self._lb_refcount = lb_refcount
|
||||
if opaque is not None: self.allocate(opaque)
|
||||
if initial_value is not None:
|
||||
self.allocate()
|
||||
self.copyin(memoryview(initial_value))
|
||||
else:
|
||||
assert base._base is None, "base can't have a base"
|
||||
assert device == base.device, "base must have the same device"
|
||||
self._base = base
|
||||
if preallocate: self.allocate()
|
||||
@property
|
||||
def base(self) -> Buffer: return self._base if self._base is not None else self
|
||||
@property
|
||||
def lb_refcount(self): return self.base._lb_refcount
|
||||
def ref(self, cnt): self.base._lb_refcount += cnt
|
||||
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||
def allocate(self, opaque=None) -> Buffer:
|
||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||
from tinygrad.device import Device
|
||||
self.allocator = Device[self.device].allocator
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
|
||||
else:
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
return self
|
||||
def __reduce__(self):
|
||||
buf = None
|
||||
if self._base is not None:
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||
if self.is_allocated():
|
||||
buf = bytearray(self.nbytes)
|
||||
@@ -41,10 +62,12 @@ class Buffer:
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if not hasattr(self, '_buf'): return
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
if self._base is None:
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
def __repr__(self):
|
||||
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
||||
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
|
||||
(">" if self.options is None else f"{self.options=}>")
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
@@ -63,3 +86,7 @@ class Buffer:
|
||||
assert self.is_allocated(), "can't copyout unallocated buffer"
|
||||
self.allocator.copyout(mv, self._buf)
|
||||
return mv
|
||||
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
||||
assert offset < self.nbytes, "offset must be less than nbytes"
|
||||
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
||||
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
||||
|
||||
@@ -113,6 +113,8 @@ class _MallocAllocator(LRUAllocator):
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
|
||||
|
||||
MallocAllocator = _MallocAllocator()
|
||||
|
||||
# **************** for Compiled Devices ****************
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner
|
||||
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner, EmptyOp, ViewOp
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
@@ -37,6 +37,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
||||
current_device = None
|
||||
|
||||
for ji in jit_cache:
|
||||
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
|
||||
@@ -110,7 +111,10 @@ class TinyJit(Generic[ReturnType]):
|
||||
def add_buffer(self, b:Buffer) -> Buffer:
|
||||
if found:=self.buffer_replace.get(b, None): return found
|
||||
if b.is_allocated() or b.lb_refcount > 0: return b
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||
if b._base is not None:
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
|
||||
else:
|
||||
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
@@ -119,6 +123,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
def reset(self):
|
||||
self.jit_cache: List[ExecItem] = []
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
|
||||
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
||||
self.cnt: int = 0
|
||||
|
||||
@@ -154,6 +159,13 @@ class TinyJit(Generic[ReturnType]):
|
||||
assert len(self.jit_cache), "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# track inputs that are views of buffers
|
||||
for ji in self.jit_cache:
|
||||
for b in ji.bufs:
|
||||
if b is not None and b._base is not None and b._base in input_rawbuffers:
|
||||
input_rawbuffers.append(b)
|
||||
self.extra_view_inputs.append((input_rawbuffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
||||
|
||||
# memory planning (optional)
|
||||
assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
|
||||
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
|
||||
@@ -166,6 +178,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_rawbuffers.append(Buffer(device, size, dtype, base=input_rawbuffers[idx], offset=offset).ensure_allocated())
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable
|
||||
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Union
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import DType
|
||||
@@ -13,7 +13,8 @@ class ExecItem:
|
||||
prg: Runner
|
||||
bufs: List[Optional[Buffer]]
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
|
||||
@@ -36,6 +37,11 @@ class EmptyOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
||||
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
||||
|
||||
def lower_runner(runner:Runner, bufs) -> ExecItem:
|
||||
# TODO: globals isn't on the stupid diskrunner, remove the need for it
|
||||
return ExecItem(runner, [bufs[x[0]] for x in runner.globals] if hasattr(runner, 'globals') else bufs)
|
||||
@@ -53,6 +59,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
|
||||
if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
raise RuntimeError(f"don't know how to lower {ast}")
|
||||
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
@@ -60,7 +67,8 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
|
||||
|
||||
capturing: List = [] # put classes with an add method in here
|
||||
|
||||
def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
||||
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
|
||||
if getenv("NO_MEMORY_PLANNER"): return {}
|
||||
last_appearance = {}
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u: last_appearance[buf] = i
|
||||
@@ -68,16 +76,24 @@ def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") ->
|
||||
# LRU algorithm
|
||||
assigned: Dict[Buffer, Buffer] = {}
|
||||
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
|
||||
|
||||
def handle_buffer(buf):
|
||||
key = (buf.device, buf.size, buf.dtype)
|
||||
if buf not in assigned:
|
||||
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
||||
else: assigned[buf] = Buffer(*key)
|
||||
if i == last_appearance[buf]:
|
||||
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
|
||||
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
# all unallocated unparented buffers are fair game to replace
|
||||
if buf.is_allocated() or buf.lb_refcount > 0: continue
|
||||
key = (buf.device, buf.size, buf.dtype)
|
||||
if buf not in assigned:
|
||||
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
|
||||
else: assigned[buf] = Buffer(*key)
|
||||
if i == last_appearance[buf]:
|
||||
local_cache[key].append(assigned[buf])
|
||||
# handle view buffers
|
||||
if buf._base is not None:
|
||||
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
|
||||
else:
|
||||
handle_buffer(buf)
|
||||
|
||||
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
|
||||
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
||||
|
||||
@@ -80,8 +80,8 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
|
||||
inputs: List[LazyBuffer] = []
|
||||
ast: List[LazyOp] = []
|
||||
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs)
|
||||
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
|
||||
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], [x.base for x in outs[0].srcs]
|
||||
else:
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
@@ -121,6 +121,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
if buf.op is LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes[buf.srcs[0].base] = None
|
||||
if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
|
||||
for x in buf.srcs:
|
||||
children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
||||
|
||||
@@ -218,7 +218,8 @@ def cpu_time_execution(cb, enable):
|
||||
# *** ctypes helpers
|
||||
|
||||
# TODO: make this work with read only memoryviews (if possible)
|
||||
def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
|
||||
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
||||
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
||||
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
||||
@@ -22,6 +22,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
|
||||
if enable_cache: lazycache[cache_key] = ret
|
||||
return ret
|
||||
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
|
||||
class LazyBuffer:
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
@@ -32,8 +33,15 @@ class LazyBuffer:
|
||||
# properties on base
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
||||
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
||||
self.buffer.lb_refcount += 1
|
||||
|
||||
if (self.op is LoadOps.CONTIGUOUS or (self.op is UnaryOps.CAST and self.arg[1] is True)) and srcs[0].st.consecutive and \
|
||||
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
|
||||
# some LazyBuffers can be processed with only a view, no AST required
|
||||
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
||||
self.op = LoadOps.VIEW
|
||||
else:
|
||||
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
|
||||
self.buffer.ref(1)
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.forced_realize = False
|
||||
else:
|
||||
@@ -42,7 +50,7 @@ class LazyBuffer:
|
||||
self._base = base
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1
|
||||
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
|
||||
|
||||
@@ -26,7 +26,7 @@ class ReduceOps(Enum):
|
||||
"""A -> B (reduce)"""
|
||||
SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.helpers import init_c_var, GraphException
|
||||
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
@@ -13,20 +13,20 @@ class CUDAGraph(MultiGraphRunner):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
self.cpu_buffers = []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
deps = self._access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
|
||||
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.outcount:] if x is not None],
|
||||
[x.base for x in ji.bufs[:ji.prg.outcount] if x is not None], new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
|
||||
@@ -37,28 +37,14 @@ class CUDAGraph(MultiGraphRunner):
|
||||
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
|
||||
src_dev = cast(CUDADevice, Device[src.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self._access_resources(read=[src], write=[dest], new_dependency=node_from)
|
||||
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
if getenv("CUDA_P2P", int(CUDADevice.peer_access)):
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
||||
else:
|
||||
self.cpu_buffers.append(cpu_buffer:=Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate())
|
||||
|
||||
node_to = cuda.CUgraphNode()
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1,
|
||||
ctypes.byref(cp_params), dest_dev.context))
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
||||
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
@@ -148,6 +148,7 @@ class CUDAAllocator(LRUAllocator):
|
||||
check(cuda.cuEventRecord(sync_event, None))
|
||||
check(cuda.cuCtxSetCurrent(dest_dev.context))
|
||||
check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
|
||||
def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset)
|
||||
|
||||
class CUDADevice(Compiled):
|
||||
devices: List[CUDADevice] = []
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io, functools
|
||||
from typing import Dict, List, Any, Optional
|
||||
from tinygrad.helpers import prod, OSX
|
||||
from tinygrad.device import Compiled, Allocator, Runner, Buffer
|
||||
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
import os, mmap, _posixshmem, io
|
||||
from typing import Optional
|
||||
from tinygrad.helpers import OSX
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
|
||||
class DiskBuffer:
|
||||
def __init__(self, device:DiskDevice, size:int, offset=0):
|
||||
@@ -31,32 +29,7 @@ class DiskAllocator(Allocator):
|
||||
fo.readinto(dest)
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
class DiskRunner(Runner):
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op is UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
self.new_dtype = top_src.arg.dtype
|
||||
assert top_src.op is BufferOps.LOAD, "top of AST must be load"
|
||||
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
|
||||
view = top_src.arg.st.views[0]
|
||||
assert view.mask is None, "view cannot have a mask"
|
||||
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
|
||||
self.new_size = prod(view.shape)
|
||||
self.new_offset = view.offset * top_src.arg.dtype.itemsize
|
||||
super().__init__(f"sz 0x{self.new_size:X} offset 0x{self.new_offset:X}", "DISK")
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False):
|
||||
assert len(rawbufs) == 2
|
||||
# TODO: this is a terrible hack that should be moved to lazy.py
|
||||
rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset
|
||||
def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
@@ -85,7 +58,3 @@ class DiskDevice(Compiled):
|
||||
if self.count == 0:
|
||||
if hasattr(self, 'fd'): os.close(self.fd)
|
||||
self.size = None
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, *ast:LazyOp):
|
||||
assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels."
|
||||
return DiskRunner(ast[0])
|
||||
|
||||
@@ -106,12 +106,12 @@ class HWComputeQueue:
|
||||
return self
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self
|
||||
@@ -137,12 +137,12 @@ class HWCopyQueue:
|
||||
return self
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), value, 0x0,
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), value, 0x0,
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user