mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Fix cast error in render_load in wgsl (#1956)
* Fix cast error in wgsl * User render_cast intead of introducing new method * Make it shorter * Add back webgpu tests: efficientnet and dtypes
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -203,9 +203,11 @@ jobs:
|
||||
- name: Run linearizer and tensor core test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
|
||||
#- name: Run webgpu pytest
|
||||
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/
|
||||
#- name: Build WEBGPU Efficientnet
|
||||
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
|
||||
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto
|
||||
- name: Run webgpu dtype tests
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
|
||||
|
||||
tests:
|
||||
strategy:
|
||||
|
||||
@@ -152,6 +152,7 @@ class TestInt32Dtype(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
|
||||
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable")
|
||||
class TestBoolDtype(unittest.TestCase):
|
||||
def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32])
|
||||
def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool)
|
||||
|
||||
@@ -40,6 +40,7 @@ class CStyleLanguage(NamedTuple):
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
@@ -61,10 +62,12 @@ class CStyleLanguage(NamedTuple):
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
cast = f"({output_dtype.name})" if output_dtype != buf_dtype else ""
|
||||
if output_dtype.sz > 1:
|
||||
return f"{cast}(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx})))"
|
||||
return f"{cast}(*({buf_name}+{idx}))" if self.uses_ptr_arithmetic else f"{cast}({buf_name}[{idx}])"
|
||||
out_val = f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
@@ -45,6 +45,10 @@ class WGSLLanguage(CStyleLanguage):
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select(f32({y}), {x}, bool({cond}))"
|
||||
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
if type_map[var_dtype]: return f"{type_map[var_dtype]}({x[0]})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
|
||||
|
||||
Reference in New Issue
Block a user