From d5d9cffe7cd4fa4ce30f1d369b1e57313beaba0f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 4 Jul 2022 13:28:03 -0700 Subject: [PATCH] training param for batchnorm --- examples/benchmark_train_efficientnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py index 53451ac9db..776b9858ab 100644 --- a/examples/benchmark_train_efficientnet.py +++ b/examples/benchmark_train_efficientnet.py @@ -16,6 +16,7 @@ NUM = int(os.getenv("NUM", 2)) BS = int(os.getenv("BS", 8)) CNT = int(os.getenv("CNT", 10)) BACKWARD = int(os.getenv("BACKWARD", 0)) +TRAINING = int(os.getenv("TRAINING", 1)) if __name__ == "__main__": print(f"NUM:{NUM} BS:{BS} CNT:{CNT}") @@ -23,7 +24,7 @@ if __name__ == "__main__": parameters = get_parameters(model) optimizer = optim.SGD(parameters, lr=0.001) - Tensor.training = True + Tensor.training = TRAINING for i in trange(CNT): cpy = time.monotonic() x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()