From c7368515d23b2ba15fd1a67449ee6f2d1727bf55 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 3 May 2024 14:28:36 -0400 Subject: [PATCH] refactor sparse_categorical_crossentropy (#4406) factor out the -1 * and / loss_mask.sum() for both smoothing and non-smoothing terms --- tinygrad/tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 844797ca64..1cb4b2657e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1324,9 +1324,9 @@ class Tensor: # NOTE: self is a logits input log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) - y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) - smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum() - return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing + y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) + smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum() + return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum() # ***** cast ops *****