training param for batchnorm

This commit is contained in:
George Hotz
2022-07-04 13:28:03 -07:00
parent 21c78b9316
commit d5d9cffe7c

View File

@@ -16,6 +16,7 @@ NUM = int(os.getenv("NUM", 2))
BS = int(os.getenv("BS", 8))
CNT = int(os.getenv("CNT", 10))
BACKWARD = int(os.getenv("BACKWARD", 0))
TRAINING = int(os.getenv("TRAINING", 1))
if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
@@ -23,7 +24,7 @@ if __name__ == "__main__":
parameters = get_parameters(model)
optimizer = optim.SGD(parameters, lr=0.001)
Tensor.training = True
Tensor.training = TRAINING
for i in trange(CNT):
cpy = time.monotonic()
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()