From fb4d830a2a7c2bd5e40d03ab61df0c85f6f60ec0 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Wed, 4 Oct 2023 11:29:14 +0200 Subject: [PATCH] Fix cast error in render_load in wgsl (#1956) * Fix cast error in wgsl * User render_cast intead of introducing new method * Make it shorter * Add back webgpu tests: efficientnet and dtypes --- .github/workflows/test.yml | 8 +++++--- test/test_dtype.py | 1 + tinygrad/renderer/cstyle.py | 9 ++++++--- tinygrad/renderer/wgsl.py | 4 ++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c00d10318c..9fc3a907c8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -203,9 +203,11 @@ jobs: - name: Run linearizer and tensor core test run: METAL=1 python -m pytest -n=auto test/test_linearizer.py #- name: Run webgpu pytest - # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/ - #- name: Build WEBGPU Efficientnet - # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet + # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto + - name: Run webgpu dtype tests + run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py + - name: Build WEBGPU Efficientnet + run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet tests: strategy: diff --git a/test/test_dtype.py b/test/test_dtype.py index b56543042f..b090a72b01 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -152,6 +152,7 @@ class TestInt32Dtype(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64") def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64) +@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable") class TestBoolDtype(unittest.TestCase): def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32]) def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1508ccf527..98882d247f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -40,6 +40,7 @@ class CStyleLanguage(NamedTuple): # returns a str expression of the casted xs with the given type def render_cast(self, x:List[str], var_dtype:DType) -> str: + if len(x) == 1: return f"({var_dtype.name})({x[0]})" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert self.float4 is not None, "cast is not supported on this platform" if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})" @@ -61,10 +62,12 @@ class CStyleLanguage(NamedTuple): return f"read_imagef({buf_name}, smp, {idx})" if self.uses_vload and buf_dtype == dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" - cast = f"({output_dtype.name})" if output_dtype != buf_dtype else "" if output_dtype.sz > 1: - return f"{cast}(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx})))" - return f"{cast}(*({buf_name}+{idx}))" if self.uses_ptr_arithmetic else f"{cast}({buf_name}[{idx}])" + out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" + else: + out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" + + return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val def render_local(self, name:str, size:int): return self.smem_prefix + f"float {name}[{size}];" diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 897aa709e5..40a3431e8a 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -45,6 +45,10 @@ class WGSLLanguage(CStyleLanguage): def render_conditional(self, cond:str, x:str, y:str) -> str: return f"select(f32({y}), {x}, bool({cond}))" + + def render_cast(self, x:List[str], var_dtype:DType) -> str: + if type_map[var_dtype]: return f"{type_map[var_dtype]}({x[0]})" + raise NotImplementedError(f"no cast for {var_dtype}") def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"