subbuffer support (#4397)

* subbuffer support

* diskbuffer offset

* cuda subbuffer works

* use subbuffer

* more subbuffer tests

* consecutive

* cast

* consec

* offset

* view is a better name

* offset is in nbytes

* fix view + memory planner

* delete unused DiskRunner

* reverse order

* no subbuffers on unrealized consts

* only enabled for disk

* don't reverse memory

* view supported devices

* pickle buffer view

* ring jit

* support extra view inputs in jit

* fix JIT=2 issue

* test copy jit

* p2p isn't an option anymore

* fix dep tracking issue

* fix mypy

* fix pickle

* from_nv is contents now
This commit is contained in:
George Hotz
2024-05-03 18:05:57 -07:00
committed by GitHub
parent c7368515d2
commit 9fc4465557
16 changed files with 234 additions and 105 deletions

View File

@@ -26,6 +26,13 @@ N = 128
# shard_x is "data parallel"
# shard_w is "model parallel"
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize()
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
b.realize()
return aa, b
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestMultiTensor(unittest.TestCase):
def test_to(self):
@@ -132,19 +139,37 @@ class TestMultiTensor(unittest.TestCase):
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
def _test_allreduce(self):
t = Tensor.rand(256, 256)
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).numpy()
ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize()
b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
b.realize()
np.testing.assert_almost_equal(aa, b.numpy(), decimal=5)
def test_allreduce_naive(self):
with Context(RING=0): self._test_allreduce()
with Context(RING=0):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring(self):
with Context(RING=2): self._test_allreduce()
with Context(RING=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self):
@TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
for _ in range(5):
t = Tensor.rand(256).realize()
x = copy_tensor(t)
np.testing.assert_equal((t+1).numpy(), x.numpy())
def test_allreduce_naive_jit(self):
with Context(RING=0):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring_jit(self):
with Context(RING=2):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
@unittest.skip("slow")
def test_fuzz_allreduce(self):

View File

@@ -16,6 +16,17 @@ class TestPickle(unittest.TestCase):
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_buffer_view(self):
t = Tensor.arange(10, device="CLANG").contiguous().realize()
vt = t[3:5].contiguous().realize()
assert hasattr(vt.lazydata.buffer, 'base')
ref_value = vt.tolist()
st = pickle.dumps(vt)
del t, vt
vt2 = pickle.loads(st)
assert hasattr(vt2.lazydata.buffer, 'base')
assert ref_value == vt2.tolist()
def test_pickle_numpy(self):
t = Tensor(np.array([1,2,3,4.]))
st = pickle.dumps(t)

52
test/test_subbuffer.py Normal file
View File

@@ -0,0 +1,52 @@
import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
from tinygrad.buffer import Buffer
from tinygrad.lazy import view_supported_devices
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
class TestSubBuffer(unittest.TestCase):
def setUp(self):
self.buf = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated()
self.buf.copyin(memoryview(bytearray(range(10))))
def test_subbuffer(self):
vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
tst = vbuf.as_buffer().tolist()
assert tst == [3, 4]
def test_subbuffer_cast(self):
# NOTE: bitcast depends on endianness
vbuf = self.buf.view(2, dtypes.uint16, offset=3).ensure_allocated()
tst = vbuf.as_buffer().cast("H").tolist()
assert tst == [3|(4<<8), 5|(6<<8)]
def test_subbuffer_double(self):
vbuf = self.buf.view(4, dtypes.uint8, offset=3).ensure_allocated()
vvbuf = vbuf.view(2, dtypes.uint8, offset=1).ensure_allocated()
tst = vvbuf.as_buffer().tolist()
assert tst == [4, 5]
def test_subbuffer_len(self):
vbuf = self.buf.view(5, dtypes.uint8, 2).ensure_allocated()
mv = vbuf.as_buffer()
assert len(mv) == 5
mv = vbuf.as_buffer(allow_zero_copy=True)
assert len(mv) == 5
def test_subbuffer_used(self):
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
# TODO: why does it needs contiguous
vt = t[2:4].contiguous().realize()
out = (vt + 100).tolist()
assert out == [102, 103]
@unittest.skipIf(Device.DEFAULT != "CUDA" or CI, "only CUDA")
def test_subbuffer_transfer(self):
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
vt = t[2:5].contiguous().realize()
out = vt.to("CUDA:1").realize().tolist()
assert out == [2, 3, 4]
if __name__ == '__main__':
unittest.main()

View File

@@ -231,6 +231,15 @@ class TestDiskTensor(unittest.TestCase):
tout = [(x//256, x%256) for x in out]
assert tout == list([(x+1,x) for x in range(32,64,2)])
def test_simple_read_bitcast_alt(self):
fn = pathlib.Path(temp("dt1"))
fn.unlink(missing_ok=True)
fn.write_bytes(bytes(range(256))*2)
t = Tensor.empty(16, 16*2, device=f"disk:{temp('dt1')}", dtype=dtypes.uint8)
out = t.bitcast(dtypes.uint16)[1].to(Device.DEFAULT).tolist()
tout = [(x//256, x%256) for x in out]
assert tout == list([(x+1,x) for x in range(32,64,2)])
def test_write_ones(self):
pathlib.Path(temp("dt2")).unlink(missing_ok=True)
@@ -269,6 +278,12 @@ class TestDiskTensor(unittest.TestCase):
np.testing.assert_array_equal(t.numpy(), np.array([3] * 10))
def test_bitcast(self):
with open(temp('bf16'), "wb") as f: f.write(bytes(range(10,20)))
t = Tensor.empty(5, dtype=dtypes.int16, device=f"disk:{temp('bf16')}")
ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
@unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
@@ -279,8 +294,9 @@ class TestDiskTensor(unittest.TestCase):
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm_bf16_cast(dtypes.float)
assert t.numpy().tolist() == [9984., -1, -1000, -9984, 20]
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}")
ct = t.llvm_bf16_cast(dtypes.float)
assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
if __name__ == "__main__":
unittest.main()

View File

@@ -13,25 +13,46 @@ class BufferOptions:
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0):
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount
if opaque is not None: self.allocate(opaque)
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
if base is None:
assert offset == 0, "base buffers can't have offset"
self._base = None
self._lb_refcount = lb_refcount
if opaque is not None: self.allocate(opaque)
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
else:
assert base._base is None, "base can't have a base"
assert device == base.device, "base must have the same device"
self._base = base
if preallocate: self.allocate()
@property
def base(self) -> Buffer: return self._base if self._base is not None else self
@property
def lb_refcount(self): return self.base._lb_refcount
def ref(self, cnt): self.base._lb_refcount += cnt
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
from tinygrad.device import Device
self.allocator = Device[self.device].allocator
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
if self._base is not None:
self._base.ensure_allocated()
assert hasattr(self.allocator, "offset"), "offset function required for view"
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
else:
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
return self
def __reduce__(self):
buf = None
if self._base is not None:
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
if self.is_allocated():
buf = bytearray(self.nbytes)
@@ -41,10 +62,12 @@ class Buffer:
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
if self._base is None:
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self):
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
(">" if self.options is None else f"{self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
@@ -63,3 +86,7 @@ class Buffer:
assert self.is_allocated(), "can't copyout unallocated buffer"
self.allocator.copyout(mv, self._buf)
return mv
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
assert offset < self.nbytes, "offset must be less than nbytes"
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
return Buffer(self.device, size, dtype, base=self, offset=offset)

View File

@@ -113,6 +113,8 @@ class _MallocAllocator(LRUAllocator):
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
MallocAllocator = _MallocAllocator()
# **************** for Compiled Devices ****************

View File

@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner, EmptyOp, ViewOp
from tinygrad.nn.state import get_parameters
from weakref import WeakKeyDictionary
@@ -37,6 +37,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
current_device = None
for ji in jit_cache:
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
@@ -110,7 +111,10 @@ class TinyJit(Generic[ReturnType]):
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self.buffer_replace.get(b, None): return found
if b.is_allocated() or b.lb_refcount > 0: return b
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
if b._base is not None:
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
else:
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
def add(self, ei:ExecItem):
@@ -119,6 +123,7 @@ class TinyJit(Generic[ReturnType]):
def reset(self):
self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
self.cnt: int = 0
@@ -154,6 +159,13 @@ class TinyJit(Generic[ReturnType]):
assert len(self.jit_cache), "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# track inputs that are views of buffers
for ji in self.jit_cache:
for b in ji.bufs:
if b is not None and b._base is not None and b._base in input_rawbuffers:
input_rawbuffers.append(b)
self.extra_view_inputs.append((input_rawbuffers.index(b.base), b.offset, b.device, b.size, b.dtype))
# memory planning (optional)
assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
@@ -166,6 +178,8 @@ class TinyJit(Generic[ReturnType]):
elif self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_rawbuffers.append(Buffer(device, size, dtype, base=input_rawbuffers[idx], offset=offset).ensure_allocated())
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)

View File

@@ -1,4 +1,4 @@
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import DType
@@ -13,7 +13,8 @@ class ExecItem:
prg: Runner
bufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
@@ -36,6 +37,11 @@ class EmptyOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
def lower_runner(runner:Runner, bufs) -> ExecItem:
# TODO: globals isn't on the stupid diskrunner, remove the need for it
return ExecItem(runner, [bufs[x[0]] for x in runner.globals] if hasattr(runner, 'globals') else bufs)
@@ -53,6 +59,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {ast}")
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
@@ -60,7 +67,8 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
capturing: List = [] # put classes with an add method in here
def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]:
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]:
if getenv("NO_MEMORY_PLANNER"): return {}
last_appearance = {}
for i,u in enumerate(buffers):
for buf in u: last_appearance[buf] = i
@@ -68,16 +76,24 @@ def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") ->
# LRU algorithm
assigned: Dict[Buffer, Buffer] = {}
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
def handle_buffer(buf):
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf])
for i,u in enumerate(buffers):
for buf in u:
# all unallocated unparented buffers are fair game to replace
if buf.is_allocated() or buf.lb_refcount > 0: continue
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
local_cache[key].append(assigned[buf])
# handle view buffers
if buf._base is not None:
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset)
else:
handle_buffer(buf)
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",

View File

@@ -80,8 +80,8 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
inputs: List[LazyBuffer] = []
ast: List[LazyOp] = []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs)
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], [x.base for x in outs[0].srcs]
else:
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
@@ -121,6 +121,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if buf.op is LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes[buf.srcs[0].base] = None
if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)

View File

@@ -218,7 +218,8 @@ def cpu_time_execution(cb, enable):
# *** ctypes helpers
# TODO: make this work with read only memoryviews (if possible)
def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
def from_mv(mv:memoryview, to_type=ctypes.c_char):
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
@functools.lru_cache(maxsize=None)

View File

@@ -22,6 +22,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
if enable_cache: lazycache[cache_key] = ret
return ret
view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
class LazyBuffer:
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
@@ -32,8 +33,15 @@ class LazyBuffer:
# properties on base
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer.lb_refcount += 1
if (self.op is LoadOps.CONTIGUOUS or (self.op is UnaryOps.CAST and self.arg[1] is True)) and srcs[0].st.consecutive and \
not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
# some LazyBuffers can be processed with only a view, no AST required
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
self.op = LoadOps.VIEW
else:
self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer.ref(1)
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
else:
@@ -42,7 +50,7 @@ class LazyBuffer:
self._base = base
def __del__(self):
if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1
if hasattr(self, 'buffer'): self.buffer.ref(-1)
def __repr__(self) -> str:
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"

View File

@@ -26,7 +26,7 @@ class ReduceOps(Enum):
"""A -> B (reduce)"""
SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]

View File

@@ -1,8 +1,8 @@
import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException, getenv
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device, BufferOptions
from tinygrad.helpers import init_c_var, GraphException
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem
@@ -13,20 +13,20 @@ class CUDAGraph(MultiGraphRunner):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
self.cpu_buffers = []
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.launch_dims(var_vals)
new_node = cuda.CUgraphNode()
deps = self._access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.outcount:] if x is not None],
[x.base for x in ji.bufs[:ji.prg.outcount] if x is not None], new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
@@ -37,28 +37,14 @@ class CUDAGraph(MultiGraphRunner):
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
src_dev = cast(CUDADevice, Device[src.device])
node_from = cuda.CUgraphNode()
deps = self._access_resources(read=[src], write=[dest], new_dependency=node_from)
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
if getenv("CUDA_P2P", int(CUDADevice.peer_access)):
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
else:
self.cpu_buffers.append(cpu_buffer:=Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate())
node_to = cuda.CUgraphNode()
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1,
ctypes.byref(cp_params), dest_dev.context))
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))

View File

@@ -148,6 +148,7 @@ class CUDAAllocator(LRUAllocator):
check(cuda.cuEventRecord(sync_event, None))
check(cuda.cuCtxSetCurrent(dest_dev.context))
check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset)
class CUDADevice(Compiled):
devices: List[CUDADevice] = []

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
import os, mmap, _posixshmem, io, functools
from typing import Dict, List, Any, Optional
from tinygrad.helpers import prod, OSX
from tinygrad.device import Compiled, Allocator, Runner, Buffer
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
from tinygrad.shape.view import strides_for_shape
import os, mmap, _posixshmem, io
from typing import Optional
from tinygrad.helpers import OSX
from tinygrad.device import Compiled, Allocator
class DiskBuffer:
def __init__(self, device:DiskDevice, size:int, offset=0):
@@ -31,32 +29,7 @@ class DiskAllocator(Allocator):
fo.readinto(dest)
else:
dest[:] = src._buf()
class DiskRunner(Runner):
def __init__(self, ast:LazyOp):
# two ASTs are allowed here.
assert ast.op is BufferOps.STORE, "output of AST must be store"
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
if ast.src[0].op is UnaryOps.CAST:
top_src = ast.src[0].src[0]
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
self.new_dtype = ast.src[0].arg[0]
else:
top_src = ast.src[0]
self.new_dtype = top_src.arg.dtype
assert top_src.op is BufferOps.LOAD, "top of AST must be load"
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
view = top_src.arg.st.views[0]
assert view.mask is None, "view cannot have a mask"
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
self.new_size = prod(view.shape)
self.new_offset = view.offset * top_src.arg.dtype.itemsize
super().__init__(f"sz 0x{self.new_size:X} offset 0x{self.new_offset:X}", "DISK")
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False):
assert len(rawbufs) == 2
# TODO: this is a terrible hack that should be moved to lazy.py
rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset
def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset)
class DiskDevice(Compiled):
def __init__(self, device:str):
@@ -85,7 +58,3 @@ class DiskDevice(Compiled):
if self.count == 0:
if hasattr(self, 'fd'): os.close(self.fd)
self.size = None
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, *ast:LazyOp):
assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels."
return DiskRunner(ast[0])

View File

@@ -106,12 +106,12 @@ class HWComputeQueue:
return self
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
return self
def signal(self, signal, value=0, timestamp=False):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
return self
@@ -137,12 +137,12 @@ class HWCopyQueue:
return self
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), value, 0x0,
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), value, 0x0,
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
return self
def signal(self, signal, value=0, timestamp=False):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value),
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
return self