get_buf() for Buffer (#16134)

* p

* mypy

* x
This commit is contained in:
nimlgen
2026-05-11 16:36:14 +03:00
committed by GitHub
parent 2dd84416bf
commit ad9738892c
7 changed files with 32 additions and 12 deletions

View File

@@ -9,7 +9,7 @@ def print_objects():
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
lazybuffers = [x for x in gc.get_objects() if isinstance(x, UOp)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, Buffer) and hasattr(x, "_buf")]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, Buffer) and x.is_initialized()]
realized_buffers = [x.realized for x in lazybuffers if x.base == x and x.realized]
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]

View File

@@ -78,7 +78,7 @@ def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelSnapshot], dict[int, in
if dst_id not in buf_pool:
buf_pool[dst_id] = dst_buf.nbytes
# Get source data if it's from numpy/CPU
if hasattr(src_buf, 'base') and src_buf.base is not None and hasattr(src_buf.base, '_buf'):
if hasattr(src_buf, 'base') and src_buf.base is not None and src_buf.base.is_allocated():
src_data = bytes(src_buf.base._buf)
buf_data[dst_id] = src_data
elif ast.op is Ops.PROGRAM:

View File

@@ -103,6 +103,7 @@ class Buffer:
uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False):
assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0
self._bufs: dict[str, Any] = {}
if base is None:
assert offset == 0, "base buffers can't have offset"
self._base = None
@@ -120,13 +121,24 @@ class Buffer:
def base(self) -> Buffer: return self._base if self._base is not None else self
@property
def uop_refcount(self): return self.base._uop_refcount
@property
def _buf(self) -> Any: return self._bufs[self.device]
def ref(self, cnt):
self.base._uop_refcount += cnt
return self
# check if the underlying buffer is allocated and the current buffer/view is initialized
def is_initialized(self) -> bool: return self.is_allocated() and hasattr(self, '_buf')
def is_initialized(self) -> bool: return self.is_allocated() and self.device in self._bufs
# check if the underlying buffer is allocated, possibly from the base object
def is_allocated(self) -> bool: return self.base.is_allocated() if self._base is not None else hasattr(self, '_buf')
def is_allocated(self) -> bool: return self.base.is_allocated() if self._base is not None else self.device in self._bufs
def get_buf(self, device: str) -> Any:
if device not in self._bufs:
allocator = Device[device].allocator
if device == self.device: self.ensure_allocated()
elif self._base is not None:
assert hasattr(allocator, "_offset"), "offset function required for view"
self._bufs[device] = allocator._offset(self._base.get_buf(device), self.nbytes, self.offset)
else: self._bufs[device] = allocator._map(self.ensure_allocated()._buf)
return self._bufs[device]
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_initialized() else self
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
assert not self.is_initialized(), "can't allocate already allocated buffer"
@@ -140,25 +152,27 @@ class Buffer:
self._base.ensure_allocated()
self._base.allocated_views += 1
assert hasattr(self.allocator, "_offset"), "offset function required for view"
self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
self._bufs[self.device] = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
else:
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
self._bufs[self.device] = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
if not self.device.startswith("DISK") and (self.options is None or self.options.external_ptr is None):
GlobalCounters.mem_used += self.nbytes
GlobalCounters.mem_used_per_device[self.device] += self.nbytes
if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "alloc", self.trace_num, {"dtype":self.dtype, "sz":self.size}))
return self
def deallocate(self):
assert hasattr(self, '_buf'), "buffer must be allocated to deallocate"
assert self.device in self._bufs, "buffer must be allocated to deallocate"
if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
if self._base is None:
if GlobalCounters is not None and not self.device.startswith("DISK") and (self.options is None or self.options.external_ptr is None):
GlobalCounters.mem_used -= self.nbytes
GlobalCounters.mem_used_per_device[self.device] -= self.nbytes
if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", self.trace_num))
for dev, mb in self._bufs.items():
if dev != self.device: Device[dev].allocator._unmap(mb)
self.allocator.free(self._buf, self.nbytes, self.options)
elif self._base is not None: self._base.allocated_views -= 1
del self._buf
self._bufs.clear()
def __reduce__(self):
buf = None
if self._base is not None:
@@ -175,7 +189,7 @@ class Buffer:
@property
def nbytes(self): return self.size*self.dtype.itemsize
@suppress_finalizing
def __del__(self): (not hasattr(self, '_buf')) or self.deallocate()
def __del__(self): (self.device not in self._bufs) or self.deallocate()
def __repr__(self):
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
(f" offset:{self.offset}" if self._base is not None else "") + (f" {self.options=}" if self.options is not None else "") + ">"
@@ -227,6 +241,8 @@ class Allocator(Generic[DeviceType]):
def _free(self, opaque, options:BufferSpec): pass # if opaque is a Python object, you don't need a free
def _copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def _copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
def _map(self, buf): raise NotImplementedError("need map")
def _unmap(self, mb): pass # default no-op; override if _map allocates iface-side state
# def _as_buffer(self, src) -> memoryview:
# def _offset(self, buf, size:int, offset:int):
# def _transfer(self, dest, src, sz:int, src_dev, dest_dev):

View File

@@ -209,7 +209,7 @@ class CapturedJit(Generic[ReturnType]):
for u in self._written_uops:
if (buf:=buffers.get(u)) is None: continue
for b in (buf.bufs if isinstance(buf, MultiBuffer) else (buf,)):
if hasattr(b, '_buf'): b.deallocate()
if b.is_initialized(): b.deallocate()
if (base:=b._base) is not None and base.allocated_views == 0 and base.is_allocated(): base.deallocate()
def _prepare_jit_inputs(args, kwargs):

View File

@@ -130,6 +130,8 @@ class CPUAllocator(HCQAllocator):
return to_mv(src.va_addr, src.size)
def _map(self, buf:HCQBuffer):
if buf.view is None or not isinstance(buf.view, MMIOInterface): raise RuntimeError("Cannot map buffer without view to cpu")
return HCQBuffer(buf.view.addr, buf.size, view=buf.view, owner=buf.owner)
def _unmap(self, mb): pass # CPU _map returns a view wrapper, nothing to release
class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):

View File

@@ -81,6 +81,7 @@ class RDMAAllocator(HCQAllocatorBase):
meta=self.dev.iface.mlx_dev.register_mem(pages, len(pages) * page_sz, page_sz.bit_length() - 1))
def _do_free(self, buf:HCQBuffer, options): self.dev.iface.mlx_dev.unregister_mem(buf.meta)
def _unmap(self, mb): self.dev.iface.mlx_dev.unregister_mem(mb.meta)
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQCompiled, dest_dev:HCQCompiled):
# sync device

View File

@@ -564,10 +564,11 @@ class HCQAllocatorBase(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
@suppress_finalizing
def _free(self, buf:HCQBuffer, options:BufferSpec|None=None):
for dev in buf.mapped_devs: dev.synchronize()
for d, mb in buf.mappings.items():
if hasattr(d.allocator, '_do_free'): d.allocator._do_free(mb, options)
for d, mb in buf.mappings.items(): d.allocator._unmap(mb)
if hasattr(self, '_do_free'): self._do_free(buf, options)
def _unmap(self, mb): self.dev.iface.free(mb)
def _offset(self, buf, size:int, offset:int) -> HCQBuffer: return buf.offset(offset=offset, size=size)
class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):