"""蒙特卡洛树搜索""" import numpy as np import copy from config import CONFIG def softmax(x): probs = np.exp(x - np.max(x)) probs /= np.sum(probs) return probs # 定义叶子节点 class TreeNode(object): """ mcts树中的节点, 树的子节点字典中, 键为动作, 值为TreeNode。记录当前节点选择的动作, 以及选择该动作后会跳转到的下一个子节点。 每个节点跟踪其自身的Q, 先验概率P及其访问次数调整的u """ def __init__(self, parent, prior_p): """ :param parent: 当前节点的父节点 :param prior_p: 当前节点被选择的先验概率 """ self._parent = parent self._children = {} # 从动作到TreeNode的映射 self._n_visits = 0 # 当前当前节点的访问次数 self._Q = 0 # 当前节点对应动作的平均动作价值 self._u = 0 # 当前节点的置信上限 # PUCT算法 self._P = prior_p def expand(self, action_priors): # 这里把不合法的动作概率全部设置为0 """通过创建新子节点来展开树""" for action, prob in action_priors: if action not in self._children: self._children[action] = TreeNode(self, prob) def select(self, c_puct): """ 在子节点中选择能够提供最大的Q+U的节点 return: (action, next_node)的二元组 """ return max(self._children.items(), key=lambda act_node: act_node[1].get_value(c_puct)) def get_value(self, c_puct): """ 计算并返回此节点的值, 它是节点评估Q和此节点的先验的组合 c_puct: 控制相对影响(0, inf) """ self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) return self._Q + self._u def update(self, leaf_value): """ 从叶节点评估中更新节点值 leaf_value: 这个子节点的评估值来自当前玩家的视角 """ # 统计访问次数 self._n_visits += 1 # 更新Q值, 取决于所有访问次数的平均树, 使用增量式更新方式 self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits # 使用递归的方法对所有节点(当前节点对应的支线)进行一次更新 def update_recursive(self, leaf_value): """就像调用update()一样, 但是对所有直系节点进行更新""" # 如果它不是根节点, 则应首先更新此节点的父节点 if self._parent: self._parent.update_recursive(-leaf_value) self.update(leaf_value) def is_leaf(self): """检查是否是叶节点, 即没有被扩展的节点""" return self._children == {} def is_root(self): return self._parent is None # 蒙特卡洛搜索树 class MCTS(object): def __init__(self, policy_value_fn, c_puct=5, n_playout=2000): """policy_value_fn: 接收board的盘面状态, 返回落子概率和盘面评估得分""" self._root = TreeNode(None, 1.0) self._policy = policy_value_fn self._c_puct = c_puct self._n_playout = n_playout def _playout(self, state): """ 进行一次搜索, 根据叶节点的评估值进行反向更新树节点的参数 注意:state已就地修改, 因此必须提供副本 """ node = self._root while True: if node.is_leaf(): break # 贪心算法选择下一步行动 action, node = node.select(self._c_puct) state.do_move(action) # 使用网络评估叶子节点, 网络输出(动作, 概率)元组p的列表以及当前玩家视角的得分[-1, 1] action_probs, leaf_value = self._policy(state) # 查看游戏是否结束 end, winner = state.game_end() if not end: node.expand(action_probs) else: # 对于结束状态, 将叶子节点的值换成1或-1 if winner == -1: # Tie leaf_value = 0.0 else: leaf_value = ( 1.0 if winner == state.get_current_player_id() else -1.0 ) # 在本次遍历中更新节点的值和访问次数 # 必须添加符号, 因为两个玩家共用一个搜索树 node.update_recursive(-leaf_value) def get_move_probs(self, state, temp=1e-3): """ 按顺序运行所有搜索并返回可用的动作及其相应的概率 state:当前游戏的状态 temp:介于(0, 1]之间的温度参数 """ for n in range(self._n_playout): state_copy = copy.deepcopy(state) self._playout(state_copy) # 跟据根节点处的访问计数来计算移动概率 act_visits= [(act, node._n_visits) for act, node in self._root._children.items()] acts, visits = zip(*act_visits) act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10)) return acts, act_probs def update_with_move(self, last_move): """ 在当前的树上向前一步, 保持我们已经直到的关于子树的一切 """ if last_move in self._root._children: self._root = self._root._children[last_move] self._root._parent = None else: self._root = TreeNode(None, 1.0) def __str__(self): return 'MCTS' # 基于MCTS的AI玩家 class MCTSPlayer(object): def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0): self.mcts = MCTS(policy_value_function, c_puct, n_playout) self._is_selfplay = is_selfplay self.agent = "AI" def set_player_ind(self, p): self.player = p # 重置搜索树 def reset_player(self): self.mcts.update_with_move(-1) def __str__(self): return 'MCTS {}'.format(self.player) # 得到行动 def get_action(self, board, temp=1e-3, return_prob=0): # 像alphaGo_Zero论文一样使用MCTS算法返回的pi向量 move_probs = np.zeros(2086) acts, probs = self.mcts.get_move_probs(board, temp) move_probs[list(acts)] = probs if self._is_selfplay: # 添加Dirichlet Noise进行探索(自我对弈需要) move = np.random.choice( acts, p=0.75*probs + 0.25*np.random.dirichlet(CONFIG['dirichlet'] * np.ones(len(probs))) ) # 更新根节点并重用搜索树 self.mcts.update_with_move(move) else: # 使用默认的temp=1e-3, 它几乎相当于选择具有最高概率的移动 move = np.random.choice(acts, p=probs) # 重置根节点 self.mcts.update_with_move(-1) if return_prob: return move, move_probs else: return move