diff --git a/test/test_const_folding.py b/test/test_const_folding.py index dfffca8989..b78faee145 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4)) def test_literal_one_pow(self): _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4])) - # this fails because of DETACH, it shouldn't - # update: passes after CONST(VIEW(DEVICE)) in tensor + # TODO: pow simplification def test_tensor_one_pow(self): - _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) + _check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) # folds advance indexing into basic indexing class TestIndexingConstFolding(unittest.TestCase): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8d048eb431..de142aa06f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait): if not base.is_floating_point(): raise RuntimeError("base needs to be float") # start with b ** e = exp(e * log(b)) ret = base.abs().log().mul(exponent).exp() - # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent) + # correct sign of negative base with odd exponent negative_base = (base < 0).detach().where(1, 0) # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent - correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1) + correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base) # inject nan for negative base and non-integer exponent inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1) # apply correct_sign inject_nan, and fix 0 ** 0 = 1