safe softmax trick in MCTS ucb_explored_children (#7515)

* safe softmax trick in MCTS ucb_explored_children

fixed
```
  File "numpy/random/mtrand.pyx", line 971, in numpy.random.mtrand.RandomState.choice
ValueError: probabilities contain NaN
```
when all ucb_explored_children are big negative numbers result in all NaN probabilities

* better type
This commit is contained in:
chenyu
2024-11-03 15:59:31 -05:00
committed by GitHub
parent 3ef3b5b5f8
commit e641bbc859

View File

@@ -35,7 +35,7 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
if node.children is None or len(node.children) == 0: return node
unexplored_children = []
explored_children = []
ucb_explored_children = []
ucb_explored_children: List[float] = []
for child in node.children:
if child.n == 0: unexplored_children.append(child)
else:
@@ -45,7 +45,8 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
ucb_explored_children.append(ucb)
if len(unexplored_children): return random.choice(unexplored_children)
if not len(explored_children): return node
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
# safe softmax
ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
# this will expand/remove sometimes