diff --git a/test/backend/test_ops.py b/test/backend/test_ops.py index 0195d60b01..0f21519637 100644 --- a/test/backend/test_ops.py +++ b/test/backend/test_ops.py @@ -1449,9 +1449,7 @@ class TestOps(unittest.TestCase): np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) - @unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and DEV.renderer == "LLVM") or IMAGE - or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE") - @unittest.skipIf(Device.DEFAULT == "QCOM", "not precise enough") + @unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "not precise enough when emulating") def test_gemm_fp16(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3, grad_atol=5e-3, grad_rtol=5e-3) def test_gemm(self): diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index c1dc753724..ee9bf3fedf 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -113,5 +113,5 @@ class WGSLRenderer(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}" - def supported_dtypes(self): - return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half} + def supported_dtypes(self): return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.int32, dtypes.uint32, + dtypes.float, *((dtypes.half,) if "shader-f16" in self.target.arch else ())} diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 2e861a0dff..7482477adb 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -217,8 +217,8 @@ class WebGpuDevice(Compiled): self.device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback, webgpu.WGPURequestDeviceStatus, 1, 2, adapter_res, dev_desc) - super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], - functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))) + program = functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)) + super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], program, arch="shader-f16" * (webgpu.WGPUFeatureName_ShaderF16 in supported)) def synchronize(self): _run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,