mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
training param for batchnorm
This commit is contained in:
@@ -16,6 +16,7 @@ NUM = int(os.getenv("NUM", 2))
|
||||
BS = int(os.getenv("BS", 8))
|
||||
CNT = int(os.getenv("CNT", 10))
|
||||
BACKWARD = int(os.getenv("BACKWARD", 0))
|
||||
TRAINING = int(os.getenv("TRAINING", 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
@@ -23,7 +24,7 @@ if __name__ == "__main__":
|
||||
parameters = get_parameters(model)
|
||||
optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Tensor.training = True
|
||||
Tensor.training = TRAINING
|
||||
for i in trange(CNT):
|
||||
cpy = time.monotonic()
|
||||
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
|
||||
|
||||
Reference in New Issue
Block a user