remove requires_grad use in extra/examples (#16238)

except the ones fed into optimizer
This commit is contained in:
chenyu
2026-05-16 18:40:26 -04:00
committed by GitHub
parent 8631b6f17d
commit dcee90aa3f
16 changed files with 35 additions and 37 deletions

View File

@@ -48,8 +48,8 @@ if __name__ == "__main__":
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous()
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous()
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
# create loss and grads. init all state so the JIT works on microbatch

View File

@@ -69,7 +69,6 @@ class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
def __init__(self, num_features):
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
self.weight.requires_grad = False
self.bias.requires_grad = True
class ConvGroup:
def __init__(self, channels_in, channels_out):
@@ -264,7 +263,6 @@ def train_cifar():
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
self.net_ema = SpeedyResNet(w)
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
net_ema_param.requires_grad = False
net_ema_param.assign(net_param.numpy())
@TinyJit

View File

@@ -102,7 +102,7 @@ class Int8Embedding:
self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).unsqueeze(-1)
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)

View File

@@ -57,7 +57,7 @@ class EmbeddingBert(nn.Embedding):
def __call__(self, idx:Tensor) -> Tensor:
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype)

View File

@@ -180,11 +180,11 @@ def train_resnet():
def fake_data_get(batch_size):
x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
y = [0] * batch_size
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, None
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, cookie
# ** epoch loop **
step_times = []
@@ -798,7 +798,7 @@ def train_unet3d():
@Tensor.train(mode=False)
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
y_hat, y = Tensor(y_hat), Tensor(y)
loss = dice_ce_loss(y_hat, y)
score = dice_score(y_hat, y)
return loss.realize(), score.realize()

View File

@@ -21,7 +21,7 @@ class GradAccClipAdamW(Optimizer):
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM):
super().__init__(params, lr, device, fused)
self.b1, self.b2, self.eps, self.wd = b1, b2, eps, weight_decay
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2])
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device) for _ in [b1, b2])
self.m = self._new_optim_param()
self.v = self._new_optim_param()
self.grad_acc, self.clip_norm = grad_acc, clip_norm

View File

@@ -71,7 +71,7 @@ def train_generator(optimizer, data_fake):
if __name__ == "__main__":
# data for training and validation
X_train, _, _, _ = mnist()
ds_noise = Tensor.randn(64, 128, requires_grad=False)
ds_noise = Tensor.randn(64, 128)
# parameters
epochs, batch_size, k = 300, 512, 1
sample_interval = epochs // 10

View File

@@ -164,8 +164,8 @@ elif cmd == "train":
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
sample_x = Tensor(x_img, requires_grad = False)
sample_y = Tensor(y_img, requires_grad = False)
sample_x = Tensor(x_img)
sample_y = Tensor(y_img)
# magic code roughly from readme example
# An explanation, in case anyone else has to go down this path:

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
class LR_Scheduler:
def __init__(self, optimizer: Optimizer):
self.optimizer = optimizer
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
self.epoch_counter = Tensor([0], device=self.optimizer.device)
def get_lr(self): pass

View File

@@ -52,7 +52,7 @@ class BertForPretraining:
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index)
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y_counter = Tensor.arange(predictions.shape[-1], device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
@@ -159,7 +159,7 @@ class BertPooler:
return self.dense(hidden_states[:, 0]).tanh()
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
return onehot @ prediction_logits
@@ -189,7 +189,7 @@ class BertEmbeddings:
input_shape = input_ids.shape
seq_length = input_shape[1]
position_ids = Tensor.arange(seq_length, requires_grad=False, device=input_ids.device).unsqueeze(0).expand(*input_shape)
position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

View File

@@ -78,7 +78,7 @@ def tensor_getitem(tensor, *keys):
# for gather with indicies only on axis=0
def tensor_gather(tensor, indices):
if not isinstance(indices, Tensor):
indices = Tensor(indices, requires_grad=False)
indices = Tensor(indices)
if len(tensor.shape) > 2:
rem_shape = list(tensor.shape)[1:]
tensor = tensor.reshape(tensor.shape[0], -1)

View File

@@ -15,7 +15,7 @@ class RNNT:
@TinyJit
def __call__(self, x, y, hc=None):
f, _ = self.encoder(x, None)
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
g, _ = self.prediction(y, hc, Tensor.ones(1))
out = self.joint(f, g)
return out.realize()
@@ -30,10 +30,10 @@ class RNNT:
return outputs
def _greedy_decode(self, logits, logit_len):
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size)
labels = []
label = Tensor.zeros(1, 1, requires_grad=False)
mask = Tensor.zeros(1, requires_grad=False)
label = Tensor.zeros(1, 1)
mask = Tensor.zeros(1)
for time_idx in range(logit_len):
logit = logits[time_idx, :, :].unsqueeze(0)
not_blank = True
@@ -41,7 +41,7 @@ class RNNT:
while not_blank and added < 30:
if len(labels) > 0:
mask = (mask + 1).clip(0, 1)
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]]) + 1 - 1
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
k = jhc[0, 0, :29].argmax(axis=0).numpy()
not_blank = k != 28
@@ -129,7 +129,7 @@ class LSTM:
return self.do_step(x_, hc_)
if hc is None:
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize()
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size).contiguous().realize()
output = None
for t in range(x.shape[0]):

View File

@@ -376,7 +376,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple):
attn_mask = attn_mask.shard(xq.device, axis=0)
else:
attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32)
attn_mask = Tensor.zeros((B, 1, N, N), device=single_device, dtype=dtypes.float32)
if isinstance(xq.device, tuple):
attn_mask = attn_mask.shard(xq.device, axis=0)

View File

@@ -26,7 +26,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
losses, accuracies = [], []
for i in (t := trange(steps, disable=CI)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
x = Tensor(transform(X_train[samp]))
y = Tensor(target_transform(Y_train[samp]))
loss, accuracy = train_step(x, y)
# printing

View File

@@ -14,7 +14,7 @@ if __name__ == "__main__":
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
for x in nn.state.get_parameters(layer): x.replace(x.cast(dtypes.default_float)).realize()
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().realize()
@TinyJit
def run(t): return layer(t, 0, freqs_cis, None)

View File

@@ -52,13 +52,13 @@ class TestTensorGradient(unittest.TestCase):
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t = Tensor([1.0, 2, 3]).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
x = Tensor([3.])
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
@@ -85,7 +85,7 @@ class TestTensorGradient(unittest.TestCase):
np.testing.assert_allclose(base.grad.numpy(), [0.0, 0.0, 0.0, 0.0]) # ...but detach blocks it from base
def test_setitem_on_grad_used_tensor_raises(self):
x = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True).realize()
x = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
_ = (x * 2.0).sum()
with self.assertRaises(RuntimeError):
x[0] = 99.0
@@ -115,10 +115,10 @@ class TestMultiOutputGradient(unittest.TestCase):
def test_custom_kernel_multi_output_backward(self):
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
((a_ref + b_ref).sum() + (a_ref * b_ref).sum()).backward()
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a, b = Tensor(a_np), Tensor(b_np)
Tensor.realize(a, b)
c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul)
(c.sum() + d.sum()).backward()
@@ -127,10 +127,10 @@ class TestMultiOutputGradient(unittest.TestCase):
def test_custom_kernel_multi_output_backward_interacting(self):
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
((a_ref + b_ref) * (a_ref * b_ref)).sum().backward()
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a, b = Tensor(a_np), Tensor(b_np)
Tensor.realize(a, b)
c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul)
(c * d).sum().backward()
@@ -152,10 +152,10 @@ class TestMultiOutputGradient(unittest.TestCase):
return (None, None, None, grad_a, grad_b)
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
((a_ref + b_ref).sum() + (a_ref * b_ref).sum() + (a_ref - b_ref).sum()).backward()
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
a, b = Tensor(a_np), Tensor(b_np)
Tensor.realize(a, b)
c, d, e, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), Tensor.empty(4, 4), a, b,
fxn=addmulsub_kernel, grad_fxn=backward_addmulsub)
@@ -166,7 +166,7 @@ class TestMultiOutputGradient(unittest.TestCase):
class TestViewGradient(unittest.TestCase):
def test_expand(self):
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
a = Tensor([3.])
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())