DL中咱们可能会根据tensor中元素的值进行不一样的操做(好比loss阶段会根据grandtruth或outputs中元素的大小进行不一样的loss操做),这时就要对tensor中的元素进行判断。在python中能够用for + if语句进行判断。但TF中输入是Tensor,for和if语句失效。python
格式:tf.where(condition, x=None, y=None, name=None)less
参数:
condition: 一个元素为bool型的tensor。元素内容为false,或true。
x: 一个和condition有相同shape的tensor,若是x是一个高维的tensor,x的第一维size必须和condition同样。
y: 和x有同样shape的tensorcode
返回:
一个和x,y有一样shape的tensorelement
功能:
遍历condition Tensor中的元素,若是该元素为true,则output Tensor中对应位置的元素来自x Tensor中对应位置的元素;不然output Tensor中对应位置的元素来自Y tensor中对应位置的元素。文档
好比当tensor中元素x大于等于5时,对应输出tensor中元素y=x * 2;不然 y= x * 3的计算get
import os import sys import tensorflow as tf import numpy as np # # y = x * 2 (x >= 5) # y = x * 3 (x < 5) # a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9]) tmp = tf.constant([0, 0, 0, 0, 0, 0, 0, 0, 0]) condition = tf.less(a, 5) smaller = tf.where(condition, a, tmp) bigger = tf.where(condition, tmp, a) compute_smaller = smaller * 3 compute_bigger = bigger * 2 result = compute_smaller + compute_bigger with tf.Session() as sess: print(sess.run(result)) # # 结果: [ 3 6 9 12 10 12 14 16 18]
上述过程也能够改为以下it
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9]) condition = tf.less(a, 5) result = tf.where(condition, a * 3, a * 2) with tf.Session() as sess: print(sess.run(result))
整个过程当中核心部分是condition条件的设置,该条件能够用tf提供的less,equal等操做实现(详细查看tf文档)。
tmp变量为引入的一个临时变量,目的是为了保证where按条件选择后输出的Tensor大小不变(tmp的0元素在乘法中是无心义计算,用该方法保证未背选择的元素在smaller和bigger中不参与计算)io
详细的condition 和 tmp须要根据实际的计算进行不一样的设置import
https://stackoverflow.com/questions/42689342/compare-two-tensors-elementwise-tensorflow
https://stackoverflow.com/questions/37912161/how-can-i-compute-element-wise-conditionals-on-batches-in-tensorflow变量