mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
safe softmax trick in MCTS ucb_explored_children (#7515)
* safe softmax trick in MCTS ucb_explored_children fixed ``` File "numpy/random/mtrand.pyx", line 971, in numpy.random.mtrand.RandomState.choice ValueError: probabilities contain NaN ``` when all ucb_explored_children are big negative numbers result in all NaN probabilities * better type
This commit is contained in:
@@ -35,7 +35,7 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
|
||||
if node.children is None or len(node.children) == 0: return node
|
||||
unexplored_children = []
|
||||
explored_children = []
|
||||
ucb_explored_children = []
|
||||
ucb_explored_children: List[float] = []
|
||||
for child in node.children:
|
||||
if child.n == 0: unexplored_children.append(child)
|
||||
else:
|
||||
@@ -45,7 +45,8 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
|
||||
ucb_explored_children.append(ucb)
|
||||
if len(unexplored_children): return random.choice(unexplored_children)
|
||||
if not len(explored_children): return node
|
||||
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
|
||||
# safe softmax
|
||||
ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
|
||||
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
|
||||
|
||||
# this will expand/remove sometimes
|
||||
|
||||
Reference in New Issue
Block a user