simplify pow to not call cos (#8877)

use %2 instead of cos to detect even numbers
This commit is contained in:
chenyu
2025-02-03 12:54:18 -05:00
committed by GitHub
parent d1aa9f30bc
commit cce26009f0
2 changed files with 4 additions and 5 deletions

View File

@@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
def test_literal_one_pow(self):
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
# this fails because of DETACH, it shouldn't
# update: passes after CONST(VIEW(DEVICE)) in tensor
# TODO: pow simplification
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
_check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):

View File

@@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait):
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
# correct sign of negative base with odd exponent
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1