viterbi过程 1.hmm相似。 状态转移,发射几率 2.逐次计算每一个序列节点的全部状态下的几率值,最大几率值对应的index。 3.几率值的计算,上一个节点的几率值*转移几率+当前几率值。 4.最后取出最大的一个值对应的indexes 难点: 理解viterbi的核心点,在于每一个时间步都保留每个可视状态,每个可视状态保留上一个时间步的最大隐状态转移, 每个时间步t记录上一个最大几率转移过来的时间步t-1的信息,包括index/几率值累积。 迭代完时间步,根据最后一个最大累积几率值,逐个往前找便可。 根据index对应的状态逐个往前找。 应用: 状态转移求解最佳转移路径。 只要连续时间步,每一个时间步有状态分布,先后时间步之间有状态转移,就能够使用viterbi进行最佳状态转移计算求解。 状态转移矩阵的做用在于 在每一个状态转移几率计算时,和固有的状态转移矩阵进行加和,再计算。至关于额外的几率添加。
import numpy as np def viterbi_decode(score, transition_params): """ 保留全部可视状态下,对seqlen中的每一步的全部可视状态状况下的中间状态求解几率最大值,如此 :param score: :param transition_params: :return: """ # score [seqlen,taglen] transition_params [taglen,taglen] trellis=np.zeros_like(score) trellis[0]=score[0] backpointers=np.zeros_like(score,dtype=np.int32) for t in range(1,len(score)): matrix_node=np.expand_dims(trellis[t-1],axis=1)+transition_params #axis=0 表明发射几率初始状态 trellis[t]=score[t]+np.max(matrix_node,axis=0) backpointers[t]=np.argmax(matrix_node,axis=0) viterbi=[np.argmax(trellis[-1],axis=0)] for backpointer in reversed(backpointers[1:]): viterbi.append(backpointer[viterbi[-1]]) viterbi_score = np.max(trellis[-1]) viterbi.reverse() print(trellis) return viterbi,viterbi_score def calculate(): score = np.array([[1, 2, 3], [2, 1, 3], [1, 3, 2], [3, 2,1]]) # (batch_size, time_step, num_tabs) transition = np.array([ [2, 1, 3], [1, 3, 2], [3, 2, 1] ] )# (num_tabs, num_tabs) lengths = [len(score[0])] # (batch_size, time_step) # numpy print("[numpy]") # np_op = viterbi_decode( score=np.array(score[0]), transition_params=np.array(transition)) # print(np_op[0]) # print(np_op[1]) print("=============") # tensorflow # score_t = tf.constant(score, dtype=tf.int64) # transition_t = transition, dtype=tf.int64 tf_op = viterbi_decode( score, transition) print('--------------------') print(tf_op) if __name__=='__main__': calculate()
// java 版本 import java.lang.reflect.Array; import java.util.ArrayList; import java.util.List; public class viterbi { public static int[] viterbi_decode(double[][]score,double[][]trans ) { //score(16,31) trans(31,31) int path[] = new int[score.length]; double trellis[][] = new double[score.length][score[0].length]; int backpointers[][] = new int [score.length][score[0].length]; trellis[0] = score[0]; for(int t = 1; t<score.length;t++) { // 一维数组,31个元素 [-1000,-1000,-1000,.......] double h[] = trellis[t - 1]; //i shape(31 ,1) 31行,1列 [ [-1000][-10000][-1000] ] //i = np.expand_dims(trellis[t - 1], 1) // // double expand_dims[][] = new double[trans.length][trans[0].length]; //?? // for(int j = 0;j<expand_dims[0].length;j++) { // expand_dims[j] = h; //todo // } //zyy begin double expand_h[][]=new double[trans.length][trans[0].length]; for(int i=0;i<trans.length;i++){ for(int j=0;j<trans.length;j++) { expand_h[i][j]=h[i]; } } double expand_dims[][] = new double[trans.length][trans[0].length]; //?? for(int j = 0;j<expand_dims[0].length;j++) { expand_dims[j] =expand_h[j] ; //todo } //zyy_end double v[][] = new double[trans.length][trans[0].length]; for(int i = 0; i < v.length; i++ ) { for(int j = 0; j< v[0].length ;j++) { v[i][j] = expand_dims[i][j] + trans[i][j]; } } //取每列最大的值 获得score.length个每列最大值,一维数组 double max_v[] = new double[trans[0].length]; int max_v_linepoint[] = new int[trans[0].length]; for (int j = 0; j < v[0].length; j++) { double max_column = v[0][j]; int line_point = 0; for (int i = 0; i < v.length; i++) { if(v[i][j] > max_column) { max_column = v[i][j]; line_point = i; } } max_v[j] = max_column; max_v_linepoint[j] = line_point; } for(int i = 0 ;i < score[0].length; i++ ) { trellis[t][i] = score[t][i] + max_v[i]; backpointers[t][i] = max_v_linepoint[i]; } } int viterbi[] = new int[score.length]; // List<Integer> viterbi = new ArrayList<>(); double max_trellis = trellis[score.length-1][0]; for(int j = 0; j< trellis[score.length-1].length ;j++) { if(trellis[score.length-1][j] > max_trellis) { max_trellis = trellis[score.length-1][j]; // viterbi.add(j); viterbi[0] = j; } } for(int i=1;i< 1+(backpointers.length)/2;i++){ int temp[] = backpointers[i]; backpointers[i] = backpointers[backpointers.length-i]; backpointers[backpointers.length-i]=temp; } for(int i = 1; i < backpointers.length; i++ ) { // viterbi.add( backpointers[i][viterbi.get(viterbi.size() - 1)]); viterbi[i] = backpointers[i][viterbi[i-1]]; } for(int i = 0;i < (viterbi.length)/2; i++){ //把数组的值赋给一个临时变量 int temp = viterbi[i]; viterbi[i] = viterbi[viterbi.length-i-1]; viterbi[viterbi.length-i-1] = temp; } return viterbi; } public static void main(String[] args){ List<List<Integer>> score=new ArrayList<>(); ArrayList<Integer> row1=new ArrayList<>(); row1.add(1); row1.add(2); row1.add(3); ArrayList<Integer> row2=new ArrayList<>(); row2.add(2); row2.add(1); row2.add(3); ArrayList<Integer> row3=new ArrayList<>(); row3.add(1); row3.add(3); row3.add(2); ArrayList<Integer> row4=new ArrayList<>(); row4.add(3); row4.add(2); row4.add(1); score.add(row1); score.add(row2); score.add(row3); score.add(row4); List<List<Integer>> trans=new ArrayList<>(); ArrayList<Integer> row11=new ArrayList<>(); row11.add(2); row11.add(1); row11.add(3); ArrayList<Integer> row12=new ArrayList<>(); row12.add(1); row12.add(3); row12.add(2); ArrayList<Integer> row13=new ArrayList<>(); row13.add(3); row13.add(2); row13.add(1); trans.add(row11); trans.add(row12); trans.add(row13); // double[][] score_double=(double[][]) score.toArray(); // double[][] trans_double=(double[][]) trans.toArray(); System.out.println(score); System.out.println(trans); double[][] score_double=new double[score.size()][score.get(0).size()]; for(int i=0;i<score.size();i++){ // score_double[i]=score.get(i); for(int j=0;j<score.get(0).size();j++){ score_double[i][j]=score.get(i).get(j); } } double[][] trans_double=new double[trans.size()][trans.get(0).size()]; for(int i=0;i<trans.size();i++){ // score_double[i]=score.get(i); for(int j=0;j<trans.get(0).size();j++){ trans_double[i][j]=trans.get(i).get(j); } } int[] result=viterbi_decode(score_double,trans_double); System.out.println("===========****==============="); System.out.println(result.toString()); } }