最近看刘汝佳老师的《算法经典入门训练指南》,搜相关的算法博客时,发现一本神书《人工智能 一种现代的方法》(如下简称《人工智能》),里面囊括的算法,也让我对算法有了新的认知。在学校时,相信你们都学过数据结构和算法,这些算法是你们接触的最基础的算法,再往上走,你们作人工智能,又涉及到机器学习和深度学习。鉴于本身的认知视野优先,总感受这两类之间的算法中间跳过了不少东西。所以最近看算法时,难免要搜到不少博客,才知道算法的种类不少,每一个类下面又有不少分支,图大类下的路径搜索,约束知足问题,以及这篇要讲的对抗搜索等等。东西太多可是在了解的过程当中,也解答了之前的一些疑惑,好比路径搜索在游戏中的应用,约束知足的应用,以及对抗搜索在棋牌类的应用。最后感叹在校期间没有参加ACM训练,没有抓住机会扩展本身的视野。强烈建议在校生有机会参加参加ACM训练,拿不到奖牌也能够扩展本身的视野。真的很重要。正文开始:java
(本文不少东西都是参考上面两本书,文末会贴上本身的代码。)算法
相信你们在网上或多或少的都玩过不少对抗类游戏,好比五子棋、象棋、国际象棋、围棋等。初期时,可能不少时候是和电脑进行人机联系。电脑方的出牌策略就是应用了不少对抗搜索的算法(对抗类游戏可能有多我的参与,本文只讨论二人对抗的游戏,多人对抗的请参考《人工智能》这本书)。数组
这些问题很是难于求解,例如国际象棋的平均分支因子大约是35,一盘棋通常每一个游戏者走50步,因此搜索树大约有35100即10154个结点(尽管整个状态空间“只”约1040个不一样结点)。和现实世界同样,游戏要求即便没法找到最优决策也必须能作某种决策,而不能花费太多的时间。换句话说,这些游戏有严格的时间限制(time limit)。因此对博弈的研究也产生了一些有趣的思想,如何尽量充分的利用好时间。数据结构
咱们知道对抗类游戏是参与者轮流出招,咱们能够将其写成相似于路径寻找的问题:app
初始状态:包括棋盘局面和肯定该哪一个游戏者出招
后继函数:返回(move, state)列表,每一项表示一个合法招数和对应的结果状态。
终止测试:判断游戏是否结束。游戏结束的状态称为终止状态。
效用函数:也称目标函数或收益函数,是终止状态的得分。国际象棋中赢、输、平分别是1,-1和0分,而围棋、黑白棋等能够有更多的结果。机器学习
考虑一个简单的游戏:井字棋。如今有两个参与者MAX和MIN,在3×3的棋盘上,MAX划叉,MIN划圆圈。任何一种图案占据了一行或者一列或者一整条斜对角线(主副对角线),那么断定相应的游戏者获胜。以下图(摘自《人工智能》)。初始状态棋盘为空,而后依次由MAX、MIN方轮流走,这样就造成了一颗相似于搜索树的博弈树。数据结构和算法
这个图列举了双方能接受的全部选择。咱们能够看到只有叶子节点才有评价函数,能够看到从根节点到叶子节点是双方按照当前路径走下来的最终结果(赢、输、平局)。每条路径都对应一个结果,双反不论在何时,确定都要选择“最利于”本身获胜的步骤。此时的核心问题就是在每一步的时候,MAX/MIN如何来设计评价函数来选择“最利于”本身的下棋步骤。好比在第一步时,MAX到底该如何得知本身要选择9个选择中的哪个。函数
这里采用极大极小值方法: 对MAX方来讲,评价函数越大越好,而对MIN方来讲,评价函数越小越好。也就是在每一步中,MAX方选择全部节点中评价函数最大的节点,做为本身当前的落棋选择,而MIN则相反。若是一个MIN结点有三个儿子,评价值分别为3,4,-1。最聪明的对手必定会选择那个-1的儿子(这样对MAX最不利),而若是对手并无发现这个走步(或者并不以为它的后继状态对MAX最不利),它可能选择的是3或者4。学习
惋惜因为博弈树太大,若是要直接追踪到最终状态,这对于计算机来讲也是一个超大的负荷,所以合理的方案是在固定深度截断,在这个深度内的“叶子节点”双发按照极大极小值方法来选择本身每一步的落棋选择。对于井字棋游戏,一个可能的评价函数是:
e(s) = (MAX可能占有的行/列/对角线数) - (MIN可能占有的行/列/对角线数)测试
其中“可能占有”的意思是“此行/列/对角线”不含对方的符号。更复杂的评价函数每每是对各类特征进行加权计算。 下图是深度为2时的评价函数计算。
能够验证对max的第一步来讲,选择走中间那个节点是最优的选择。若是此时MIN选择走第一行正中,那么此时节点的部分搜索树以下。
刚才所述的算法成为MAXMIN算法,咱们采起递归的计算方式来描述整个算法:
int max_value ( int dep , state s ){ if ( terminal ( s )) return e ( s ); //终止状态 if ( dep == maxdepth ) return e ( s ); //深度截断,返回评价函数 v = - inf ; //初始化为负无穷 succ = make_successors ( s ); // succ [ i为第]个后继状态i for ( i = 0; i < succ . count ; i ++) v = max (v , min_value ( succ [ i ])); //计算全部儿子的最大值 return v ; }
int min_value ( int dep , state s ){ if ( terminal ( s )) return e ( s ); //终止状态 if ( dep == maxdepth ) return e ( s ); //深度截断,返回评价函数 v = inf ; //初始化为无穷大 succ = make_successors ( s ); // succ [ i为第]个后继状态i for ( i = 0; i < succ . count ; i ++) v = min (v , max_value ( succ [ i ])); //计算全部儿子的最小值(刘汝佳老师的书中是错的) return v ; }
文末附上“井字棋”的完整JAVA代码
package search; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; import javax.xml.parsers.FactoryConfigurationError; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; /** * 对抗搜索-井字棋游戏 */ public class Search_Ant { private static int n = 3; private static final int maxPlayer = 1; private static final int minPlayer = 2; private static int maxDepth = 3; //对抗预测的最大深度 private static int[][] initArray = new int[n][n]; private static int depth = 0; private static int alpha = -10000; private static int beta = 10000; public static void main(String[] args) { State initState = new State(); initState.setCurrentState(initArray); antSearch(initState, depth); } /** * 开始对抗搜索 * 偶数max执行; * 奇数min执行; * 共对抗执行的最大次数: n * n ; */ public static void antSearch(State state, int depth) { if(isSuccess(state)){ System.out.println("某一方成功赢了"); return; } if (depth >= n * n) { System.out.println("双方平局"); return; } //偶数次由max方走,奇数次由min方走; if (depth % 2 == 0) { maxValue(state, 0); int rowIndex = state.getNextBestState().getRowIndex(); int columnIndex = state.getNextBestState().getColumnIndex(); state.setRowIndex(rowIndex); state.setColumnIndex(columnIndex); state.getCurrentState()[rowIndex][columnIndex] = 1; display(state, maxPlayer, depth); // System.out.println(String.format("max方执行: (%s,%s)", state.getRowIndex(), state.getColumnIndex())); } else if (depth % 2 == 1) { minValue(state, 0); int rowIndex = state.getNextBestState().getRowIndex(); int columnIndex = state.getNextBestState().getColumnIndex(); state.setRowIndex(rowIndex); state.setColumnIndex(columnIndex); state.getCurrentState()[rowIndex][columnIndex] = 2; display(state, minPlayer, depth); // System.out.println(String.format("min方执行: (%s,%s)", state.getRowIndex(), state.getColumnIndex())); } // ++depth; // 开始下一次迭代 // antSearch(state, depth); } /** * 得到当前状态的后继节点 * * @param currentState * @return */ public static List<State> getSuccessor(State currentState, int player) { List<State> successorList = new ArrayList<State>(); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (currentState.getCurrentState()[i][j] == 0) { int[][] array = copyArray(currentState.getCurrentState()); if (player == 1) { array[i][j] = 1; } else { array[i][j] = 2; } State nextState = new State(i, j, array); successorList.add(nextState); } } } return successorList; } public static int[][] copyArray(int[][] array) { int[][] copyArray = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { copyArray[i][j] = array[i][j]; } } return copyArray; } /** * max方执行 * * @return */ public static int maxValue(State state, int currentDepth) { int[][] currentState = state.getCurrentState(); if (currentDepth >= maxDepth) { //计算差值 return evalFunction(currentState); } if (isSuccess(state)) { //计算差值 return evalFunction(currentState); } List<State> successor = getSuccessor(state, 1); int v = -10000; int target = -1; for (int i = 0; i < successor.size(); i++) { //这里须要优化,Alpha-Beta剪枝,缩小检索空间 int value = minValue(successor.get(i), currentDepth + 1); if (v < value) { v = value; target = i; } // System.out.println(String.format("depth=%s, value=%s, maxPlayer,数组: %s", currentDepth, value,display((successor.get(i))))); } if (target == -1) { //已经此时是平局 return 0; } System.out.println(String.format("depth=%s, value=%s, maxPlayer,数组: %s", currentDepth, v,display((successor.get(target))))); state.setNextBestState(successor.get(target)); return v; } /** * min方执行 * * @return */ public static int minValue(State state, int currentDepth) { int[][] currentState = state.getCurrentState(); if (currentDepth >= maxDepth) { //计算差值 return evalFunction(currentState); } if (isSuccess(state)) { //计算差值 return evalFunction(currentState); } List<State> successor = getSuccessor(state, 2); int v = 10000; int target = -1; for (int i = 0; i < successor.size(); i++) { int value = maxValue(successor.get(i), currentDepth + 1); //这里须要优化,Alpha-Beta剪枝,缩小检索空间 if (v > value) { target = i; v = value; } // System.out.println(String.format("depth=%s, value=%s, minPlayer,数组: %s", currentDepth,value,display((successor.get(i))))); } if (target == -1) { //已经此时是平局 return 0; } System.out.println(String.format("depth=%s, value=%s, minPlayer,数组: %s", currentDepth, v,display((successor.get(target))))); state.setNextBestState(successor.get(target)); return v; } /** * 当前状态的评估函数 * * @param currentState * @return */ public static int evalFunction(int[][] currentState) { int minPlayerResult = getPlayerOccupy(currentState, maxPlayer); int maxPlayerResult = getPlayerOccupy(currentState, minPlayer); return maxPlayerResult - minPlayerResult; } /** * 得到某一方所占用的坐标 * * @param currentState * @return */ public static int getPlayerOccupy(int[][] currentState, int palyer) { Set<Integer> rowOccupy = new HashSet<Integer>(); Set<Integer> columnOccupy = new HashSet<Integer>(); boolean mainDiag = false; boolean viceDiag = false; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (currentState[i][j] == palyer) { rowOccupy.add(i); columnOccupy.add(j); if (i == j) { //在主对角线上 mainDiag = true; } if (i + j == n - 1) {//在副对角线上 viceDiag = true; } } } } int result = 0; result += n - rowOccupy.size(); result += n - columnOccupy.size(); result += mainDiag ? 0 : 1; result += viceDiag ? 0 : 1; return result; } /** * 判断当前状态是否能够判断某一方已经胜利 * * @return */ public static boolean isSuccess(State state) { if (isRowSame(state) || isColumnSame(state) || isMainDiagSame(state) || isViceDiagSame(state)) { return true; } return false; } /** * 一行为相同 * * @return */ public static boolean isRowSame(State state) { int[][] currentState = state.getCurrentState(); int rowIndex = state.getRowIndex(); int preValue = currentState[rowIndex][0]; if (preValue == 0) { return false; } for (int i = 1; i < n; i++) { if (currentState[rowIndex][i] != preValue) { return false; } } return true; } /** * 列相同 * * @return */ public static boolean isColumnSame(State state) { int[][] currentState = state.getCurrentState(); int columnIndex = state.getColumnIndex(); int preValue = currentState[0][columnIndex]; if (preValue == 0) { return false; } for (int i = 1; i < n; i++) { if (currentState[i][columnIndex] != preValue) { return false; } } return true; } /** * 主对角线是否相同 * * @return */ public static boolean isMainDiagSame(State state) { int[][] currentState = state.getCurrentState(); int rowIndex = state.getRowIndex(); int columnIndex = state.getColumnIndex(); if (rowIndex == columnIndex) { int preValue = currentState[0][0]; if (preValue == 0) { return false; } for (int i = 1; i < n; i++) { if (currentState[i][i] != preValue) { return false; } } return true; } return false; } /** * 副对角线是否相同 * * @return */ public static boolean isViceDiagSame(State state) { int[][] currentState = state.getCurrentState(); int rowIndex = state.getRowIndex(); int columnIndex = state.getColumnIndex(); if (rowIndex + columnIndex == n - 1) { int preValue = currentState[0][n - 1]; if (preValue == 0) { return false; } int m = 0; int k = n - 1; for (int i = 1; i < n; i++) { if (currentState[m + i][k - i] != preValue) { return false; } } return true; } return false; } public static void display(State state, int player, int depth) { int[][] array = state.getCurrentState(); System.out.println("==============================================="); System.out.println(String.format("第%s步: 当前方以及走的坐标: %s --> (%s,%s)", depth, player, state.getRowIndex(), state.getColumnIndex())); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { System.out.print(array[i][j] + " "); } System.out.println(); } } public static String display(State state) { StringBuffer buffer = new StringBuffer(); int[][] array = state.getCurrentState(); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { // System.out.print(array[i][j]+" "); buffer.append(array[i][j] + " "); } } return buffer.toString(); } } @Getter @Setter @AllArgsConstructor @NoArgsConstructor class State { private int rowIndex; private int columnIndex; private int[][] currentState; private State nextBestState; //当前最好的状态 public State(int i, int j, int[][] array) { this.rowIndex = i; this.columnIndex = j; this.currentState = array; } }