From 0c3e4382290b035b684ab90da0bdb3035cd57fa2 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 29 Mar 2026 02:18:25 +0800 Subject: [PATCH] llama: mllog (#15502) --- examples/mlperf/model_train.py | 75 ++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 8cb7ad05ae..7b3717119e 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1287,6 +1287,9 @@ def train_llama3(): from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup from examples.mlperf.optim import GradAccClipAdamW + INITMLPERF = getenv("INITMLPERF") + RUNMLPERF = getenv("RUNMLPERF") + LOGMLPERF = getenv("LOGMLPERF") BENCHMARK = getenv("BENCHMARK") config = {} @@ -1309,15 +1312,60 @@ def train_llama3(): EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16) EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6) - # LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. DEV=AMD AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py - # trains to 7 + if LOGMLPERF: + from mlperf_logging import mllog + import mlperf_logging.mllog.constants as mllog_constants + + mllog.config(filename=f"result_llama31_{SEED}.log") + mllog.config(root_dir=Path(__file__).parents[3].as_posix()) + MLLOGGER = mllog.get_mllogger() + MLLOGGER.logger.propagate = False + + LLAMA_BENCHMARK = mllog_constants.LLAMA31_405B if getenv("LLAMA3_SIZE", "8B") == "405B" else mllog_constants.LLAMA31_8B + + if INITMLPERF: + assert BENCHMARK, "BENCHMARK must be set for INITMLPERF" + MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp") + MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox")) + MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED) + MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM) + + MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=LLAMA_BENCHMARK) + + diskcache_clear() + MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True) + MLLOGGER.start(key=mllog_constants.INIT_START, value=None) + + if RUNMLPERF: + MLLOGGER.start(key=mllog_constants.RUN_START, value=None) + MLLOGGER.event(key=mllog_constants.SEED, value=SEED) + + MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=GBS) + MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=SEQLEN) + MLLOGGER.event(key=mllog_constants.MAX_STEPS, value=MAX_STEPS) + MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=grad_acc) + MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=EVAL_SAMPLES) + MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=SAMPLES) + + MLLOGGER.event(key=mllog_constants.OPT_NAME, value=mllog_constants.ADAMW) + MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=LR) + MLLOGGER.event(key=mllog_constants.OPT_END_LR, value=END_LR) + MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_1, value=0.9) + MLLOGGER.event(key=mllog_constants.OPT_ADAMW_BETA_2, value=0.95) + MLLOGGER.event(key=mllog_constants.OPT_ADAMW_EPSILON, value=1e-5) + MLLOGGER.event(key=mllog_constants.OPT_ADAMW_WEIGHT_DECAY, value=0.1) + MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=WARMUP_STEPS) + MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=WARMUP_STEPS) + MLLOGGER.event(key=mllog_constants.OPT_LR_DECAY_STEPS, value=MAX_STEPS - WARMUP_STEPS) + MLLOGGER.event(key=mllog_constants.OPT_GRADIENT_CLIP_NORM, value=1.0) + else: + MLLOGGER = None opt_adamw_beta_1 = 0.9 opt_adamw_beta_2 = 0.95 opt_adamw_epsilon = 1e-5 opt_adamw_weight_decay = 0.1 - opt_gradient_clip_norm = 1.0 opt_learning_rate_warmup_steps = WARMUP_STEPS opt_learning_rate_decay_steps = MAX_STEPS - opt_learning_rate_warmup_steps opt_base_learning_rate = LR @@ -1451,6 +1499,11 @@ def train_llama3(): train_iter = get_train_iter() i, sequences_seen = resume_ckpt, 0 step_times = [] + + if MLLOGGER and RUNMLPERF: + MLLOGGER.start(key=mllog_constants.EPOCH_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + while i < MAX_STEPS: GlobalCounters.reset() actual_gbs = GBS if i >= 2 else BS @@ -1533,6 +1586,10 @@ def train_llama3(): tqdm.write(f"evaluating after {sequences_seen} sequences") profile_marker(f"eval @ {i}") + if MLLOGGER and RUNMLPERF: + MLLOGGER.end(key=mllog_constants.BLOCK_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + MLLOGGER.start(key=mllog_constants.EVAL_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + # run eval eval_losses = [] eval_iter = get_eval_iter() @@ -1542,22 +1599,34 @@ def train_llama3(): eval_losses += eval_step(tokens).tolist() if BENCHMARK and (j+1) == min(BENCHMARK, EVAL_SAMPLES//EVAL_BS): + if MLLOGGER and INITMLPERF: + MLLOGGER.end(key=mllog_constants.INIT_STOP, value=None) return log_perplexity = sum(eval_losses) / len(eval_losses) tqdm.write(f"eval log perplexity: {log_perplexity:.4f}") + if MLLOGGER and RUNMLPERF: + MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=log_perplexity, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + MLLOGGER.end(key=mllog_constants.EVAL_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + if WANDB: wandb.log({"eval/log_perplexity": log_perplexity, "eval/sequences_seen": sequences_seen}) if log_perplexity < EVAL_TARGET: tqdm.write(f"target achieved after {sequences_seen} sequences") + if MLLOGGER and RUNMLPERF: + MLLOGGER.end(key=mllog_constants.EPOCH_STOP, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) + MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=sequences_seen) + MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata={mllog_constants.STATUS: mllog_constants.SUCCESS}) if getenv("CKPT"): if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) fn = f"{ckpt_dir}/llama3.safe" safe_save(get_state_dict(model), fn) break + if MLLOGGER and RUNMLPERF: + MLLOGGER.start(key=mllog_constants.BLOCK_START, metadata={mllog_constants.SAMPLES_COUNT: sequences_seen}) def train_stable_diffusion(): from extra.models.unet import UNetModel