mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-16 01:48:24 +08:00
simplify pow to not call cos (#8877)
use %2 instead of cos to detect even numbers
This commit is contained in:
@@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
||||
def test_literal_one_pow(self):
|
||||
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
||||
# this fails because of DETACH, it shouldn't
|
||||
# update: passes after CONST(VIEW(DEVICE)) in tensor
|
||||
# TODO: pow simplification
|
||||
def test_tensor_one_pow(self):
|
||||
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
_check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
# folds advance indexing into basic indexing
|
||||
class TestIndexingConstFolding(unittest.TestCase):
|
||||
|
||||
@@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait):
|
||||
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
||||
# start with b ** e = exp(e * log(b))
|
||||
ret = base.abs().log().mul(exponent).exp()
|
||||
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
||||
# correct sign of negative base with odd exponent
|
||||
negative_base = (base < 0).detach().where(1, 0)
|
||||
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
||||
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
||||
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
|
||||
# inject nan for negative base and non-integer exponent
|
||||
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
||||
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
||||
|
||||
Reference in New Issue
Block a user