mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
smaller bert script change (#6768)
only WANDB and RUNMLPERF order. BENCHMARK and BEAM will be done differently
This commit is contained in:
@@ -644,6 +644,13 @@ def train_bert():
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
||||
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
@@ -672,7 +679,7 @@ def train_bert():
|
||||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
|
||||
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
@@ -727,14 +734,8 @@ def train_bert():
|
||||
start_step = int(scheduler_wd.epoch_counter.numpy().item())
|
||||
print(f"resuming from {ckpt} at step {start_step}")
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
||||
|
||||
if not INITMLPERF:
|
||||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
for _ in range(start_step): next(train_it) # Fast forward
|
||||
@@ -743,10 +744,12 @@ def train_bert():
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
if INITMLPERF:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
else:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
@@ -759,10 +762,10 @@ def train_bert():
|
||||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
if INITMLPERF:
|
||||
next_data = get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
next_data = get_data_bert(GPUS, train_it)
|
||||
else:
|
||||
next_data = get_fake_data_bert(GPUS, BS)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
@@ -807,10 +810,10 @@ def train_bert():
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
if INITMLPERF:
|
||||
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
else:
|
||||
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user