From acbeaf0ba90b5fb157cda0aa447996ef0e5753ea Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 19 Jul 2022 09:33:07 -0700 Subject: [PATCH] adam in benchmark_train_efficientnet --- examples/benchmark_train_efficientnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py index a6019f7af0..f9d450f8f1 100644 --- a/examples/benchmark_train_efficientnet.py +++ b/examples/benchmark_train_efficientnet.py @@ -17,13 +17,15 @@ BS = int(os.getenv("BS", 8)) CNT = int(os.getenv("CNT", 10)) BACKWARD = int(os.getenv("BACKWARD", 0)) TRAINING = int(os.getenv("TRAINING", 1)) +ADAM = int(os.getenv("ADAM", 0)) if __name__ == "__main__": print(f"NUM:{NUM} BS:{BS} CNT:{CNT}") model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False) parameters = get_parameters(model) for p in parameters: p.realize() - optimizer = optim.SGD(parameters, lr=0.001) + if ADAM: optimizer = optim.Adam(parameters, lr=0.001) + else: optimizer = optim.SGD(parameters, lr=0.001) Tensor.training = TRAINING Tensor.no_grad = not BACKWARD