llama: fix FP8=1 FAKEDATA=1 (#15564)

This commit is contained in:
qazal
2026-04-01 14:53:03 +03:00
committed by GitHub
parent 6d1e992e89
commit 09f60d80fd

View File

@@ -1397,7 +1397,7 @@ def train_llama3():
if getenv("FAKEDATA"):
for v in get_parameters(model):
v = v.assign(Tensor.empty(v.shape))
v = v.assign(Tensor.empty(v.shape, dtype=v.dtype))
is_dp = (DP := getenv("DP", 1)) > 1
is_mp = (MP := getenv("MP", 1)) > 1