mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
remove requires_grad use in extra/examples (#16238)
except the ones fed into optimizer
This commit is contained in:
@@ -48,8 +48,8 @@ if __name__ == "__main__":
|
||||
pos_params = list(itertools.accumulate(params, lambda x,y: x+y.numel(), initial=0))
|
||||
adam_m = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||
adam_v = Tensor.zeros(pos_params[-1], device="CPU").contiguous()
|
||||
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU", requires_grad=False).contiguous()
|
||||
adam_b1_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous()
|
||||
adam_b2_t = Tensor.ones((1,), dtype=dtypes.float32, device="CPU").contiguous()
|
||||
adam_params = [adam_m, adam_v, adam_b1_t, adam_b2_t]
|
||||
|
||||
# create loss and grads. init all state so the JIT works on microbatch
|
||||
|
||||
@@ -69,7 +69,6 @@ class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
@@ -264,7 +263,6 @@ def train_cifar():
|
||||
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
||||
self.net_ema = SpeedyResNet(w)
|
||||
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
|
||||
net_ema_param.requires_grad = False
|
||||
net_ema_param.assign(net_param.numpy())
|
||||
|
||||
@TinyJit
|
||||
|
||||
@@ -102,7 +102,7 @@ class Int8Embedding:
|
||||
self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).unsqueeze(-1)
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
|
||||
@@ -57,7 +57,7 @@ class EmbeddingBert(nn.Embedding):
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
|
||||
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
||||
return (arange == idx).where(vals, 0).sum(2, dtype=vals.dtype)
|
||||
|
||||
|
||||
@@ -180,11 +180,11 @@ def train_resnet():
|
||||
def fake_data_get(batch_size):
|
||||
x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
|
||||
y = [0] * batch_size
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, None
|
||||
|
||||
def data_get(it):
|
||||
x, y, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y).shard(GPUS, axis=0), y, cookie
|
||||
|
||||
# ** epoch loop **
|
||||
step_times = []
|
||||
@@ -798,7 +798,7 @@ def train_unet3d():
|
||||
@Tensor.train(mode=False)
|
||||
def eval_step(model, x, y):
|
||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y)
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
score = dice_score(y_hat, y)
|
||||
return loss.realize(), score.realize()
|
||||
|
||||
@@ -21,7 +21,7 @@ class GradAccClipAdamW(Optimizer):
|
||||
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, device=None, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, device, fused)
|
||||
self.b1, self.b2, self.eps, self.wd = b1, b2, eps, weight_decay
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2])
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device) for _ in [b1, b2])
|
||||
self.m = self._new_optim_param()
|
||||
self.v = self._new_optim_param()
|
||||
self.grad_acc, self.clip_norm = grad_acc, clip_norm
|
||||
|
||||
@@ -71,7 +71,7 @@ def train_generator(optimizer, data_fake):
|
||||
if __name__ == "__main__":
|
||||
# data for training and validation
|
||||
X_train, _, _, _ = mnist()
|
||||
ds_noise = Tensor.randn(64, 128, requires_grad=False)
|
||||
ds_noise = Tensor.randn(64, 128)
|
||||
# parameters
|
||||
epochs, batch_size, k = 300, 512, 1
|
||||
sample_interval = epochs // 10
|
||||
|
||||
@@ -164,8 +164,8 @@ elif cmd == "train":
|
||||
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
|
||||
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
|
||||
|
||||
sample_x = Tensor(x_img, requires_grad = False)
|
||||
sample_y = Tensor(y_img, requires_grad = False)
|
||||
sample_x = Tensor(x_img)
|
||||
sample_y = Tensor(y_img)
|
||||
|
||||
# magic code roughly from readme example
|
||||
# An explanation, in case anyone else has to go down this path:
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
|
||||
class LR_Scheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
|
||||
self.epoch_counter = Tensor([0], device=self.optimizer.device)
|
||||
|
||||
def get_lr(self): pass
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class BertForPretraining:
|
||||
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
|
||||
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
|
||||
log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index)
|
||||
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y_counter = Tensor.arange(predictions.shape[-1], device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
|
||||
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
|
||||
|
||||
@@ -159,7 +159,7 @@ class BertPooler:
|
||||
return self.dense(hidden_states[:, 0]).tanh()
|
||||
|
||||
def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
|
||||
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
|
||||
return onehot @ prediction_logits
|
||||
|
||||
@@ -189,7 +189,7 @@ class BertEmbeddings:
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
|
||||
position_ids = Tensor.arange(seq_length, requires_grad=False, device=input_ids.device).unsqueeze(0).expand(*input_shape)
|
||||
position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape)
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
@@ -78,7 +78,7 @@ def tensor_getitem(tensor, *keys):
|
||||
# for gather with indicies only on axis=0
|
||||
def tensor_gather(tensor, indices):
|
||||
if not isinstance(indices, Tensor):
|
||||
indices = Tensor(indices, requires_grad=False)
|
||||
indices = Tensor(indices)
|
||||
if len(tensor.shape) > 2:
|
||||
rem_shape = list(tensor.shape)[1:]
|
||||
tensor = tensor.reshape(tensor.shape[0], -1)
|
||||
|
||||
@@ -15,7 +15,7 @@ class RNNT:
|
||||
@TinyJit
|
||||
def __call__(self, x, y, hc=None):
|
||||
f, _ = self.encoder(x, None)
|
||||
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
|
||||
g, _ = self.prediction(y, hc, Tensor.ones(1))
|
||||
out = self.joint(f, g)
|
||||
return out.realize()
|
||||
|
||||
@@ -30,10 +30,10 @@ class RNNT:
|
||||
return outputs
|
||||
|
||||
def _greedy_decode(self, logits, logit_len):
|
||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size)
|
||||
labels = []
|
||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||
mask = Tensor.zeros(1, requires_grad=False)
|
||||
label = Tensor.zeros(1, 1)
|
||||
mask = Tensor.zeros(1)
|
||||
for time_idx in range(logit_len):
|
||||
logit = logits[time_idx, :, :].unsqueeze(0)
|
||||
not_blank = True
|
||||
@@ -41,7 +41,7 @@ class RNNT:
|
||||
while not_blank and added < 30:
|
||||
if len(labels) > 0:
|
||||
mask = (mask + 1).clip(0, 1)
|
||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]]) + 1 - 1
|
||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
||||
not_blank = k != 28
|
||||
@@ -129,7 +129,7 @@ class LSTM:
|
||||
return self.do_step(x_, hc_)
|
||||
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize()
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size).contiguous().realize()
|
||||
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
|
||||
@@ -376,7 +376,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
else:
|
||||
attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32)
|
||||
attn_mask = Tensor.zeros((B, 1, N, N), device=single_device, dtype=dtypes.float32)
|
||||
if isinstance(xq.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=CI)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
x = Tensor(transform(X_train[samp]))
|
||||
y = Tensor(target_transform(Y_train[samp]))
|
||||
loss, accuracy = train_step(x, y)
|
||||
# printing
|
||||
|
||||
2
test/external/external_test_llama3_layer.py
vendored
2
test/external/external_test_llama3_layer.py
vendored
@@ -14,7 +14,7 @@ if __name__ == "__main__":
|
||||
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
|
||||
for x in nn.state.get_parameters(layer): x.replace(x.cast(dtypes.default_float)).realize()
|
||||
|
||||
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
|
||||
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().realize()
|
||||
|
||||
@TinyJit
|
||||
def run(t): return layer(t, 0, freqs_cis, None)
|
||||
|
||||
@@ -52,13 +52,13 @@ class TestTensorGradient(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
|
||||
|
||||
def test_copy_to_device_gradient(self):
|
||||
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
|
||||
t = Tensor([1.0, 2, 3]).realize()
|
||||
t.to("CPU:1").square().sum().backward()
|
||||
self.assertEqual(t.grad.device, t.device)
|
||||
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_multiple_backward(self):
|
||||
x = Tensor([3.], requires_grad=True)
|
||||
x = Tensor([3.])
|
||||
(x*2)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0])
|
||||
old_grad = x.grad
|
||||
@@ -85,7 +85,7 @@ class TestTensorGradient(unittest.TestCase):
|
||||
np.testing.assert_allclose(base.grad.numpy(), [0.0, 0.0, 0.0, 0.0]) # ...but detach blocks it from base
|
||||
|
||||
def test_setitem_on_grad_used_tensor_raises(self):
|
||||
x = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True).realize()
|
||||
x = Tensor([1.0, 2.0, 3.0, 4.0]).realize()
|
||||
_ = (x * 2.0).sum()
|
||||
with self.assertRaises(RuntimeError):
|
||||
x[0] = 99.0
|
||||
@@ -115,10 +115,10 @@ class TestMultiOutputGradient(unittest.TestCase):
|
||||
|
||||
def test_custom_kernel_multi_output_backward(self):
|
||||
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
|
||||
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
|
||||
((a_ref + b_ref).sum() + (a_ref * b_ref).sum()).backward()
|
||||
|
||||
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a, b = Tensor(a_np), Tensor(b_np)
|
||||
Tensor.realize(a, b)
|
||||
c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul)
|
||||
(c.sum() + d.sum()).backward()
|
||||
@@ -127,10 +127,10 @@ class TestMultiOutputGradient(unittest.TestCase):
|
||||
|
||||
def test_custom_kernel_multi_output_backward_interacting(self):
|
||||
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
|
||||
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
|
||||
((a_ref + b_ref) * (a_ref * b_ref)).sum().backward()
|
||||
|
||||
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a, b = Tensor(a_np), Tensor(b_np)
|
||||
Tensor.realize(a, b)
|
||||
c, d, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), a, b, fxn=self.addmul_kernel, grad_fxn=self.backward_addmul)
|
||||
(c * d).sum().backward()
|
||||
@@ -152,10 +152,10 @@ class TestMultiOutputGradient(unittest.TestCase):
|
||||
return (None, None, None, grad_a, grad_b)
|
||||
|
||||
a_np, b_np = np.random.randn(4, 4).astype(np.float32), np.random.randn(4, 4).astype(np.float32)
|
||||
a_ref, b_ref = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a_ref, b_ref = Tensor(a_np), Tensor(b_np)
|
||||
((a_ref + b_ref).sum() + (a_ref * b_ref).sum() + (a_ref - b_ref).sum()).backward()
|
||||
|
||||
a, b = Tensor(a_np, requires_grad=True), Tensor(b_np, requires_grad=True)
|
||||
a, b = Tensor(a_np), Tensor(b_np)
|
||||
Tensor.realize(a, b)
|
||||
c, d, e, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), Tensor.empty(4, 4), a, b,
|
||||
fxn=addmulsub_kernel, grad_fxn=backward_addmulsub)
|
||||
@@ -166,7 +166,7 @@ class TestMultiOutputGradient(unittest.TestCase):
|
||||
class TestViewGradient(unittest.TestCase):
|
||||
def test_expand(self):
|
||||
x = Tensor.randn(5,2)
|
||||
a = Tensor([3.], requires_grad=True)
|
||||
a = Tensor([3.])
|
||||
aex = a.expand(10)
|
||||
(aex.reshape(5,2) * x).sum().backward()
|
||||
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
|
||||
|
||||
Reference in New Issue
Block a user