tensorflow踩坑:按tensor中元素进行比较

问题描述

DL中咱们可能会根据tensor中元素的值进行不一样的操做(好比loss阶段会根据grandtruth或outputs中元素的大小进行不一样的loss操做),这时就要对tensor中的元素进行判断。在python中能够用for + if语句进行判断。但TF中输入是Tensor,for和if语句失效。python

tf.where说明

  • 格式:tf.where(condition, x=None, y=None, name=None)less

  • 参数:
    condition: 一个元素为bool型的tensor。元素内容为false,或true。
    x: 一个和condition有相同shape的tensor,若是x是一个高维的tensor,x的第一维size必须和condition同样。
    y: 和x有同样shape的tensorcode

  • 返回:
    一个和x,y有一样shape的tensorelement

  • 功能:
    遍历condition Tensor中的元素,若是该元素为true,则output Tensor中对应位置的元素来自x Tensor中对应位置的元素;不然output Tensor中对应位置的元素来自Y tensor中对应位置的元素。文档

使用例子

好比当tensor中元素x大于等于5时,对应输出tensor中元素y=x * 2;不然 y= x * 3的计算get

import os
import sys
import tensorflow as tf
import numpy as np
#
# y = x * 2 (x >= 5)
# y = x * 3 (x < 5)
#
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])
tmp = tf.constant([0, 0, 0, 0, 0, 0, 0, 0, 0]) 
condition = tf.less(a, 5)
smaller = tf.where(condition, a, tmp)
bigger = tf.where(condition, tmp, a)
compute_smaller = smaller * 3
compute_bigger = bigger * 2
result = compute_smaller + compute_bigger
with tf.Session() as sess:
    print(sess.run(result))
#
# 结果: [ 3  6  9 12 10 12 14 16 18]

上述过程也能够改为以下it

a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])
condition = tf.less(a, 5)
result = tf.where(condition, a * 3, a * 2)
with tf.Session() as sess:
    print(sess.run(result))

整个过程当中核心部分是condition条件的设置,该条件能够用tf提供的less,equal等操做实现(详细查看tf文档)。
tmp变量为引入的一个临时变量,目的是为了保证where按条件选择后输出的Tensor大小不变(tmp的0元素在乘法中是无心义计算,用该方法保证未背选择的元素在smaller和bigger中不参与计算)io

详细的condition 和 tmp须要根据实际的计算进行不一样的设置import

参考

https://stackoverflow.com/questions/42689342/compare-two-tensors-elementwise-tensorflow
https://stackoverflow.com/questions/37912161/how-can-i-compute-element-wise-conditionals-on-batches-in-tensorflow变量

相关文章
相关标签/搜索