clean up binary_crossentropy_logits (#10958)

This commit is contained in:
chenyu
2025-06-24 12:23:40 -04:00
committed by GitHub
parent 2ccddfc0ca
commit bfa87f3490
2 changed files with 3 additions and 8 deletions

View File

@@ -165,7 +165,7 @@ class TestRealWorld(unittest.TestCase):
for v in data.values(): v.to_(Device.DEFAULT)
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 347)
def test_bert_fuse_arange(self):
with Context(FUSE_ARANGE=1):

View File

@@ -3900,13 +3900,8 @@ class Tensor(MathTrait):
print(t.binary_crossentropy_logits(Y).item())
```
"""
log_exp = (1 + self.abs().neg().exp()).log()
base = self.maximum(0) - Y * self
if pos_weight is None:
return (base + log_exp)._do_reduction(reduction)
pos_scalar = 1 + Y * (pos_weight - 1)
pos_addition = Y * (pos_weight - 1) * self.neg().maximum(0)
return (base + pos_addition + (pos_scalar * log_exp))._do_reduction(reduction)
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""