mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
clean up binary_crossentropy_logits (#10958)
This commit is contained in:
@@ -165,7 +165,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
for v in data.values(): v.to_(Device.DEFAULT)
|
||||
|
||||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 347)
|
||||
|
||||
def test_bert_fuse_arange(self):
|
||||
with Context(FUSE_ARANGE=1):
|
||||
|
||||
@@ -3900,13 +3900,8 @@ class Tensor(MathTrait):
|
||||
print(t.binary_crossentropy_logits(Y).item())
|
||||
```
|
||||
"""
|
||||
log_exp = (1 + self.abs().neg().exp()).log()
|
||||
base = self.maximum(0) - Y * self
|
||||
if pos_weight is None:
|
||||
return (base + log_exp)._do_reduction(reduction)
|
||||
pos_scalar = 1 + Y * (pos_weight - 1)
|
||||
pos_addition = Y * (pos_weight - 1) * self.neg().maximum(0)
|
||||
return (base + pos_addition + (pos_scalar * log_exp))._do_reduction(reduction)
|
||||
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
|
||||
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user