diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index b3c2b93d99..d0b64895e0 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -268,10 +268,11 @@ def train_resnet(): BEAM.value = EVAL_BEAM if INITMLPERF: + i, proc = 0, fake_data_get(EVAL_BS) + else: it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch)) i, proc = 0, data_get(it) - else: - i, proc = 0, fake_data_get(EVAL_BS) + prev_cookies = [] while proc is not None: GlobalCounters.reset()