diff --git a/test/helpers.py b/test/helpers.py index 1d2bd857ee..15ee0c56a6 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -25,7 +25,7 @@ def assert_jit_cache_len(fxn, expected_len): def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support - return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI) + return device in {"RHIP", "HSA"} or (device == "CUDA" and not CI and not getenv("PTX")) if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] if device == "CUDA" and getenv("PTX") and dtype in (dtypes.int8, dtypes.uint8): return False # for CI GPU and OSX, cl_khr_fp16 isn't supported