From dcee90aa3f1e8f5471ee6e218955eca3f7073d61 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 16 May 2026 18:40:26 -0400 Subject: [PATCH] remove requires_grad use in extra/examples (#16238) except the ones fed into optimizer --- examples/gradaccum_mnist.py | 4 ++-- examples/hlb_cifar10.py | 2 -- examples/llama3.py | 2 +- examples/mlperf/initializers.py | 2 +- examples/mlperf/model_train.py | 6 +++--- examples/mlperf/optim.py | 2 +- examples/mnist_gan.py | 2 +- examples/vgg7.py | 4 ++-- extra/lr_scheduler.py | 2 +- extra/models/bert.py | 6 +++--- extra/models/mask_rcnn.py | 2 +- extra/models/rnnt.py | 12 ++++++------ extra/thunder/tiny/fa.py | 2 +- extra/training.py | 2 +- test/external/external_test_llama3_layer.py | 2 +- test/unit/test_gradient.py | 20 ++++++++++---------- 16 files changed, 35 insertions(+), 37 deletions(-) diff --git a/examples/gradaccum_mnist.py b/examples/gradaccum_mnist.py index 2a0ac6f143..fbebe16c22 100644 --- a/examples/gradaccum_mnist.py +++ b/examples/gradaccum_mnist.py @@ -48,8 +48,8 @@ if __name__ == "__main__": pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0)) adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous() adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous() - adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous() - adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous() + adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous() + adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous() adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t] # create loss and grads. init all state so the JIT works on microbatch diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 6fc6d08fd1..6870e2139c 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -69,7 +69,6 @@ class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm): def __init__(self, num_features): super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True) self.weight.requires_grad = False - self.bias.requires_grad = True class ConvGroup: def __init__(self, channels_in, channels_out): @@ -264,7 +263,6 @@ def train_cifar(): # self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer self.net_ema = SpeedyResNet(w) for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()): - net_ema_param.requires_grad = False net_ema_param.assign(net_param.numpy()) @TinyJit diff --git a/examples/llama3.py b/examples/llama3.py index c7476c50b8..f55fdb3273 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -102,7 +102,7 @@ class Int8Embedding: self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half) def __call__(self, idx:Tensor) -> Tensor: - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).unsqueeze(-1) big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index d10792d917..226b5c49a2 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -57,7 +57,7 @@ class EmbeddingBert(nn.Embedding): def __call__(self, idx:Tensor) -> Tensor: if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device) arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,) - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp) + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).reshape(arange_shp) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp) return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index e944279155..87335313c3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -180,11 +180,11 @@ def train_resnet(): def fake_data_get(batch_size): x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous() y = [0] * batch_size - return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None + return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, None def data_get(it): x, y, cookie = next(it) - return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie + return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, cookie # ** epoch loop ** step_times = [] @@ -798,7 +798,7 @@ def train_unet3d(): @Tensor.train(mode=False) def eval_step(model, x, y): y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) - y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False) + y_hat, y = Tensor(y_hat), Tensor(y) loss = dice_ce_loss(y_hat, y) score = dice_score(y_hat, y) return loss.realize(), score.realize() diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index 1148a331d8..800be4b1b5 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -21,7 +21,7 @@ class GradAccClipAdamW(Optimizer): def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM): super().__init__(params, lr, device, fused) self.b1, self.b2, self.eps, self.wd = b1, b2, eps, weight_decay - self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2]) + self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device) for _ in [b1, b2]) self.m = self._new_optim_param() self.v = self._new_optim_param() self.grad_acc, self.clip_norm = grad_acc, clip_norm diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index 3e1ebe0ef0..640d485032 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -71,7 +71,7 @@ def train_generator(optimizer, data_fake): if __name__ == "__main__": # data for training and validation X_train, _, _, _ = mnist() - ds_noise = Tensor.randn(64, 128, requires_grad=False) + ds_noise = Tensor.randn(64, 128) # parameters epochs, batch_size, k = 300, 512, 1 sample_interval = epochs // 10 diff --git a/examples/vgg7.py b/examples/vgg7.py index 1b01fd1aaf..0aa49c30b9 100644 --- a/examples/vgg7.py +++ b/examples/vgg7.py @@ -164,8 +164,8 @@ elif cmd == "train": x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png") y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png") - sample_x = Tensor(x_img, requires_grad = False) - sample_y = Tensor(y_img, requires_grad = False) + sample_x = Tensor(x_img) + sample_y = Tensor(y_img) # magic code roughly from readme example # An explanation, in case anyone else has to go down this path: diff --git a/extra/lr_scheduler.py b/extra/lr_scheduler.py index 87ff077a40..094155a712 100644 --- a/extra/lr_scheduler.py +++ b/extra/lr_scheduler.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor class LR_Scheduler: def __init__(self, optimizer: Optimizer): self.optimizer = optimizer - self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device) + self.epoch_counter = Tensor([0], device=self.optimizer.device) def get_lr(self): pass diff --git a/extra/models/bert.py b/extra/models/bert.py index 4528be8920..df619724f0 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -52,7 +52,7 @@ class BertForPretraining: # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1): log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index) - y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) + y_counter = Tensor.arange(predictions.shape[-1], device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero @@ -159,7 +159,7 @@ class BertPooler: return self.dense(hidden_states[:, 0]).tanh() def gather(prediction_logits:Tensor, masked_lm_positions:Tensor): - counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) + counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1]) return onehot @ prediction_logits @@ -189,7 +189,7 @@ class BertEmbeddings: input_shape = input_ids.shape seq_length = input_shape[1] - position_ids = Tensor.arange(seq_length, requires_grad=False, device=input_ids.device).unsqueeze(0).expand(*input_shape) + position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index 052ccd2328..d61b811e75 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -78,7 +78,7 @@ def tensor_getitem(tensor, *keys): # for gather with indicies only on axis=0 def tensor_gather(tensor, indices): if not isinstance(indices, Tensor): - indices = Tensor(indices, requires_grad=False) + indices = Tensor(indices) if len(tensor.shape) > 2: rem_shape = list(tensor.shape)[1:] tensor = tensor.reshape(tensor.shape[0], -1) diff --git a/extra/models/rnnt.py b/extra/models/rnnt.py index e7ad0f54b9..83218cd087 100644 --- a/extra/models/rnnt.py +++ b/extra/models/rnnt.py @@ -15,7 +15,7 @@ class RNNT: @TinyJit def __call__(self, x, y, hc=None): f, _ = self.encoder(x, None) - g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False)) + g, _ = self.prediction(y, hc, Tensor.ones(1)) out = self.joint(f, g) return out.realize() @@ -30,10 +30,10 @@ class RNNT: return outputs def _greedy_decode(self, logits, logit_len): - hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False) + hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size) labels = [] - label = Tensor.zeros(1, 1, requires_grad=False) - mask = Tensor.zeros(1, requires_grad=False) + label = Tensor.zeros(1, 1) + mask = Tensor.zeros(1) for time_idx in range(logit_len): logit = logits[time_idx, :, :].unsqueeze(0) not_blank = True @@ -41,7 +41,7 @@ class RNNT: while not_blank and added < 30: if len(labels) > 0: mask = (mask + 1).clip(0, 1) - label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1 + label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]]) + 1 - 1 jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask) k = jhc[0, 0, :29].argmax(axis=0).numpy() not_blank = k != 28 @@ -129,7 +129,7 @@ class LSTM: return self.do_step(x_, hc_) if hc is None: - hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize() + hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size).contiguous().realize() output = None for t in range(x.shape[0]): diff --git a/extra/thunder/tiny/fa.py b/extra/thunder/tiny/fa.py index 7e589237a3..0a893ada01 100644 --- a/extra/thunder/tiny/fa.py +++ b/extra/thunder/tiny/fa.py @@ -376,7 +376,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple): attn_mask = attn_mask.shard(xq.device, axis=0) else: - attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32) + attn_mask = Tensor.zeros((B, 1, N, N), device=single_device, dtype=dtypes.float32) if isinstance(xq.device, tuple): attn_mask = attn_mask.shard(xq.device, axis=0) diff --git a/extra/training.py b/extra/training.py index d134e68bf2..8b4b57ca67 100644 --- a/extra/training.py +++ b/extra/training.py @@ -26,7 +26,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou losses, accuracies = [], [] for i in (t := trange(steps, disable=CI)): samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(transform(X_train[samp]), requires_grad=False) + x = Tensor(transform(X_train[samp])) y = Tensor(target_transform(Y_train[samp])) loss, accuracy = train_step(x, y) # printing diff --git a/test/external/external_test_llama3_layer.py b/test/external/external_test_llama3_layer.py index 31e752b627..0193b2c0e5 100644 --- a/test/external/external_test_llama3_layer.py +++ b/test/external/external_test_llama3_layer.py @@ -14,7 +14,7 @@ if __name__ == "__main__": layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0) for x in nn.state.get_parameters(layer): x.replace(x.cast(dtypes.default_float)).realize() - freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize() + freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().realize() @TinyJit def run(t): return layer(t, 0, freqs_cis, None) diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index 64765af02a..3a566f0e18 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -52,13 +52,13 @@ class TestTensorGradient(unittest.TestCase): with self.assertRaises(RuntimeError): x.float().sum().gradient(x) def test_copy_to_device_gradient(self): - t = Tensor([1.0, 2, 3], requires_grad=True).realize() + t = Tensor([1.0, 2, 3]).realize() t.to("CPU:1").square().sum().backward() self.assertEqual(t.grad.device, t.device) self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0]) def test_multiple_backward(self): - x = Tensor([3.], requires_grad=True) + x = Tensor([3.]) (x*2)[0].backward() np.testing.assert_allclose(x.grad.numpy(), [2.0]) old_grad = x.grad @@ -85,7 +85,7 @@ class TestTensorGradient(unittest.TestCase): np.testing.assert_allclose(base.grad.numpy(), [0.0, 0.0, 0.0, 0.0]) # ...but detach blocks it from base def test_setitem_on_grad_used_tensor_raises(self): - x = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True).realize() + x = Tensor([1.0, 2.0, 3.0, 4.0]).realize() _ = (x * 2.0).sum() with self.assertRaises(RuntimeError): x[0] = 99.0 @@ -115,10 +115,10 @@ class TestMultiOutputGradient(unittest.TestCase): def test_custom_kernel_multi_output_backward(self): a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32) - a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a_ref, b_ref = Tensor(a_np), Tensor(b_np) ((a_ref + b_ref).sum() + (a_ref * b_ref).sum()).backward() - a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a, b = Tensor(a_np), Tensor(b_np) Tensor.realize(a, b) c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul) (c.sum() + d.sum()).backward() @@ -127,10 +127,10 @@ class TestMultiOutputGradient(unittest.TestCase): def test_custom_kernel_multi_output_backward_interacting(self): a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32) - a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a_ref, b_ref = Tensor(a_np), Tensor(b_np) ((a_ref + b_ref) * (a_ref * b_ref)).sum().backward() - a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a, b = Tensor(a_np), Tensor(b_np) Tensor.realize(a, b) c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul) (c * d).sum().backward() @@ -152,10 +152,10 @@ class TestMultiOutputGradient(unittest.TestCase): return (None, None, None, grad_a, grad_b) a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32) - a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a_ref, b_ref = Tensor(a_np), Tensor(b_np) ((a_ref + b_ref).sum() + (a_ref * b_ref).sum() + (a_ref - b_ref).sum()).backward() - a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True) + a, b = Tensor(a_np), Tensor(b_np) Tensor.realize(a, b) c, d, e, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=addmulsub_kernel, grad_fxn=backward_addmulsub) @@ -166,7 +166,7 @@ class TestMultiOutputGradient(unittest.TestCase): class TestViewGradient(unittest.TestCase): def test_expand(self): x = Tensor.randn(5,2) - a = Tensor([3.], requires_grad=True) + a = Tensor([3.]) aex = a.expand(10) (aex.reshape(5,2) * x).sum().backward() np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())