raise RuntimeError for int base pow (#8852)

current implementation is not precise and blocking other simplification change
This commit is contained in:
chenyu
2025-02-01 12:11:57 -05:00
committed by GitHub
parent 72e1f41f8e
commit 73ee2d74c0
3 changed files with 7 additions and 0 deletions

View File

@@ -95,6 +95,10 @@ backend_test.exclude('test_dequantizelinear_e5m2_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')
# no support for int pow
backend_test.exclude('test_pow_types_int32_int32_cpu')
backend_test.exclude('test_pow_types_int64_int64_cpu')
# no support for fmod
backend_test.exclude('test_mod_int64_fmod_cpu')
backend_test.exclude('test_mod_mixed_sign_float16_cpu')

View File

@@ -619,6 +619,7 @@ class TestOps(unittest.TestCase):
# TODO: fix backward, should be nan
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
@unittest.skip("not supported")
def test_pow_int(self):
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)

View File

@@ -3310,6 +3310,8 @@ class Tensor(SimpleMathTrait):
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
base, exponent = self._broadcasted(x, reverse=reverse)
# TODO: int pow
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)