llama: mllog (#15502)

This commit is contained in:
wozeparrot
2026-03-29 02:18:25 +08:00
committed by GitHub
parent 7e57e101d5
commit 0c3e438229

View File

@@ -1287,6 +1287,9 @@ def train_llama3():
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
from examples.mlperf.optim import GradAccClipAdamW
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
LOGMLPERF = getenv("LOGMLPERF")
BENCHMARK = getenv("BENCHMARK")
config = {}
@@ -1309,15 +1312,60 @@ def train_llama3():
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. DEV=AMD AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
# trains to 7
if LOGMLPERF:
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename=f"result_llama31_{SEED}.log")
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
MLLOGGER = mllog.get_mllogger()
MLLOGGER.logger.propagate = False
LLAMA_BENCHMARK = mllog_constants.LLAMA31_405B if getenv("LLAMA3_SIZE", "8B") == "405B" else mllog_constants.LLAMA31_8B
if INITMLPERF:
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=LLAMA_BENCHMARK)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
MLLOGGER.event(key=mllog_constants.SEED, value=SEED)
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=GBS)
MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=SEQLEN)
MLLOGGER.event(key=mllog_constants.MAX_STEPS, value=MAX_STEPS)
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=grad_acc)
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=EVAL_SAMPLES)
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=SAMPLES)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value=mllog_constants.ADAMW)
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=LR)
MLLOGGER.event(key=mllog_constants.OPT_END_LR, value=END_LR)
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_1, value=0.9)
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_2, value=0.95)
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_EPSILON, value=1e-5)
MLLOGGER.event(key=mllog_constants.OPT_ADAMW_WEIGHT_DECAY, value=0.1)
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=WARMUP_STEPS)
MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=WARMUP_STEPS)
MLLOGGER.event(key=mllog_constants.OPT_LR_DECAY_STEPS, value=MAX_STEPS - WARMUP_STEPS)
MLLOGGER.event(key=mllog_constants.OPT_GRADIENT_CLIP_NORM, value=1.0)
else:
MLLOGGER = None
opt_adamw_beta_1 = 0.9
opt_adamw_beta_2 = 0.95
opt_adamw_epsilon = 1e-5
opt_adamw_weight_decay = 0.1
opt_gradient_clip_norm = 1.0
opt_learning_rate_warmup_steps = WARMUP_STEPS
opt_learning_rate_decay_steps = MAX_STEPS - opt_learning_rate_warmup_steps
opt_base_learning_rate = LR
@@ -1451,6 +1499,11 @@ def train_llama3():
train_iter = get_train_iter()
i, sequences_seen = resume_ckpt, 0
step_times = []
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EPOCH_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
while i < MAX_STEPS:
GlobalCounters.reset()
actual_gbs = GBS if i >= 2 else BS
@@ -1533,6 +1586,10 @@ def train_llama3():
tqdm.write(f"evaluating after {sequences_seen} sequences")
profile_marker(f"eval @ {i}")
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.BLOCK_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
MLLOGGER.start(key=mllog_constants.EVAL_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
# run eval
eval_losses = []
eval_iter = get_eval_iter()
@@ -1542,22 +1599,34 @@ def train_llama3():
eval_losses += eval_step(tokens).tolist()
if BENCHMARK and (j+1) == min(BENCHMARK, EVAL_SAMPLES//EVAL_BS):
if MLLOGGER and INITMLPERF:
MLLOGGER.end(key=mllog_constants.INIT_STOP, value=None)
return
log_perplexity = sum(eval_losses) / len(eval_losses)
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=log_perplexity, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
MLLOGGER.end(key=mllog_constants.EVAL_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
if WANDB:
wandb.log({"eval/log_perplexity": log_perplexity, "eval/sequences_seen": sequences_seen})
if log_perplexity < EVAL_TARGET:
tqdm.write(f"target achieved after {sequences_seen} sequences")
if MLLOGGER and RUNMLPERF:
MLLOGGER.end(key=mllog_constants.EPOCH_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=sequences_seen)
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata={mllog_constants.STATUS: mllog_constants.SUCCESS})
if getenv("CKPT"):
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3.safe"
safe_save(get_state_dict(model), fn)
break
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen})
def train_stable_diffusion():
from extra.models.unet import UNetModel