typedef bf16 amd (#7850)

This commit is contained in:
ignaciosica
2024-11-22 16:29:01 -03:00
committed by GitHub
parent a352a6938f
commit fb10ea563e

View File

@@ -412,7 +412,7 @@ class AMDRenderer(CStyleLanguage):
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper