原文连接 tensorflow中取下标的函数包括:tf.gather , tf.gather_nd 和 tf.batch_gather。python
indices必须是一维张量 主要参数:api
返回值:经过indices获取params下标的张量。 例子:函数
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([1,2,0],dtype=tf.int32) tensor_c = tf.Variable([0,0],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather(tensor_a,tensor_b))) print(sess.run(tf.gather(tensor_a,tensor_c)))
上个例子tf.gather(tensor_a,tensor_b) 的值为[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值为[[1,2,3],[1,2,3]]学习
对于tensor_a,其第1个元素为[4,5,6],第2个元素为[7,8,9],第0个元素为[1,2,3],因此以[1,2,0]为索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],一样的,以[0,0]为索引的值为[[1,2,3],[1,2,3]]。spa
https://www.tensorflow.org/api_docs/python/tf/gather.net
功能和参数与tf.gather相似,不一样之处在于tf.gather_nd支持多维度索引,即indices能够使多维张量。 例子:code
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32) tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather_nd(tensor_a,tensor_b))) print(sess.run(tf.gather_nd(tensor_a,tensor_c))) tf.gather_nd(tensor_a,tensor_b)值为[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值为[3,7].
对于tensor_a,下标[1,0]的元素为4,下标为[1,1]的元素为5,下标为[1,2]的元素为6,索引[1,0],[1,1],[1,2]]的返回值为[4,5,6],一样的,索引[[0,2],[2,0]]的返回值为[3,7].blog
https://www.tensorflow.org/api_docs/python/tf/gather_nd索引
支持对张量的批量索引,各参数意义见(1)中描述。注意由于是批处理,因此indices要有和params相同的第0个维度。get
例子:
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32) tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.batch_gather(tensor_a,tensor_b))) print(sess.run(tf.batch_gather(tensor_a,tensor_c))) tf.gather_nd(tensor_a,tensor_b)值为[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值为[1,4,7].
tensor_a的三个元素[1,2,3],[4,5,6],[7,8,9]分别对应索引元素的第一,第二和第三个值。[1,2,3]的第0个元素为1,[4,5,6]的第1个元素为5,[7,8,9]的第2个元素为9,因此索引[[0],[1],[2]]的返回值为[1,5,9],一样地,索引[[0],[0],[0]]的返回值为[1,4,7].
https://www.tensorflow.org/api_docs/python/tf/batch_gather
在深度学习的模型训练中,有时候须要对一个batch的数据进行相似于tf.gather_nd的操做,但tensorflow中并无tf.batch_gather_nd之类的操做,此时须要tf.map_fn和tf.gather_nd结合来实现上述操做。