smaller bert script change (#6768)

only WANDB and RUNMLPERF order. BENCHMARK and BEAM will be done differently
This commit is contained in:
chenyu
2024-09-26 04:54:28 -04:00
committed by GitHub
parent abd484a9f7
commit 5a5fbfa1eb

View File

@@ -644,6 +644,13 @@ def train_bert():
else:
MLLOGGER = None
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
@@ -672,7 +679,7 @@ def train_bert():
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)
@@ -727,14 +734,8 @@ def train_bert():
start_step = int(scheduler_wd.epoch_counter.numpy().item())
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
if not INITMLPERF:
if RUNMLPERF:
# only load real data with RUNMLPERF
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
for _ in range(start_step): next(train_it) # Fast forward
@@ -743,10 +744,12 @@ def train_bert():
step_times = []
# ** train loop **
wc_start = time.perf_counter()
if INITMLPERF:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
else:
if RUNMLPERF:
# only load real data with RUNMLPERF
i, train_data = start_step, get_data_bert(GPUS, train_it)
else:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
@@ -759,10 +762,10 @@ def train_bert():
pt = time.perf_counter()
try:
if INITMLPERF:
next_data = get_fake_data_bert(GPUS, BS)
else:
if RUNMLPERF:
next_data = get_data_bert(GPUS, train_it)
else:
next_data = get_fake_data_bert(GPUS, BS)
except StopIteration:
next_data = None
@@ -807,10 +810,10 @@ def train_bert():
BEAM.value = EVAL_BEAM
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
if INITMLPERF:
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
else:
if RUNMLPERF:
eval_data = get_data_bert(GPUS, eval_it)
else:
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
GlobalCounters.reset()
st = time.time()