mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
llama: fix FP8=1 FAKEDATA=1 (#15564)
This commit is contained in:
@@ -1397,7 +1397,7 @@ def train_llama3():
|
||||
|
||||
if getenv("FAKEDATA"):
|
||||
for v in get_parameters(model):
|
||||
v = v.assign(Tensor.empty(v.shape))
|
||||
v = v.assign(Tensor.empty(v.shape, dtype=v.dtype))
|
||||
|
||||
is_dp = (DP := getenv("DP", 1)) > 1
|
||||
is_mp = (MP := getenv("MP", 1)) > 1
|
||||
|
||||
Reference in New Issue
Block a user