From e641bbc859d1c314fc70c1199db04cacc5dc6c8e Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 3 Nov 2024 15:59:31 -0500 Subject: [PATCH] safe softmax trick in MCTS ucb_explored_children (#7515) * safe softmax trick in MCTS ucb_explored_children fixed ``` File "numpy/random/mtrand.pyx", line 971, in numpy.random.mtrand.RandomState.choice ValueError: probabilities contain NaN ``` when all ucb_explored_children are big negative numbers result in all NaN probabilities * better type --- extra/mcts_search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extra/mcts_search.py b/extra/mcts_search.py index f8edb94442..54189fabb3 100644 --- a/extra/mcts_search.py +++ b/extra/mcts_search.py @@ -35,7 +35,7 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode: if node.children is None or len(node.children) == 0: return node unexplored_children = [] explored_children = [] - ucb_explored_children = [] + ucb_explored_children: List[float] = [] for child in node.children: if child.n == 0: unexplored_children.append(child) else: @@ -45,7 +45,8 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode: ucb_explored_children.append(ucb) if len(unexplored_children): return random.choice(unexplored_children) if not len(explored_children): return node - ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP) + # safe softmax + ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP) return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm) # this will expand/remove sometimes