From 5a5fbfa1ebe87a30328f1f2ebdcae2ae3bf8b2e1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 26 Sep 2024 04:54:28 -0400 Subject: [PATCH] smaller bert script change (#6768) only WANDB and RUNMLPERF order. BENCHMARK and BEAM will be done differently --- examples/mlperf/model_train.py | 39 ++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 67d66b688d..73e574c09f 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -644,6 +644,13 @@ def train_bert(): else: MLLOGGER = None + # ** init wandb ** + WANDB = getenv("WANDB") + if WANDB: + import wandb + wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {} + wandb.init(config=config, **wandb_args, project="MLPerf-BERT") + # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) @@ -672,7 +679,7 @@ def train_bert(): Tensor.manual_seed(seed) # seed for weight initialization - model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None) + model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None) for _, x in get_state_dict(model).items(): x.realize().to_(GPUS) @@ -727,14 +734,8 @@ def train_bert(): start_step = int(scheduler_wd.epoch_counter.numpy().item()) print(f"resuming from {ckpt} at step {start_step}") - # ** init wandb ** - WANDB = getenv("WANDB") - if WANDB: - import wandb - wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {} - wandb.init(config=config, **wandb_args, project="MLPerf-BERT") - - if not INITMLPERF: + if RUNMLPERF: + # only load real data with RUNMLPERF eval_it = iter(batch_load_val_bert(EVAL_BS)) train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK)) for _ in range(start_step): next(train_it) # Fast forward @@ -743,10 +744,12 @@ def train_bert(): step_times = [] # ** train loop ** wc_start = time.perf_counter() - if INITMLPERF: - i, train_data = start_step, get_fake_data_bert(GPUS, BS) - else: + if RUNMLPERF: + # only load real data with RUNMLPERF i, train_data = start_step, get_data_bert(GPUS, train_it) + else: + i, train_data = start_step, get_fake_data_bert(GPUS, BS) + while train_data is not None and i < train_steps and not achieved: Tensor.training = True BEAM.value = TRAIN_BEAM @@ -759,10 +762,10 @@ def train_bert(): pt = time.perf_counter() try: - if INITMLPERF: - next_data = get_fake_data_bert(GPUS, BS) - else: + if RUNMLPERF: next_data = get_data_bert(GPUS, train_it) + else: + next_data = get_fake_data_bert(GPUS, BS) except StopIteration: next_data = None @@ -807,10 +810,10 @@ def train_bert(): BEAM.value = EVAL_BEAM for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK): - if INITMLPERF: - eval_data = get_fake_data_bert(GPUS, EVAL_BS) - else: + if RUNMLPERF: eval_data = get_data_bert(GPUS, eval_it) + else: + eval_data = get_fake_data_bert(GPUS, EVAL_BS) GlobalCounters.reset() st = time.time()