From de832d26c64a9ec575e47aeb58efe27a0ccf4e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Fri, 26 Apr 2024 07:20:10 +0200 Subject: [PATCH] disable bfloat16 from ptx tests (#4305) --- test/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/helpers.py b/test/helpers.py index 1d2bd857ee..15ee0c56a6 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -25,7 +25,7 @@ def assert_jit_cache_len(fxn, expected_len): def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support - return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI) + return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI and not getenv("PTX")) if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] if device == "CUDA" and getenv("PTX") and dtype in (dtypes.int8, dtypes.uint8): return False # for CI GPU and OSX, cl_khr_fp16 isn't supported