Fix cast error in render_load in wgsl (#1956)

* Fix cast error in wgsl

* User render_cast intead of introducing new method

* Make it shorter

* Add back webgpu tests: efficientnet and dtypes
This commit is contained in:
Ahmed Harmouche
2023-10-04 11:29:14 +02:00
committed by GitHub
parent 6a79d4044a
commit fb4d830a2a
4 changed files with 16 additions and 6 deletions

View File

@@ -203,9 +203,11 @@ jobs:
- name: Run linearizer and tensor core test
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
#- name: Run webgpu pytest
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/
#- name: Build WEBGPU Efficientnet
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto
- name: Run webgpu dtype tests
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py
- name: Build WEBGPU Efficientnet
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
tests:
strategy:

View File

@@ -152,6 +152,7 @@ class TestInt32Dtype(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable")
class TestBoolDtype(unittest.TestCase):
def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32])
def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool)

View File

@@ -40,6 +40,7 @@ class CStyleLanguage(NamedTuple):
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
@@ -61,10 +62,12 @@ class CStyleLanguage(NamedTuple):
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
cast = f"({output_dtype.name})" if output_dtype != buf_dtype else ""
if output_dtype.sz > 1:
return f"{cast}(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx})))"
return f"{cast}(*({buf_name}+{idx}))" if self.uses_ptr_arithmetic else f"{cast}({buf_name}[{idx}])"
out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
else:
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"

View File

@@ -45,6 +45,10 @@ class WGSLLanguage(CStyleLanguage):
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if type_map[var_dtype]: return f"{type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"