mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
refactor webgpu (#16406)
This commit is contained in:
committed by
GitHub
parent
202adc644e
commit
0ae957bb0a
@@ -1,157 +1,120 @@
|
||||
import functools, struct
|
||||
from tinygrad.device import Compiled, Allocator, BufferSpec
|
||||
from tinygrad.renderer.wgsl import WGSLRenderer
|
||||
from tinygrad.helpers import round_up, suppress_finalizing
|
||||
from tinygrad.helpers import round_up, suppress_finalizing, getenv, to_mv
|
||||
from tinygrad.runtime.autogen import webgpu
|
||||
from tinygrad.runtime.support import c
|
||||
from typing import cast, List, Any, TypeAlias
|
||||
from typing import Callable
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
WGPUDevPtr: TypeAlias = webgpu.WGPUDevice
|
||||
WGPUBufPtr: TypeAlias = webgpu.WGPUBuffer
|
||||
backend_types = {v: k for k, v in webgpu.enum_WGPUBackendType.items()}
|
||||
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features=webgpu.WGPUInstanceFeatures(timedWaitAnyEnable=True)))
|
||||
|
||||
backend_types = {v: k for k, v in webgpu.enum_WGPUBackendType.items() }
|
||||
def from_wgpu_str(string_view:webgpu.WGPUStringView) -> str: return ctypes.string_at(string_view.data, string_view.length).decode()
|
||||
def to_wgpu_str(_str:str) -> webgpu.WGPUStringView: return webgpu.WGPUStringView(data=ctypes.create_string_buffer(_str.encode()), length=len(_str))
|
||||
|
||||
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
|
||||
# gets a memoryview from a buffer, which is assumed to have MAP_READ (see _readable_buffer)
|
||||
def buf_to_mv(buf:webgpu.WGPUBuffer) -> memoryview:
|
||||
BufferMapAsync(buf, webgpu.WGPUMapMode_Read, 0, size:=webgpu.wgpuBufferGetSize(buf))
|
||||
return to_mv(webgpu.wgpuBufferGetConstMappedRange(buf, 0, size), size)
|
||||
|
||||
def to_c_string(_str:str) -> ctypes.Array: return ctypes.create_string_buffer(_str.encode('utf-8'))
|
||||
# turns a webgpu function returning a future into python-synchronous function
|
||||
# the new function handles the status code and optional error message, returning the other callback arguments
|
||||
def synchronous(status_enum:dict[int, str], has_emsg:bool=False):
|
||||
def wrap(fn:Callable[..., webgpu.WGPUFuture]) -> Callable:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args):
|
||||
status, payload, emsg = 0, [], None
|
||||
|
||||
def from_wgpu_str(string_view:webgpu.struct_WGPUStringView) -> str: return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
|
||||
@next(ty for nm, ty, *_ in fn.argtypes[-1]._real_fields_ if nm == "callback") # type: ignore
|
||||
def cb(s:int, *args):
|
||||
nonlocal status, payload, emsg
|
||||
# the last two arguments are "userdata1" and "userdata2", which we drop
|
||||
# we must process wgpu strings in this callback, as they will be freed after we return
|
||||
status, (*payload, emsg) = s, [from_wgpu_str(a) if type(a) is webgpu.WGPUStringView else a for a in args[:-2]] + ([] if has_emsg else [None])
|
||||
|
||||
def to_wgpu_str(_str:str) -> webgpu.struct_WGPUStringView:
|
||||
return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
|
||||
future = fn(*args, fn.argtypes[-1](mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb)) # type: ignore
|
||||
if (future_status:=webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future), 2**64-1)) != webgpu.WGPUWaitStatus_Success:
|
||||
raise RuntimeError(f"error while waiting for future ({fn.__name__}): {webgpu.enum_WGPUWaitStatus.get(future_status)}")
|
||||
|
||||
def _wait(future:webgpu.struct_WGPUFuture):
|
||||
assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
|
||||
if status != 1: raise RuntimeError(f"[{status_enum.get(status)}]{emsg or ''}")
|
||||
return payload if len(payload) > 1 else payload[0] if len(payload) == 1 else None
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
def write_buffer(device:WGPUDevPtr, buf:WGPUBufPtr, offset:int, src:memoryview|bytearray|bytes):
|
||||
src = bytearray(src)
|
||||
webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
|
||||
|
||||
def _run(async_fun, cb_info_type, cb_type, status_enum:dict|None, res_idx:int|None, msg_idx:int|None, *params):
|
||||
result: List[Any] = []
|
||||
|
||||
def cb(*params):
|
||||
result[:] = params
|
||||
if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx])
|
||||
|
||||
cb_info = cb_info_type(mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
|
||||
_wait(async_fun(*params, cb_info))
|
||||
|
||||
if result[0] != 1: raise RuntimeError(f"[{status_enum.get(result[0]) if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
|
||||
return result[res_idx] if res_idx else None
|
||||
|
||||
def copy_buffer_to_buffer(dev:WGPUDevPtr, src:WGPUBufPtr, src_offset:int, dst:WGPUBufPtr, dst_offset:int, size:int):
|
||||
encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
|
||||
webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
|
||||
cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
|
||||
webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(dev), 1, (webgpu.WGPUCommandBuffer*1)(cb))
|
||||
webgpu.wgpuCommandBufferRelease(cb)
|
||||
webgpu.wgpuCommandEncoderRelease(encoder)
|
||||
|
||||
def read_buffer(dev:WGPUDevPtr, buf:WGPUBufPtr) -> memoryview:
|
||||
size = webgpu.wgpuBufferGetSize(buf)
|
||||
tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
|
||||
usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
|
||||
copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size)
|
||||
_run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.enum_WGPUBufferMapAsyncStatus, None, 0,
|
||||
tmp_buffer, webgpu.WGPUMapMode_Read, 0, size)
|
||||
void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p)
|
||||
buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value))
|
||||
webgpu.wgpuBufferUnmap(tmp_buffer)
|
||||
webgpu.wgpuBufferDestroy(tmp_buffer)
|
||||
return memoryview(buf_copy).cast("B")
|
||||
|
||||
def pop_error(device:WGPUDevPtr) -> str:
|
||||
return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
|
||||
|
||||
def create_uniform(wgpu_device:WGPUDevPtr, val:int|float) -> WGPUBufPtr:
|
||||
buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
|
||||
webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
|
||||
write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
|
||||
return buf
|
||||
BufferMapAsync = synchronous(webgpu.enum_WGPUBufferMapAsyncStatus, True)(webgpu.wgpuBufferMapAsync2)
|
||||
DevicePopErrorScope = synchronous(webgpu.enum_WGPUPopErrorScopeStatus)(webgpu.wgpuDevicePopErrorScope2)
|
||||
DeviceCreateComputePipeline = synchronous(webgpu.enum_WGPUCreatePipelineAsyncStatus, True)(webgpu.wgpuDeviceCreateComputePipelineAsync2)
|
||||
InstanceRequestAdapter = synchronous(webgpu.enum_WGPURequestAdapterStatus, True)(webgpu.wgpuInstanceRequestAdapter2)
|
||||
AdapterRequestDevice = synchronous(webgpu.enum_WGPURequestDeviceStatus, True)(webgpu.wgpuAdapterRequestDevice2)
|
||||
QueueOnSubmittedWorkDone = synchronous(webgpu.enum_WGPUQueueWorkDoneStatus)(webgpu.wgpuQueueOnSubmittedWorkDone2)
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, dev:tuple[WGPUDevPtr, bool], name:str, lib:bytes, **kwargs):
|
||||
(self.dev, self.timestamp_supported) = dev
|
||||
def __init__(self, dev:'WebGpuDevice', name:str, lib:bytes, **kwargs):
|
||||
self.dev, self.name = dev, to_wgpu_str(name)
|
||||
|
||||
# Creating shader module
|
||||
shader = webgpu.WGPUShaderModuleWGSLDescriptor(code=to_wgpu_str(lib.decode()),
|
||||
chain=webgpu.WGPUChainedStruct(sType=webgpu.WGPUSType_ShaderSourceWGSL))
|
||||
module = webgpu.WGPUShaderModuleDescriptor()
|
||||
module.nextInChain = ctypes.cast(ctypes.pointer(shader), c.POINTER[webgpu.struct_WGPUChainedStruct])
|
||||
chain=webgpu.WGPUChainedStruct(sType=webgpu.WGPUSType_ShaderSourceWGSL))
|
||||
module = webgpu.WGPUShaderModuleDescriptor(nextInChain=ctypes.cast(ctypes.pointer(shader), ctypes.POINTER(webgpu.struct_WGPUChainedStruct)))
|
||||
|
||||
# Check compiler error
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
|
||||
shader_module = webgpu.wgpuDeviceCreateShaderModule(self.dev, module)
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev.device_res, webgpu.WGPUErrorFilter_Validation)
|
||||
self.prg = webgpu.wgpuDeviceCreateShaderModule(self.dev.device_res, module)
|
||||
if err := self.dev.pop_error(): raise RuntimeError(f"Shader compilation failed: {err}")
|
||||
|
||||
if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
|
||||
@suppress_finalizing
|
||||
def __del__(self): webgpu.wgpuShaderModuleRelease(self.prg)
|
||||
|
||||
self.name, self.lib, self.prg = name, lib, shader_module
|
||||
def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
def __call__(self, *bufs:webgpu.WGPUBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
|
||||
wait = wait and self.timestamp_supported
|
||||
tmp_bufs = [*bufs]
|
||||
buf_patch = False
|
||||
|
||||
# WebGPU does not allow using the same buffer for input and output
|
||||
for i in range(1, len(bufs)):
|
||||
if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]):
|
||||
tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
|
||||
webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
|
||||
buf_patch = True
|
||||
wait = wait and webgpu.WGPUFeatureName_TimestampQuery in self.dev.features
|
||||
|
||||
# Creating bind group layout
|
||||
binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute,
|
||||
buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))]
|
||||
binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute,
|
||||
buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs)
|
||||
else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))]
|
||||
def bgl_entry(n:int, ty:str):
|
||||
return webgpu.WGPUBindGroupLayoutEntry(binding=n, visibility=webgpu.WGPUShaderStage_Compute,
|
||||
buffer=webgpu.WGPUBufferBindingLayout(type=getattr(webgpu, f'WGPUBufferBindingType_{ty}')))
|
||||
bind_entries = (webgpu.WGPUBindGroupLayoutEntry * (1+len(bufs)+len(vals)))(
|
||||
bgl_entry(0, 'Uniform'), *(bgl_entry(i+1, 'Uniform' if i >= len(bufs) else 'Storage') for i in range(len(bufs)+len(vals))))
|
||||
|
||||
bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts)
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
|
||||
bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor(
|
||||
entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))]
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev.device_res, webgpu.WGPUErrorFilter_Validation)
|
||||
bind_layout = webgpu.wgpuDeviceCreateBindGroupLayout(self.dev.device_res,
|
||||
webgpu.WGPUBindGroupLayoutDescriptor(entryCount=len(bind_entries), entries=bind_entries))
|
||||
|
||||
if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}")
|
||||
if err := self.dev.pop_error(): raise RuntimeError(f"Error creating bind group layout: {err}")
|
||||
|
||||
# Creating pipeline layout
|
||||
pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts),
|
||||
bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts))
|
||||
pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=1, bindGroupLayouts=(webgpu.WGPUBindGroupLayout*1)(bind_layout))
|
||||
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
|
||||
pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc)
|
||||
|
||||
if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}")
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev.device_res, webgpu.WGPUErrorFilter_Validation)
|
||||
pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev.device_res, pipeline_layout_desc)
|
||||
if err := self.dev.pop_error(): raise RuntimeError(f"Error creating pipeline layout: {err}")
|
||||
|
||||
# Creating bind group
|
||||
bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)]
|
||||
bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, cast(int, x)) if i >= len(tmp_bufs) else x, offset=0,
|
||||
size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)]
|
||||
def bg_entry(n:int, x:webgpu.WGPUBuffer|int|float):
|
||||
buf = x if isinstance(x, webgpu.WGPUBuffer) else self.dev.create_uniform(x)
|
||||
return webgpu.WGPUBindGroupEntry(binding=n, buffer=buf, offset=0, size=webgpu.wgpuBufferGetSize(buf))
|
||||
bindings = (webgpu.WGPUBindGroupEntry * (1+len(bufs)+len(vals)))(bg_entry(0, float('inf')), *(bg_entry(i+1, x) for i,x in enumerate(bufs+vals)))
|
||||
|
||||
bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings)
|
||||
bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings))
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
|
||||
bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc)
|
||||
|
||||
if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}")
|
||||
bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_layout, entryCount=len(bindings), entries=bindings)
|
||||
webgpu.wgpuDevicePushErrorScope(self.dev.device_res, webgpu.WGPUErrorFilter_Validation)
|
||||
bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev.device_res, bind_group_desc)
|
||||
if err := self.dev.pop_error(): raise RuntimeError(f"Error creating bind group: {err}")
|
||||
|
||||
# Creating compute pipeline
|
||||
compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
|
||||
compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
|
||||
pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
|
||||
webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.enum_WGPUCreatePipelineAsyncStatus, 1, None, self.dev, compute_desc)
|
||||
compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=self.name))
|
||||
pipeline_result = DeviceCreateComputePipeline(self.dev.device_res, compute_desc)
|
||||
|
||||
command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
|
||||
command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev.device_res, webgpu.WGPUCommandEncoderDescriptor())
|
||||
comp_pass_desc = webgpu.WGPUComputePassDescriptor()
|
||||
|
||||
if wait:
|
||||
query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2))
|
||||
query_buf = webgpu.wgpuDeviceCreateBuffer(self.dev,
|
||||
webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc))
|
||||
comp_pass_desc.timestampWrites = c.pointer(webgpu.WGPUComputePassTimestampWrites(
|
||||
querySet=query_set, beginningOfPassWriteIndex=0, endOfPassWriteIndex=1))
|
||||
query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev.device_res, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2))
|
||||
query_buf = webgpu.wgpuDeviceCreateBuffer(
|
||||
self.dev.device_res, webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc))
|
||||
comp_pass_desc.timestampWrites = c.pointer(webgpu.WGPUComputePassTimestampWrites(querySet=query_set, beginningOfPassWriteIndex=0,
|
||||
endOfPassWriteIndex=1))
|
||||
|
||||
# Begin compute pass
|
||||
compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc)
|
||||
@@ -163,63 +126,96 @@ class WebGPUProgram:
|
||||
if wait: webgpu.wgpuCommandEncoderResolveQuerySet(command_encoder, query_set, 0, 2, query_buf, 0)
|
||||
|
||||
cmd_buf = webgpu.wgpuCommandEncoderFinish(command_encoder, webgpu.WGPUCommandBufferDescriptor())
|
||||
webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(self.dev), 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
|
||||
webgpu.wgpuQueueSubmit(self.dev.queue, 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
|
||||
|
||||
if buf_patch:
|
||||
copy_buffer_to_buffer(self.dev, tmp_bufs[0], 0, bufs[0], 0, webgpu.wgpuBufferGetSize(bufs[0]))
|
||||
webgpu.wgpuBufferDestroy(tmp_bufs[0])
|
||||
# release created objects
|
||||
webgpu.wgpuBindGroupLayoutRelease(bind_layout)
|
||||
webgpu.wgpuPipelineLayoutRelease(pipeline_layout)
|
||||
webgpu.wgpuBindGroupRelease(bind_group)
|
||||
webgpu.wgpuComputePipelineRelease(pipeline_result)
|
||||
webgpu.wgpuCommandEncoderRelease(command_encoder)
|
||||
webgpu.wgpuComputePassEncoderRelease(compute_pass)
|
||||
webgpu.wgpuCommandBufferRelease(cmd_buf)
|
||||
|
||||
if wait:
|
||||
time = ((timestamps:=read_buffer(self.dev, query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9
|
||||
webgpu.wgpuBufferDestroy(query_buf)
|
||||
time = ((timestamps:=buf_to_mv(tmp_buf:=self.dev._readable_buffer(query_buf)).cast("Q").tolist())[1] - timestamps[0]) / 1e9
|
||||
self.dev.free(query_buf)
|
||||
self.dev.free(tmp_buf)
|
||||
webgpu.wgpuQuerySetDestroy(query_set)
|
||||
webgpu.wgpuQuerySetRelease(query_set)
|
||||
return time
|
||||
return None
|
||||
|
||||
class WebGpuAllocator(Allocator['WebGpuDevice']):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> WGPUBufPtr:
|
||||
def _alloc(self, size:int, options:BufferSpec) -> webgpu.WGPUBuffer:
|
||||
# WebGPU buffers have to be 4-byte aligned
|
||||
return webgpu.wgpuDeviceCreateBuffer(self.dev.device_res, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
|
||||
usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
|
||||
def _copyin(self, dest:WGPUBufPtr, src:memoryview):
|
||||
def _copyin(self, dest:webgpu.WGPUBuffer, src:memoryview):
|
||||
if src.nbytes % 4:
|
||||
padded_src = bytearray(round_up(src.nbytes, 4))
|
||||
padded_src[:src.nbytes] = src
|
||||
write_buffer(self.dev.device_res, dest, 0, padded_src if src.nbytes % 4 else src)
|
||||
def _copyout(self, dest:memoryview, src:WGPUBufPtr):
|
||||
buffer_data = read_buffer(self.dev.device_res, src)
|
||||
dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data
|
||||
@suppress_finalizing
|
||||
def _free(self, opaque:WGPUBufPtr, options:BufferSpec): webgpu.wgpuBufferDestroy(opaque)
|
||||
self.dev.write_buffer(dest, padded_src if src.nbytes % 4 else src)
|
||||
def _copyout(self, dest:memoryview, src:webgpu.WGPUBuffer):
|
||||
dest[:] = buf_to_mv(tmp_buf:=self.dev._readable_buffer(src))[:dest.nbytes]
|
||||
self.dev.free(tmp_buf)
|
||||
|
||||
def _free(self, opaque:webgpu.WGPUBuffer, options:BufferSpec): self.dev.free(opaque)
|
||||
|
||||
class WebGpuDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
# Requesting an adapter
|
||||
adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
|
||||
webgpu.enum_WGPUCreatePipelineAsyncStatus, 1, 2, instance, webgpu.WGPURequestAdapterOptions(
|
||||
powerPreference=webgpu.WGPUPowerPreference_HighPerformance, backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0)))
|
||||
adapter_res = InstanceRequestAdapter(instance, webgpu.WGPURequestAdapterOptions(
|
||||
powerPreference=webgpu.WGPUPowerPreference_HighPerformance, backendType=backend_types.get(getenv("WEBGPU_BACKEND", ""), 0)))
|
||||
|
||||
# Get supported features
|
||||
supported_features = webgpu.WGPUSupportedFeatures()
|
||||
webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features)
|
||||
supported = [supported_features.features[i] for i in range(supported_features.featureCount)]
|
||||
features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported]
|
||||
dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),
|
||||
requiredFeatures=c.Array(webgpu.WGPUFeatureName, len(features))(*features)) # type: ignore
|
||||
webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features:=webgpu.WGPUSupportedFeatures())
|
||||
self.features = [feat for i in range(supported_features.featureCount)
|
||||
if (feat:=supported_features.features[i]) in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16]]
|
||||
webgpu.wgpuSupportedFeaturesFreeMembers(supported_features)
|
||||
dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(self.features),
|
||||
requiredFeatures=(webgpu.WGPUFeatureName * len(self.features))(*self.features))
|
||||
|
||||
# Limits
|
||||
supported_limits = webgpu.WGPUSupportedLimits()
|
||||
webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
|
||||
limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits)
|
||||
dev_desc.requiredLimits = c.pointer(limits)
|
||||
webgpu.wgpuAdapterGetLimits(adapter_res, supported_limits:=webgpu.WGPUSupportedLimits())
|
||||
dev_desc.requiredLimits = c.pointer(webgpu.WGPURequiredLimits(limits=supported_limits.limits))
|
||||
|
||||
# Requesting a device
|
||||
self.device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
|
||||
webgpu.enum_WGPURequestDeviceStatus, 1, 2, adapter_res, dev_desc)
|
||||
self.device_res = AdapterRequestDevice(adapter_res, dev_desc)
|
||||
self.queue = webgpu.wgpuDeviceGetQueue(self.device_res)
|
||||
|
||||
program = functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))
|
||||
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], program, arch="shader-f16" * (webgpu.WGPUFeatureName_ShaderF16 in supported))
|
||||
webgpu.wgpuAdapterRelease(adapter_res)
|
||||
|
||||
def synchronize(self):
|
||||
_run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
|
||||
webgpu.enum_WGPUQueueWorkDoneStatus, None, None, webgpu.wgpuDeviceGetQueue(self.device_res))
|
||||
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], functools.partial(WebGPUProgram, self),
|
||||
arch="shader-f16" * (webgpu.WGPUFeatureName_ShaderF16 in self.features))
|
||||
|
||||
def synchronize(self): QueueOnSubmittedWorkDone(self.queue)
|
||||
|
||||
@suppress_finalizing
|
||||
def free(self, buf:webgpu.WGPUBuffer):
|
||||
if webgpu.wgpuBufferGetMapState(buf) == webgpu.WGPUBufferMapState_Mapped: webgpu.wgpuBufferUnmap(buf)
|
||||
webgpu.wgpuBufferDestroy(buf)
|
||||
webgpu.wgpuBufferRelease(buf)
|
||||
|
||||
def pop_error(self) -> str: return DevicePopErrorScope(self.device_res)[1]
|
||||
def create_uniform(self, val:int|float) -> webgpu.WGPUBuffer:
|
||||
buf = webgpu.wgpuDeviceCreateBuffer(self.device_res,
|
||||
webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
|
||||
self.write_buffer(buf, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
|
||||
return buf
|
||||
def _readable_buffer(self, buf:webgpu.WGPUBuffer) -> webgpu.WGPUBuffer:
|
||||
size = webgpu.wgpuBufferGetSize(buf)
|
||||
ret = webgpu.wgpuDeviceCreateBuffer(self.device_res,
|
||||
webgpu.WGPUBufferDescriptor(size=size, usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
|
||||
|
||||
# copy_buffer_to_buffer
|
||||
encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.device_res, webgpu.WGPUCommandEncoderDescriptor())
|
||||
webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, buf, 0, ret, 0, size)
|
||||
cmd_buf = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
|
||||
webgpu.wgpuQueueSubmit(self.queue, 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
|
||||
webgpu.wgpuCommandBufferRelease(cmd_buf)
|
||||
webgpu.wgpuCommandEncoderRelease(encoder)
|
||||
|
||||
return ret
|
||||
def write_buffer(self, buf:webgpu.WGPUBuffer, src:memoryview|bytearray|bytes):
|
||||
webgpu.wgpuQueueWriteBuffer(self.queue, buf, 0, (ctypes.c_uint8 * len(src)).from_buffer_copy(src), len(src))
|
||||
|
||||
@@ -132,6 +132,7 @@ class DLL(ctypes.CDLL):
|
||||
nonlocal cfunc
|
||||
if cfunc is None: (cfunc:=getattr(self, fn.__name__)).argtypes, cfunc.restype = argtypes, restype
|
||||
return cfunc(*args)
|
||||
wrapper.restype, wrapper.argtypes = restype, argtypes # type: ignore
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
|
||||
Reference in New Issue
Block a user