webgpu: shader-f16 support in arch (#16370)

This commit is contained in:
Christopher Milan
2026-05-25 16:20:59 -07:00
committed by GitHub
parent 4bcc53eb26
commit d8f86be613
3 changed files with 5 additions and 7 deletions

View File

@@ -1449,9 +1449,7 @@ class TestOps(unittest.TestCase):
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and DEV.renderer == "LLVM") or IMAGE
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
@unittest.skipIf(Device.DEFAULT == "QCOM", "not precise enough")
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "not precise enough when emulating")
def test_gemm_fp16(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3, grad_atol=5e-3, grad_rtol=5e-3)
def test_gemm(self):

View File

@@ -113,5 +113,5 @@ class WGSLRenderer(CStyleLanguage):
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
def supported_dtypes(self):
return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half}
def supported_dtypes(self): return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.int32, dtypes.uint32,
dtypes.float, *((dtypes.half,) if "shader-f16" in self.target.arch else ())}

View File

@@ -217,8 +217,8 @@ class WebGpuDevice(Compiled):
self.device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
webgpu.WGPURequestDeviceStatus, 1, 2, adapter_res, dev_desc)
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer],
functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
program = functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], program, arch="shader-f16" * (webgpu.WGPUFeatureName_ShaderF16 in supported))
def synchronize(self):
_run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,