From 9fc4465557831b614b56dd645eebc940ca0fa1bb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 3 May 2024 18:05:57 -0700 Subject: [PATCH] subbuffer support (#4397) * subbuffer support * diskbuffer offset * cuda subbuffer works * use subbuffer * more subbuffer tests * consecutive * cast * consec * offset * view is a better name * offset is in nbytes * fix view + memory planner * delete unused DiskRunner * reverse order * no subbuffers on unrealized consts * only enabled for disk * don't reverse memory * view supported devices * pickle buffer view * ring jit * support extra view inputs in jit * fix JIT=2 issue * test copy jit * p2p isn't an option anymore * fix dep tracking issue * fix mypy * fix pickle * from_nv is contents now --- test/test_multitensor.py | 45 ++++++++++++++++++++++------- test/test_pickle.py | 11 +++++++ test/test_subbuffer.py | 52 ++++++++++++++++++++++++++++++++++ test/unit/test_disk_tensor.py | 20 +++++++++++-- tinygrad/buffer.py | 47 +++++++++++++++++++++++------- tinygrad/device.py | 2 ++ tinygrad/engine/jit.py | 18 ++++++++++-- tinygrad/engine/realize.py | 34 ++++++++++++++++------ tinygrad/engine/schedule.py | 5 ++-- tinygrad/helpers.py | 3 +- tinygrad/lazy.py | 14 +++++++-- tinygrad/ops.py | 2 +- tinygrad/runtime/graph/cuda.py | 36 +++++++---------------- tinygrad/runtime/ops_cuda.py | 1 + tinygrad/runtime/ops_disk.py | 41 ++++----------------------- tinygrad/runtime/ops_nv.py | 8 +++--- 16 files changed, 234 insertions(+), 105 deletions(-) create mode 100644 test/test_subbuffer.py diff --git a/test/test_multitensor.py b/test/test_multitensor.py index a85d927d2d..2c49235856 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -26,6 +26,13 @@ N = 128 # shard_x is "data parallel" # shard_w is "model parallel" +def _test_allreduce(t:Tensor): + aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() + ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize() + b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0)) + b.realize() + return aa, b + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestMultiTensor(unittest.TestCase): def test_to(self): @@ -132,19 +139,37 @@ class TestMultiTensor(unittest.TestCase): fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) - def _test_allreduce(self): - t = Tensor.rand(256, 256) - aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).numpy() - ts = t.shard(tuple([d0, d1, d2, d3]), 0).realize() - b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0)) - b.realize() - np.testing.assert_almost_equal(aa, b.numpy(), decimal=5) - def test_allreduce_naive(self): - with Context(RING=0): self._test_allreduce() + with Context(RING=0): + a,b = _test_allreduce(Tensor.rand(256, 256)) + np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) def test_allreduce_ring(self): - with Context(RING=2): self._test_allreduce() + with Context(RING=2): + a,b = _test_allreduce(Tensor.rand(256, 256)) + np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) + + def test_copy_jit(self): + @TinyJit + def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1) + for _ in range(5): + t = Tensor.rand(256).realize() + x = copy_tensor(t) + np.testing.assert_equal((t+1).numpy(), x.numpy()) + + def test_allreduce_naive_jit(self): + with Context(RING=0): + jit_allreduce = TinyJit(_test_allreduce) + for _ in range(5): + a,b = jit_allreduce(Tensor.rand(256, 256)) + np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) + + def test_allreduce_ring_jit(self): + with Context(RING=2): + jit_allreduce = TinyJit(_test_allreduce) + for _ in range(5): + a,b = jit_allreduce(Tensor.rand(256, 256)) + np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) @unittest.skip("slow") def test_fuzz_allreduce(self): diff --git a/test/test_pickle.py b/test/test_pickle.py index ef5597d9b8..7833bfe6af 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -16,6 +16,17 @@ class TestPickle(unittest.TestCase): t2:Tensor = pickle.loads(st) np.testing.assert_equal(t.numpy(), t2.numpy()) + def test_pickle_buffer_view(self): + t = Tensor.arange(10, device="CLANG").contiguous().realize() + vt = t[3:5].contiguous().realize() + assert hasattr(vt.lazydata.buffer, 'base') + ref_value = vt.tolist() + st = pickle.dumps(vt) + del t, vt + vt2 = pickle.loads(st) + assert hasattr(vt2.lazydata.buffer, 'base') + assert ref_value == vt2.tolist() + def test_pickle_numpy(self): t = Tensor(np.array([1,2,3,4.])) st = pickle.dumps(t) diff --git a/test/test_subbuffer.py b/test/test_subbuffer.py new file mode 100644 index 0000000000..85a55384f6 --- /dev/null +++ b/test/test_subbuffer.py @@ -0,0 +1,52 @@ +import unittest +from tinygrad import Device, dtypes, Tensor +from tinygrad.helpers import CI +from tinygrad.buffer import Buffer +from tinygrad.lazy import view_supported_devices + +@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") +class TestSubBuffer(unittest.TestCase): + def setUp(self): + self.buf = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated() + self.buf.copyin(memoryview(bytearray(range(10)))) + + def test_subbuffer(self): + vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated() + tst = vbuf.as_buffer().tolist() + assert tst == [3, 4] + + def test_subbuffer_cast(self): + # NOTE: bitcast depends on endianness + vbuf = self.buf.view(2, dtypes.uint16, offset=3).ensure_allocated() + tst = vbuf.as_buffer().cast("H").tolist() + assert tst == [3|(4<<8), 5|(6<<8)] + + def test_subbuffer_double(self): + vbuf = self.buf.view(4, dtypes.uint8, offset=3).ensure_allocated() + vvbuf = vbuf.view(2, dtypes.uint8, offset=1).ensure_allocated() + tst = vvbuf.as_buffer().tolist() + assert tst == [4, 5] + + def test_subbuffer_len(self): + vbuf = self.buf.view(5, dtypes.uint8, 2).ensure_allocated() + mv = vbuf.as_buffer() + assert len(mv) == 5 + mv = vbuf.as_buffer(allow_zero_copy=True) + assert len(mv) == 5 + + def test_subbuffer_used(self): + t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize() + # TODO: why does it needs contiguous + vt = t[2:4].contiguous().realize() + out = (vt + 100).tolist() + assert out == [102, 103] + + @unittest.skipIf(Device.DEFAULT != "CUDA" or CI, "only CUDA") + def test_subbuffer_transfer(self): + t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize() + vt = t[2:5].contiguous().realize() + out = vt.to("CUDA:1").realize().tolist() + assert out == [2, 3, 4] + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 32ad61a84e..6095f716e8 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -231,6 +231,15 @@ class TestDiskTensor(unittest.TestCase): tout = [(x//256, x%256) for x in out] assert tout == list([(x+1,x) for x in range(32,64,2)]) + def test_simple_read_bitcast_alt(self): + fn = pathlib.Path(temp("dt1")) + fn.unlink(missing_ok=True) + fn.write_bytes(bytes(range(256))*2) + t = Tensor.empty(16, 16*2, device=f"disk:{temp('dt1')}", dtype=dtypes.uint8) + out = t.bitcast(dtypes.uint16)[1].to(Device.DEFAULT).tolist() + tout = [(x//256, x%256) for x in out] + assert tout == list([(x+1,x) for x in range(32,64,2)]) + def test_write_ones(self): pathlib.Path(temp("dt2")).unlink(missing_ok=True) @@ -269,6 +278,12 @@ class TestDiskTensor(unittest.TestCase): np.testing.assert_array_equal(t.numpy(), np.array([3] * 10)) + def test_bitcast(self): + with open(temp('bf16'), "wb") as f: f.write(bytes(range(10,20))) + t = Tensor.empty(5, dtype=dtypes.int16, device=f"disk:{temp('bf16')}") + ret = t.to("CLANG").bitcast(dtypes.uint16) + 1 + assert ret.tolist() == [2827, 3341, 3855, 4369, 4883] + @unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI") def test_bf16_disk_write_read(self): t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32) @@ -279,8 +294,9 @@ class TestDiskTensor(unittest.TestCase): adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)]) with open(temp('bf16'), "wb") as f: f.write(adat) - t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm_bf16_cast(dtypes.float) - assert t.numpy().tolist() == [9984., -1, -1000, -9984, 20] + t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}") + ct = t.llvm_bf16_cast(dtypes.float) + assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20] if __name__ == "__main__": unittest.main() diff --git a/tinygrad/buffer.py b/tinygrad/buffer.py index 15153c6c32..fb14242d6d 100644 --- a/tinygrad/buffer.py +++ b/tinygrad/buffer.py @@ -13,25 +13,46 @@ class BufferOptions: class Buffer: def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, - initial_value:Optional[bytes]=None, lb_refcount=0): + initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False): assert isinstance(dtype, DType) if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be? - self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount - if opaque is not None: self.allocate(opaque) - if initial_value is not None: - self.allocate() - self.copyin(memoryview(initial_value)) + self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset + if base is None: + assert offset == 0, "base buffers can't have offset" + self._base = None + self._lb_refcount = lb_refcount + if opaque is not None: self.allocate(opaque) + if initial_value is not None: + self.allocate() + self.copyin(memoryview(initial_value)) + else: + assert base._base is None, "base can't have a base" + assert device == base.device, "base must have the same device" + self._base = base + if preallocate: self.allocate() + @property + def base(self) -> Buffer: return self._base if self._base is not None else self + @property + def lb_refcount(self): return self.base._lb_refcount + def ref(self, cnt): self.base._lb_refcount += cnt def is_allocated(self) -> bool: return hasattr(self, '_buf') def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self def allocate(self, opaque=None) -> Buffer: assert not hasattr(self, '_buf'), "can't allocate already allocated buffer" from tinygrad.device import Device self.allocator = Device[self.device].allocator - self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options) - if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes + if self._base is not None: + self._base.ensure_allocated() + assert hasattr(self.allocator, "offset"), "offset function required for view" + self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset) + else: + self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options) + if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes return self def __reduce__(self): buf = None + if self._base is not None: + return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf')) if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount) if self.is_allocated(): buf = bytearray(self.nbytes) @@ -41,10 +62,12 @@ class Buffer: def nbytes(self): return self.size*self.dtype.itemsize def __del__(self): if not hasattr(self, '_buf'): return - if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes - self.allocator.free(self._buf, self.nbytes, self.options) + if self._base is None: + if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes + self.allocator.free(self._buf, self.nbytes, self.options) def __repr__(self): return f"" if self.options is None else f"{self.options=}>") def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: # zero copy with as_buffer (disabled by default due to use after free) @@ -63,3 +86,7 @@ class Buffer: assert self.is_allocated(), "can't copyout unallocated buffer" self.allocator.copyout(mv, self._buf) return mv + def view(self, size:int, dtype:DType, offset:int) -> Buffer: + assert offset < self.nbytes, "offset must be less than nbytes" + if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset) + return Buffer(self.device, size, dtype, base=self, offset=offset) diff --git a/tinygrad/device.py b/tinygrad/device.py index faa6a6aeef..5acca1216b 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -113,6 +113,8 @@ class _MallocAllocator(LRUAllocator): def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src)) def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) + def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size]) + MallocAllocator = _MallocAllocator() # **************** for Compiled Devices **************** diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 0c054f5e57..a1cfae9da6 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -8,7 +8,7 @@ from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, Device from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint -from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner +from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner, EmptyOp, ViewOp from tinygrad.nn.state import get_parameters from weakref import WeakKeyDictionary @@ -37,6 +37,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] current_device = None for ji in jit_cache: + if ji.prg.__class__ in {EmptyOp, ViewOp}: continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}: @@ -110,7 +111,10 @@ class TinyJit(Generic[ReturnType]): def add_buffer(self, b:Buffer) -> Buffer: if found:=self.buffer_replace.get(b, None): return found if b.is_allocated() or b.lb_refcount > 0: return b - self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options) + if b._base is not None: + self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset) + else: + self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options) return ret def add(self, ei:ExecItem): @@ -119,6 +123,7 @@ class TinyJit(Generic[ReturnType]): def reset(self): self.jit_cache: List[ExecItem] = [] self.input_replace: Dict[Tuple[int, int], int] = {} + self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = [] self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary() self.cnt: int = 0 @@ -154,6 +159,13 @@ class TinyJit(Generic[ReturnType]): assert len(self.jit_cache), "didn't JIT anything!" if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") + # track inputs that are views of buffers + for ji in self.jit_cache: + for b in ji.bufs: + if b is not None and b._base is not None and b._base in input_rawbuffers: + input_rawbuffers.append(b) + self.extra_view_inputs.append((input_rawbuffers.index(b.base), b.offset, b.device, b.size, b.dtype)) + # memory planning (optional) assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ") self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache] @@ -166,6 +178,8 @@ class TinyJit(Generic[ReturnType]): elif self.cnt >= 2: # jit exec assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" + for idx, offset, device, size, dtype in self.extra_view_inputs: + input_rawbuffers.append(Buffer(device, size, dtype, base=input_rawbuffers[idx], offset=offset).ensure_allocated()) for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx] if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") for ei in self.jit_cache: ei.run(var_vals, jit=True) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b00544a0dd..7aa363ce26 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable +from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Union from collections import defaultdict from dataclasses import dataclass from tinygrad.dtype import DType @@ -13,7 +13,8 @@ class ExecItem: prg: Runner bufs: List[Optional[Buffer]] def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: - et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) + bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs] + et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals)) @@ -36,6 +37,11 @@ class EmptyOp(Runner): def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device) def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass +class ViewOp(Runner): + def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): + assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}" + def lower_runner(runner:Runner, bufs) -> ExecItem: # TODO: globals isn't on the stupid diskrunner, remove the need for it return ExecItem(runner, [bufs[x[0]] for x in runner.globals] if hasattr(runner, 'globals') else bufs) @@ -53,6 +59,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs)) if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs)) if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) + if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs)) raise RuntimeError(f"don't know how to lower {ast}") def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: @@ -60,7 +67,8 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non capturing: List = [] # put classes with an add method in here -def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]: +def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], debug_prefix="") -> Dict[Buffer, Buffer]: + if getenv("NO_MEMORY_PLANNER"): return {} last_appearance = {} for i,u in enumerate(buffers): for buf in u: last_appearance[buf] = i @@ -68,16 +76,24 @@ def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> # LRU algorithm assigned: Dict[Buffer, Buffer] = {} local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list) + + def handle_buffer(buf): + key = (buf.device, buf.size, buf.dtype) + if buf not in assigned: + if len(ll:=local_cache[key]): assigned[buf] = ll.pop() + else: assigned[buf] = Buffer(*key) + if i == last_appearance[buf]: + if assigned[buf] not in local_cache[key]: local_cache[key].append(assigned[buf]) + for i,u in enumerate(buffers): for buf in u: # all unallocated unparented buffers are fair game to replace if buf.is_allocated() or buf.lb_refcount > 0: continue - key = (buf.device, buf.size, buf.dtype) - if buf not in assigned: - if len(ll:=local_cache[key]): assigned[buf] = ll.pop() - else: assigned[buf] = Buffer(*key) - if i == last_appearance[buf]: - local_cache[key].append(assigned[buf]) + # handle view buffers + if buf._base is not None: + assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf._base, buf._base), offset=buf.offset) + else: + handle_buffer(buf) if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())): print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,", diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index fbab147a4e..f6f530862a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -80,8 +80,8 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None] inputs: List[LazyBuffer] = [] ast: List[LazyOp] = [] var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs]) - if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}: - ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs) + if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: + ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], [x.base for x in outs[0].srcs] else: for i, out in enumerate(outs): output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) @@ -121,6 +121,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La if buf.op is LoadOps.COPY: assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" realizes[buf.srcs[0].base] = None + if buf.op is LoadOps.VIEW: realizes[buf.srcs[0].base] = None for x in buf.srcs: children[x.base][buf] = None _recurse_lb(x, realizes, allbufs, simple_pads, children) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e41dd8885a..781a590a97 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -218,7 +218,8 @@ def cpu_time_execution(cb, enable): # *** ctypes helpers # TODO: make this work with read only memoryviews (if possible) -def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type)) +def from_mv(mv:memoryview, to_type=ctypes.c_char): + return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B") def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501 @functools.lru_cache(maxsize=None) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index af13640c00..a4c725ef78 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -22,6 +22,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]= if enable_cache: lazycache[cache_key] = ret return ret +view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"} class LazyBuffer: def __init__(self, device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), @@ -32,8 +33,15 @@ class LazyBuffer: # properties on base self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized" - self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype) - self.buffer.lb_refcount += 1 + + if (self.op is LoadOps.CONTIGUOUS or (self.op is UnaryOps.CAST and self.arg[1] is True)) and srcs[0].st.consecutive and \ + not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices: + # some LazyBuffers can be processed with only a view, no AST required + self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize) + self.op = LoadOps.VIEW + else: + self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype) + self.buffer.ref(1) self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None self.forced_realize = False else: @@ -42,7 +50,7 @@ class LazyBuffer: self._base = base def __del__(self): - if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1 + if hasattr(self, 'buffer'): self.buffer.ref(-1) def __repr__(self) -> str: return f"" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7f3919cae7..0cb80538dc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -26,7 +26,7 @@ class ReduceOps(Enum): """A -> B (reduce)""" SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 41ddf03bdc..b15fed3cb4 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,8 +1,8 @@ import ctypes from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import init_c_var, GraphException, getenv -from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device, BufferOptions +from tinygrad.helpers import init_c_var, GraphException +from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable from tinygrad.engine.realize import ExecItem @@ -13,20 +13,20 @@ class CUDAGraph(MultiGraphRunner): super().__init__(jit_cache, input_rawbuffers, var_vals) # Check all jit items are compatible. - if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException + if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()])) self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy) self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - self.cpu_buffers = [] for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledRunner): global_size, local_size = ji.prg.launch_dims(var_vals) new_node = cuda.CUgraphNode() - deps = self._access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node) + deps = self._access_resources([x.base for x in ji.bufs[ji.prg.outcount:] if x is not None], + [x.base for x in ji.bufs[:ji.prg.outcount] if x is not None], new_dependency=new_node) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars]) @@ -37,28 +37,14 @@ class CUDAGraph(MultiGraphRunner): self.updatable_nodes[j] = (new_node, kern_params, c_args, False) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] - src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device]) + src_dev = cast(CUDADevice, Device[src.device]) node_from = cuda.CUgraphNode() - deps = self._access_resources(read=[src], write=[dest], new_dependency=node_from) + deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None - if getenv("CUDA_P2P", int(CUDADevice.peer_access)): - cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, - dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, - WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) - else: - self.cpu_buffers.append(cpu_buffer:=Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate()) - - node_to = cuda.CUgraphNode() - cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, - dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1, - WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) - cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1, - dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, - WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1, - ctypes.byref(cp_params), dest_dev.context)) + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True) self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 23803dc9cc..56c15ff938 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -148,6 +148,7 @@ class CUDAAllocator(LRUAllocator): check(cuda.cuEventRecord(sync_event, None)) check(cuda.cuCtxSetCurrent(dest_dev.context)) check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev + def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset) class CUDADevice(Compiled): devices: List[CUDADevice] = [] diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 832ccfcf10..99899e8eea 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,10 +1,8 @@ from __future__ import annotations -import os, mmap, _posixshmem, io, functools -from typing import Dict, List, Any, Optional -from tinygrad.helpers import prod, OSX -from tinygrad.device import Compiled, Allocator, Runner, Buffer -from tinygrad.ops import UnaryOps, LazyOp, BufferOps -from tinygrad.shape.view import strides_for_shape +import os, mmap, _posixshmem, io +from typing import Optional +from tinygrad.helpers import OSX +from tinygrad.device import Compiled, Allocator class DiskBuffer: def __init__(self, device:DiskDevice, size:int, offset=0): @@ -31,32 +29,7 @@ class DiskAllocator(Allocator): fo.readinto(dest) else: dest[:] = src._buf() - -class DiskRunner(Runner): - def __init__(self, ast:LazyOp): - # two ASTs are allowed here. - assert ast.op is BufferOps.STORE, "output of AST must be store" - assert ast.arg.st.contiguous, "shapetracker must be contiguous" - # TODO: there shouldn't actually be casts here, bitcasts should fold into the load - if ast.src[0].op is UnaryOps.CAST: - top_src = ast.src[0].src[0] - assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts" - self.new_dtype = ast.src[0].arg[0] - else: - top_src = ast.src[0] - self.new_dtype = top_src.arg.dtype - assert top_src.op is BufferOps.LOAD, "top of AST must be load" - assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view" - view = top_src.arg.st.views[0] - assert view.mask is None, "view cannot have a mask" - assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides" - self.new_size = prod(view.shape) - self.new_offset = view.offset * top_src.arg.dtype.itemsize - super().__init__(f"sz 0x{self.new_size:X} offset 0x{self.new_offset:X}", "DISK") - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False): - assert len(rawbufs) == 2 - # TODO: this is a terrible hack that should be moved to lazy.py - rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset + def offset(self, buf:DiskBuffer, size:int, offset:int): return DiskBuffer(buf.device, size, offset) class DiskDevice(Compiled): def __init__(self, device:str): @@ -85,7 +58,3 @@ class DiskDevice(Compiled): if self.count == 0: if hasattr(self, 'fd'): os.close(self.fd) self.size = None - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, *ast:LazyOp): - assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels." - return DiskRunner(ast[0]) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 366a014946..227185c7f9 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -106,12 +106,12 @@ class HWComputeQueue: return self def wait(self, signal, value=0): - self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value), + self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value), (3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT return self def signal(self, signal, value=0, timestamp=False): - self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value), + self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value), (1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0] return self @@ -137,12 +137,12 @@ class HWCopyQueue: return self def wait(self, signal, value=0): - self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), value, 0x0, + self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), value, 0x0, (3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT return self def signal(self, signal, value=0, timestamp=False): - self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal).contents)), *nvdata64_le(value), + self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value), (1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0] return self