mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
webgpu: shader-f16 support in arch (#16370)
This commit is contained in:
committed by
GitHub
parent
4bcc53eb26
commit
d8f86be613
@@ -1449,9 +1449,7 @@ class TestOps(unittest.TestCase):
|
||||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and DEV.renderer == "LLVM") or IMAGE
|
||||
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "not precise enough")
|
||||
@unittest.skipUnless(dtypes.half in Device[Device.DEFAULT].renderer.supported_dtypes(), "not precise enough when emulating")
|
||||
def test_gemm_fp16(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3, grad_atol=5e-3, grad_rtol=5e-3)
|
||||
def test_gemm(self):
|
||||
|
||||
@@ -113,5 +113,5 @@ class WGSLRenderer(CStyleLanguage):
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
def supported_dtypes(self):
|
||||
return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half}
|
||||
def supported_dtypes(self): return {dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.int32, dtypes.uint32,
|
||||
dtypes.float, *((dtypes.half,) if "shader-f16" in self.target.arch else ())}
|
||||
|
||||
@@ -217,8 +217,8 @@ class WebGpuDevice(Compiled):
|
||||
self.device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
|
||||
webgpu.WGPURequestDeviceStatus, 1, 2, adapter_res, dev_desc)
|
||||
|
||||
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer],
|
||||
functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
|
||||
program = functools.partial(WebGPUProgram, (self.device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))
|
||||
super().__init__(device, WebGpuAllocator(self), [WGSLRenderer], program, arch="shader-f16" * (webgpu.WGPUFeatureName_ShaderF16 in supported))
|
||||
|
||||
def synchronize(self):
|
||||
_run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
|
||||
|
||||
Reference in New Issue
Block a user