AI => TF20的LSTM与GRU(return_sequences与return_state)参数源码

前言

舒适提示:

本文只适用于: 了解LSTM 和 GRU的结构,可是不懂Tensorflow20中LSTM和GRU的参数的人)

额外说明

看源码不等于高大上。
当你各类博客翻烂,发现内容不是互相引用,就是相互"借鉴"。。。且绝望时。
你可能会翻翻文档,其实有些文档写的并非很详细。
这时,看源码是你最好的理解方式。(LSTM 和 GRU 部分源码仍是比较好看的)函数

标题写不下了: TF20 ==> Tensorflow2.0(Stable)
tk ===> tensorflow.keras
LSTM 和 GRU 已经放在 tk.layers模块中。code

return_sequences = True
return_state = True

这两个参数是使用率最高的两个了, 而且LSTM 和 GRU 中都有。
那它们到底是什么意思呢???
来,开始吧!

进入源码方式:
    import tensorflow.keras as tk
    tk.layers.GRU()
    tk.layers.LSTM()
    用pycharm ctrl+左键 点进源码便可~~~

LSTM源码

我截取了部分主干源码:文档

...
...
  states = [new_h, new_c]           # 很显然,第一个是横向状态h, 另外一个是记忆细胞c

if self.return_sequences:         # 若是return_sequences设为True
  output = outputs                    # 则输出值为全部LSTM单元的 输出y,注意还没return
else:                             # 若是return_sequences设为False
  output = last_output                # 则只输出LSTM最后一个单元的信息, 注意还没return

if self.return_state:             # 若是return_state设为True
  return [output] + list(states)      # 则最终返回 上面的output + [new_h, new_c]
else:                             # 若是return_state设为False
  return output                       # 则最终返回 只返回上面的output

小技巧: 瞄准 return 关键词。 你就会很是清晰,它会返回什么了。

GRU源码

...
...
########  咱们主要看这一部分 #########################################
  last_output, outputs, runtime, states = self._defun_gru_call( 
      inputs, initial_state, training, mask)
#####################################################################          
...
...

######### 下面不用看了, 这下面代码和  LSTM是如出一辙的 ###################
if self.return_sequences:
  output = outputs
else:
  output = last_output

if self.return_state:
  return [output] + list(states)
else:
  return output

如今咱们的寻找关键点只在于, states 是怎么获得的???
你继续点进去 "self._defun_gru_call" 这个函数的源码, 你会发现 states 就直接暴露在里面pycharm

states = [new_h]
return ..., states

如今源码几乎所有分析完毕。 咱们回头思考总结一下:input

LSTM 和 GRU 中的 return_sequences 和 return_state 部分的源码是如出一辙的!!!
    return_sequences: 只管理 output变量的赋值,(最后一个单元 或 所有单元)
    return_state: 负责返回 output变量,而且按条件决定是否再一并多返回一个 states变量
    
进而咱们把问题关注点转换到  output变量, 和 states变量:

LSTM 和 GRU 的 output变量: 大体类似,不用管。
LSTM 和 GRU 的 ststes变量:
    LSTM的 states变量:  [H, C]    # 若是你了解LSTM的结构,看到这里你应该很清楚,LSTM有C和H
    GRU的 states变量:   [H]       # 若是你了解GRU的结构,看到这里你应该很清楚,GRU就一个H

最终使用层总结:

LSTM:

有四种组合使用:源码

  1. return_sequences = False 且 return_state = False (默认)博客

    返回值: 只返回 最后一个 LSTM单元的输出Y
  2. return_sequences = True 且 return_state = Falseit

    返回值: 只返回 全部 LSTM单元的输出Y
  3. return_sequences = False 且 return_state = Trueio

    返回值: 返回最后一个LSTM单元的输出Y   和    C + H 两个(隐层信息)
  4. return_sequences = True 且 return_state = Truetable

    返回值: 返回全部LSTM单元的输出Y  和  C + H 两个(隐层信息)  (适用于Atention)

GRU:

有四种组合使用:

  1. return_sequences = False 且 return_state = False (默认)

    返回值: 同LSTM
  2. return_sequences = True 且 return_state = False

    返回值: 同LSTM
  3. return_sequences = False 且 return_state = True

    返回值: 返回 最后一个 LSTM单元的输出Y   和   一个H(隐层信息)
  4. return_sequences = True 且 return_state = True

    返回值: 返回 全部 LSTM单元的输出Y  和 一个H(隐层信息)  (适用于Atention)
相关文章
相关标签/搜索