1 入门html
2 多个输入和输出node
3 共享层this
考虑这样的一个问题:咱们要判断连个tweet是否来源于同一我的。spa
首先咱们对两个tweet进行处理,而后将处理的结构拼接在一块儿,以后跟一个逻辑回归,输出这两条tweet来自同一我的几率。翻译
由于咱们对两条tweet的处理是相同的,因此对第一条tweet的处理的模型,能够被重用来处理第二个tweet。咱们考虑用LSTM进行处理。code
假设咱们的输入是两条 280*256的向量htm
首先定义输入:blog
import keras from keras.layers import Input, LSTM, Dense from keras.models import Model tweet_a = Input(shape=(280, 256)) tweet_b = Input(shape=(280, 256))
而后咱们共享LSTM。共享层很简单,只要实例化层一次,而后在你想处理的tensor上调用你想要应用的次数便可(翻译无力,看代码)索引
# This layer can take as input a matrix # and will return a vector of size 64 shared_lstm = LSTM(64) # When we reuse the same layer instance # multiple times, the weights of the layer # are also being reused # (it is effectively *the same* layer) encoded_a = shared_lstm(tweet_a) encoded_b = shared_lstm(tweet_b) # We can then concatenate the two vectors: merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1) # And add a logistic regression on top predictions = Dense(1, activation='sigmoid')(merged_vector) # We define a trainable model linking the # tweet inputs to the predictions model = Model(inputs=[tweet_a, tweet_b], outputs=predictions) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy']) model.fit([data_a, data_b], labels, epochs=10)
其实,简单点说,对一个层的屡次调用,就是在共享这个层。这里有一个层的节点的概念ip
当你在一个输入tensor上调用一个层时,就会生成一个输出tensor,就会在这个层上添加一个节点,这个节点链接着这两个tensor(输入tensor和输出tensor)。当你屡次调用同一个层的时,
这个层生成的节点就会按照0 ,1, 2, 。。以此类推编号。
那么当一个层有多个节点的时候,咱们怎么获取它的输出呢?
若是直接经过output获取会出错:
a = Input(shape=(280, 256)) b = Input(shape=(280, 256)) lstm = LSTM(32) encoded_a = lstm(a) encoded_b = lstm(b) lstm.output
>> AttributeError: Layer lstm_1 has multiple inbound nodes, hence the notion of "layer output" is ill-defined. Use `get_output_at(node_index)` instead.
这时候应该经过索引进行调用:
assert lstm.get_output_at(0) == encoded_a assert lstm.get_output_at(1) == encoded_b
对于输入,也是一样的
a = Input(shape=(32, 32, 3)) b = Input(shape=(64, 64, 3)) conv = Conv2D(16, (3, 3), padding='same') conved_a = conv(a) # Only one input so far, the following will work: assert conv.input_shape == (None, 32, 32, 3) conved_b = conv(b) # now the `.input_shape` property wouldn't work, but this does: assert conv.get_input_shape_at(0) == (None, 32, 32, 3) assert conv.get_input_shape_at(1) == (None, 64, 64, 3)