mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
training is False by default
This commit is contained in:
@@ -48,6 +48,7 @@ if __name__ == "__main__":
|
||||
BS, steps = int(os.getenv("BS", "64" if TINY else "16")), 2048
|
||||
print("training with batch size %d for %d steps" % (BS, steps))
|
||||
|
||||
Tensor.training = True
|
||||
for i in (t := trange(steps)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class Device:
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
training = True
|
||||
training = False
|
||||
ops = defaultdict(dict)
|
||||
|
||||
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
|
||||
|
||||
Reference in New Issue
Block a user