diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index e6facd4a64..66f416633b 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -165,7 +165,7 @@ class TestRealWorld(unittest.TestCase): for v in data.values(): v.to_(Device.DEFAULT) helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \ - data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346) + data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 347) def test_bert_fuse_arange(self): with Context(FUSE_ARANGE=1): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 76989d3c1d..15c33de2d0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3900,13 +3900,8 @@ class Tensor(MathTrait): print(t.binary_crossentropy_logits(Y).item()) ``` """ - log_exp = (1 + self.abs().neg().exp()).log() - base = self.maximum(0) - Y * self - if pos_weight is None: - return (base + log_exp)._do_reduction(reduction) - pos_scalar = 1 + Y * (pos_weight - 1) - pos_addition = Y * (pos_weight - 1) * self.neg().maximum(0) - return (base + pos_addition + (pos_scalar * log_exp))._do_reduction(reduction) + log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid() + return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction) def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: """