本文将介绍LSTM模型在实现整数加法方面的应用。
咱们以0-255之间的整数加法为例,生成的结果在0到510之间。为了能利用深度学习模型模拟整数的加法运算,咱们须要将输入的两个加数和输出的结果用二进制表示,这样就能获得向量,如加数在0-255内,能够用8位0-1向量来表示,前面的空位用0填充;结果在0-510内,能够用9位0-1向量来表示,前面的空位用0填充。由于两个加数均在0-255内变化,因此共有256*256=65536个输入向量以及65536个输出向量,输入向量为两个加数的二进制向量的拼接结果,于是是个16为的输入向量。用如下的Python代码能够模拟以上过程:python
import numpy as np # 最多8位二进制 BINARY_DIM = 8 # 将整数表示成为binary_dim位的二进制数,高位用0补齐 def int_2_binary(number, binary_dim): binary_list = list(map(lambda x: int(x), bin(number)[2:])) number_dim = len(binary_list) result_list = [0]*(binary_dim-number_dim)+binary_list return result_list # 将一个二进制数组转为整数 def binary2int(binary_array): out = 0 for index, x in enumerate(reversed(binary_array)): out += x * pow(2, index) return out # 将[0,2**BINARY_DIM)全部数表示成二进制 binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)]) # print(binary) # 样本的输入向量和输出向量 dataX = [] dataY = [] for i in range(binary.shape[0]): for j in range(binary.shape[0]): dataX.append(np.append(binary[i], binary[j])) dataY.append(int_2_binary(i+j, BINARY_DIM+1)) # print(dataX) # print(dataY) # 从新特征X和目标变量Y数组,适应LSTM模型的输入和输出 X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1)) # print(X.shape) Y = np.array(dataY) # print(dataY.shape)
在以上代码中,获得的dataX和dataY以知足要求,但为了能让LSTM模型处理,须要改变这两个数据集的形状。
咱们采用LSTM模型来训练上述数据,LSTM模型的结构很简单,就是简单的一层LSTM层,而后加上Dropout层,最后是全链接层,激活函数采用sigmoid函数,采用的损失函数为平均平方偏差。整个结构的示意图以下:web
模型训练的代码以下:算法
from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from keras import losses from keras.utils import plot_model # 定义LSTM模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]))) model.add(Dropout(0.2)) model.add(Dense(Y.shape[1], activation='sigmoid')) model.compile(loss=losses.mean_squared_error, optimizer='adam') # print(model.summary()) # plot model plot_model(model, to_file=r'./model.png', show_shapes=True) # train model epochs = 100 model.fit(X, Y, epochs=epochs, batch_size=128) # save model mp = r'./LSTM_Operation.h5' model.save(mp)
该LSTM模型每批训练128个样本,共训练100次,采用Adam优化器减小损失值。
对这个模型进行训练,训练100次,损失值为0.0045。接下来咱们就要用这个训练好的模型来预测。咱们预测的方法为,虽然挑两个在0-255内的加数,转化为二进制向量做为输入向量,而后由LSTM模型输出结果,将该结果取整做为输出向量中的元素,最后将这个输出向量转化为整数,就是预测的两个加数的和。模型预测的代码以下:数组
# use LSTM model to predict for _ in range(100): start = np.random.randint(0, len(dataX)-1) # print(dataX[start]) number1 = dataX[start][0:BINARY_DIM] number2 = dataX[start][BINARY_DIM:] print('='*30) print('%s: %s'%(number1, binary2int(number1))) print('%s: %s'%(number2, binary2int(number2))) sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1)) predict = np.round(model.predict(sample), 0).astype(np.int32)[0] print('%s: %s'%(predict, binary2int(predict)))
预测的100组样本的输出结果以下:微信
============================== [1 0 0 1 1 1 0 1]: 157 [0 1 1 1 0 0 0 1]: 113 [1 0 0 0 0 1 1 1 0]: 270 ============================== [1 1 1 0 1 0 1 0]: 234 [0 1 0 0 1 1 0 0]: 76 [1 0 0 1 1 0 1 1 0]: 310 ============================== [1 1 0 0 0 1 0 0]: 196 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 1 1 1 1 1]: 415 ============================== [0 0 1 1 1 0 1 0]: 58 [0 0 1 0 0 0 1 1]: 35 [0 0 1 0 1 1 1 0 1]: 93 ============================== [1 0 0 0 0 0 0 0]: 128 [0 1 1 1 1 0 0 1]: 121 [0 1 1 1 1 1 0 0 1]: 249 ============================== [1 1 1 1 0 1 1 0]: 246 [1 1 0 1 0 1 0 1]: 213 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 0 1 1 0]: 230 [1 0 0 0 0 0 0 0]: 128 [1 0 1 1 0 0 1 1 0]: 358 ============================== [1 0 1 0 0 0 1 1]: 163 [0 1 1 0 0 1 0 1]: 101 [1 0 0 0 0 1 0 0 0]: 264 ============================== [1 0 1 0 0 1 1 0]: 166 [0 1 0 1 0 0 0 0]: 80 [0 1 1 1 1 0 1 1 0]: 246 ============================== [0 0 0 0 1 0 1 1]: 11 [0 1 0 0 0 1 0 1]: 69 [0 0 1 0 1 0 0 0 0]: 80 ============================== [1 1 1 1 0 1 1 1]: 247 [0 1 1 1 0 0 0 0]: 112 [1 0 1 1 0 0 1 1 1]: 359 ============================== [1 0 1 0 1 0 0 1]: 169 [1 1 0 0 0 0 0 0]: 192 [1 0 1 1 0 1 0 0 1]: 361 ============================== [1 0 1 1 0 0 0 1]: 177 [1 0 0 0 1 0 1 1]: 139 [1 0 0 1 1 1 1 0 0]: 316 ============================== [0 1 0 0 0 1 1 0]: 70 [0 0 1 0 1 1 1 0]: 46 [0 0 1 1 1 0 1 0 0]: 116 ============================== [1 0 0 1 1 0 1 1]: 155 [1 1 0 0 0 0 0 1]: 193 [1 0 1 0 1 1 1 0 0]: 348 ============================== [1 0 1 1 0 0 1 0]: 178 [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 0 0 0 0 1]: 321 ============================== [0 1 0 1 1 1 1 1]: 95 [1 1 1 0 0 1 0 0]: 228 [1 0 1 0 0 0 0 1 1]: 323 ============================== [1 0 0 1 1 1 1 0]: 158 [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 1 1]: 183 ============================== [1 1 1 0 1 0 1 1]: 235 [1 1 0 0 0 0 0 1]: 193 [1 1 0 1 0 1 1 0 0]: 428 ============================== [0 1 0 1 1 1 0 1]: 93 [0 1 1 1 0 1 1 0]: 118 [0 1 1 0 1 0 0 1 1]: 211 ============================== [1 1 1 1 1 1 1 1]: 255 [1 1 1 1 1 1 1 0]: 254 [1 1 1 1 1 1 1 0 1]: 509 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 0 1 1 1 1 0]: 94 [0 1 0 1 1 0 1 1 1]: 183 ============================== [0 1 1 1 0 0 0 0]: 112 [0 0 1 1 0 1 0 0]: 52 [0 1 0 1 0 0 1 0 0]: 164 ============================== [1 0 0 0 0 0 0 0]: 128 [1 1 0 1 1 0 1 0]: 218 [1 0 1 0 1 1 0 1 0]: 346 ============================== [0 0 1 1 0 1 0 1]: 53 [1 0 1 1 1 1 1 0]: 190 [0 1 1 1 1 0 0 1 1]: 243 ============================== [0 1 1 1 1 0 0 0]: 120 [1 1 0 1 0 1 0 1]: 213 [1 0 1 0 0 1 1 0 1]: 333 ============================== [0 1 1 1 1 0 1 1]: 123 [1 1 1 0 1 1 0 1]: 237 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 1 0 1 0 0 1]: 105 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 0]: 90 [0 0 1 1 1 0 0 1 1]: 115 ============================== [1 1 1 1 0 0 0 1]: 241 [0 0 0 1 1 1 1 1]: 31 [1 0 0 0 1 0 0 0 0]: 272 ============================== [0 1 0 0 0 1 1 0]: 70 [1 1 1 0 1 0 0 1]: 233 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 0 1 0 0]: 116 [1 0 0 1 0 0 0 0 1]: 289 ============================== [0 1 0 0 1 0 0 0]: 72 [1 1 1 1 1 0 1 0]: 250 [1 0 1 0 0 0 0 1 0]: 322 ============================== [1 1 1 1 0 0 0 0]: 240 [0 1 0 0 0 0 1 0]: 66 [1 0 0 1 1 0 0 1 0]: 306 ============================== [0 1 0 0 0 1 1 1]: 71 [1 0 0 1 0 1 1 0]: 150 [0 1 1 0 1 1 1 0 1]: 221 ============================== [0 1 1 0 1 1 0 1]: 109 [0 0 1 0 0 1 0 1]: 37 [0 1 0 0 1 0 0 1 0]: 146 ============================== [1 1 0 0 0 0 0 0]: 192 [1 1 1 0 0 0 0 1]: 225 [1 1 0 1 0 0 0 0 1]: 417 ============================== [1 0 0 0 0 0 1 1]: 131 [1 1 0 1 1 1 1 0]: 222 [1 0 1 1 0 0 0 0 1]: 353 ============================== [0 0 0 0 0 1 0 0]: 4 [1 1 1 0 0 0 1 0]: 226 [0 1 1 1 0 0 1 1 0]: 230 ============================== [1 1 1 0 1 1 1 1]: 239 [1 1 0 1 1 0 1 1]: 219 [1 1 1 0 0 1 0 1 0]: 458 ============================== [0 0 1 1 0 1 0 1]: 53 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 0 1 1 1]: 295 ============================== [1 0 0 1 0 0 0 1]: 145 [0 1 0 0 0 1 0 0]: 68 [0 1 1 0 1 0 1 0 1]: 213 ============================== [0 0 1 1 0 0 0 0]: 48 [1 0 1 1 0 1 1 1]: 183 [0 1 1 1 0 0 1 1 1]: 231 ============================== [0 1 1 0 0 1 1 1]: 103 [0 0 0 1 1 1 1 0]: 30 [0 1 0 0 0 0 1 0 1]: 133 ============================== [0 1 0 1 1 1 0 1]: 93 [1 1 0 1 0 0 1 0]: 210 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 0 0 1 0 1 0]: 138 [0 1 1 1 1 0 0 1]: 121 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 0 0 0 1 1]: 3 [0 0 1 1 0 0 0 1]: 49 [0 0 0 1 1 0 1 0 0]: 52 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 1 0 0 0 0]: 16 [0 1 0 0 1 0 0 1 0]: 146 ============================== [0 0 0 1 0 0 0 0]: 16 [1 0 0 1 0 0 1 0]: 146 [0 1 0 1 0 0 0 1 0]: 162 ============================== [0 1 0 1 0 1 0 0]: 84 [0 0 0 0 1 1 0 0]: 12 [0 0 1 1 0 0 0 0 0]: 96 ============================== [1 0 1 0 1 0 1 1]: 171 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 0 0 1 1 0]: 390 ============================== [1 1 1 1 1 1 1 0]: 254 [0 1 1 0 1 0 1 0]: 106 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 0 1 1 1 0]: 14 [0 1 0 0 1 0 0 0 0]: 144 ============================== [1 0 1 0 0 1 0 1]: 165 [0 0 1 1 1 0 1 1]: 59 [0 1 1 1 0 0 0 0 0]: 224 ============================== [0 0 1 1 1 0 1 0]: 58 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 1 1 0 0]: 300 ============================== [0 1 0 0 1 1 0 1]: 77 [0 0 0 1 1 1 1 1]: 31 [0 0 1 1 0 1 1 0 0]: 108 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 0 1 0 1 0 1]: 85 [0 1 1 1 0 1 1 1 1]: 239 ============================== [0 1 1 0 1 1 0 1]: 109 [0 1 1 0 1 0 0 1]: 105 [0 1 1 0 1 0 1 1 0]: 214 ============================== [0 1 1 1 1 1 1 1]: 127 [0 1 1 1 0 0 1 0]: 114 [0 1 1 1 1 0 0 0 1]: 241 ============================== [0 1 1 0 0 1 0 1]: 101 [0 1 0 1 0 0 0 0]: 80 [0 1 0 1 1 0 1 0 1]: 181 ============================== [0 1 1 0 1 1 1 0]: 110 [0 1 0 1 0 1 1 0]: 86 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 0 0 1 0 0 1 1]: 19 [1 0 0 1 0 0 0 0]: 144 [0 1 0 1 0 0 0 1 1]: 163 ============================== [1 1 1 1 0 1 0 0]: 244 [1 1 0 1 0 0 1 1]: 211 [1 1 1 0 0 0 1 1 1]: 455 ============================== [0 0 0 0 1 1 1 0]: 14 [1 0 1 1 0 0 1 0]: 178 [0 1 1 0 0 0 0 0 0]: 192 ============================== [0 1 1 0 0 0 0 0]: 96 [1 0 0 1 1 1 0 0]: 156 [0 1 1 1 1 1 1 0 0]: 252 ============================== [0 0 1 1 0 1 0 0]: 52 [0 1 1 1 1 1 0 1]: 125 [0 1 0 1 1 0 0 0 1]: 177 ============================== [0 0 0 0 1 1 0 0]: 12 [0 1 0 1 1 1 0 1]: 93 [0 0 1 1 0 1 0 0 1]: 105 ============================== [0 1 1 0 0 1 0 1]: 101 [1 1 0 1 0 1 0 0]: 212 [1 0 0 1 1 1 0 0 1]: 313 ============================== [1 1 0 0 0 0 0 1]: 193 [1 1 0 0 1 1 0 1]: 205 [1 1 0 0 0 1 1 1 0]: 398 ============================== [0 1 1 1 0 0 1 0]: 114 [0 0 0 0 0 0 0 0]: 0 [0 0 1 1 1 0 0 1 0]: 114 ============================== [1 0 0 0 1 1 1 0]: 142 [1 0 1 1 1 1 0 1]: 189 [1 0 1 0 0 1 0 1 1]: 331 ============================== [1 0 1 1 0 1 1 1]: 183 [0 1 0 1 0 1 1 0]: 86 [1 0 0 0 0 1 1 0 1]: 269 ============================== [1 0 1 0 0 0 1 1]: 163 [1 1 1 0 0 1 0 1]: 229 [1 1 0 0 0 1 0 0 0]: 392 ============================== [0 0 1 1 0 0 0 1]: 49 [1 1 1 0 0 1 1 1]: 231 [1 0 0 0 1 1 0 0 0]: 280 ============================== [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 1 0 0 0]: 168 [1 0 0 1 1 0 1 1 1]: 311 ============================== [0 1 0 0 0 0 0 0]: 64 [0 0 0 0 0 1 0 1]: 5 [0 0 1 0 0 0 1 0 1]: 69 ============================== [1 1 1 1 1 0 1 1]: 251 [1 0 1 1 1 0 0 1]: 185 [1 1 0 1 1 0 1 0 0]: 436 ============================== [1 1 1 0 1 1 1 0]: 238 [1 1 0 0 0 0 1 0]: 194 [1 1 0 1 1 0 0 0 0]: 432 ============================== [0 0 1 1 1 1 0 0]: 60 [0 0 0 1 0 1 1 1]: 23 [0 0 1 0 1 0 0 1 1]: 83 ============================== [0 1 1 1 0 1 0 0]: 116 [1 1 1 1 1 1 0 0]: 252 [1 0 1 1 1 0 0 0 0]: 368 ============================== [1 1 0 1 0 1 1 0]: 214 [1 1 1 1 0 1 0 0]: 244 [1 1 1 0 0 1 0 1 0]: 458 ============================== [1 1 1 1 1 1 1 0]: 254 [1 1 0 1 0 0 0 1]: 209 [1 1 1 0 0 1 1 1 1]: 463 ============================== [0 0 0 0 0 0 1 0]: 2 [0 0 0 0 1 1 0 1]: 13 [0 0 0 0 0 1 1 1 1]: 15 ============================== [0 1 1 0 0 1 1 1]: 103 [1 0 1 1 1 1 1 0]: 190 [1 0 0 1 0 0 1 0 1]: 293 ============================== [1 1 1 1 0 1 1 0]: 246 [0 1 0 1 0 0 1 0]: 82 [1 0 1 0 0 1 0 0 0]: 328 ============================== [0 1 1 1 0 0 1 1]: 115 [0 0 1 1 1 0 1 1]: 59 [0 1 0 1 0 1 1 1 0]: 174 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 1 0 1 0 1 1]: 107 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 1 0 0 0 1 0 0]: 68 [0 0 1 1 1 0 0 0]: 56 [0 0 1 1 1 1 1 0 0]: 124 ============================== [1 1 0 0 1 0 0 0]: 200 [1 0 1 0 0 0 1 0]: 162 [1 0 1 1 0 1 0 1 0]: 362 ============================== [1 1 1 1 0 0 1 1]: 243 [0 1 1 0 0 0 1 1]: 99 [1 0 1 0 1 0 1 1 0]: 342 ============================== [0 0 1 0 1 0 0 1]: 41 [0 1 0 0 1 0 0 1]: 73 [0 0 1 1 1 0 0 1 0]: 114 ============================== [0 0 0 1 1 1 0 1]: 29 [1 0 1 0 1 1 1 0]: 174 [0 1 1 0 0 1 0 1 1]: 203 ============================== [0 0 0 0 1 1 1 1]: 15 [0 0 1 1 1 1 0 1]: 61 [0 0 1 0 0 1 1 0 0]: 76 ============================== [1 1 1 1 1 0 1 1]: 251 [1 1 0 1 0 0 0 0]: 208 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 1 0 0 0]: 232 [0 1 1 0 0 0 1 0]: 98 [1 0 1 0 0 1 0 1 0]: 330 ============================== [1 0 1 1 0 1 0 0]: 180 [0 1 0 1 0 1 1 1]: 87 [1 0 0 0 0 1 0 1 1]: 267 ============================== [1 0 0 0 0 1 1 0]: 134 [1 0 0 1 0 1 0 1]: 149 [1 0 0 0 1 1 0 1 1]: 283 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 1 1 0 0]: 124 [1 0 0 1 0 1 0 0 1]: 297 ============================== [0 1 0 0 1 0 0 0]: 72 [0 1 1 0 0 0 1 1]: 99 [0 1 0 1 0 1 0 1 1]: 171 ============================== [1 1 0 1 0 1 0 1]: 213 [0 0 0 1 1 1 1 0]: 30 [0 1 1 1 1 0 0 1 1]: 243
能够看到,这个简单的LSTM模型的预测的结果所有正确。所以,这就能够用来模拟0-255内的整数的加法运算,是否是很神奇呢?
若是须要想将加数的范围扩大,只须要改变代码中的BINARY_DIM变量便可。可是,加数的范围越大,样本就越大,如2^10=1024内的加法,就会有1024*1024=1048576个样本,这样大的样本量的无疑须要更多的训练时间。
本文到此结束,感谢阅读~若是不当之处,请速联系笔者,欢迎你们交流~祝您好运~app
注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎你们关注哦~~dom
完整的Python代码以下:函数
import numpy as np from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from keras import losses from keras.utils import plot_model # 最多8位二进制 BINARY_DIM = 8 # 将整数表示成为binary_dim位的二进制数,高位用0补齐 def int_2_binary(number, binary_dim): binary_list = list(map(lambda x: int(x), bin(number)[2:])) number_dim = len(binary_list) result_list = [0]*(binary_dim-number_dim)+binary_list return result_list # 将一个二进制数组转为整数 def binary2int(binary_array): out = 0 for index, x in enumerate(reversed(binary_array)): out += x * pow(2, index) return out # 将[0,2**BINARY_DIM)全部数表示成二进制 binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)]) # print(binary) # 样本的输入向量和输出向量 dataX = [] dataY = [] for i in range(binary.shape[0]): for j in range(binary.shape[0]): dataX.append(np.append(binary[i], binary[j])) dataY.append(int_2_binary(i+j, BINARY_DIM+1)) # print(dataX) # print(dataY) # 从新特征X和目标变量Y数组,适应LSTM模型的输入和输出 X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1)) # print(X.shape) Y = np.array(dataY) # print(dataY.shape) # 定义LSTM模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]))) model.add(Dropout(0.2)) model.add(Dense(Y.shape[1], activation='sigmoid')) model.compile(loss=losses.mean_squared_error, optimizer='adam') # print(model.summary()) # plot model plot_model(model, to_file=r'./model.png', show_shapes=True) # train model epochs = 100 model.fit(X, Y, epochs=epochs, batch_size=128) # save model mp = r'./LSTM_Operation.h5' model.save(mp) # use LSTM model to predict for _ in range(100): start = np.random.randint(0, len(dataX)-1) # print(dataX[start]) number1 = dataX[start][0:BINARY_DIM] number2 = dataX[start][BINARY_DIM:] print('='*30) print('%s: %s'%(number1, binary2int(number1))) print('%s: %s'%(number2, binary2int(number2))) sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1)) predict = np.round(model.predict(sample), 0).astype(np.int32)[0] print('%s: %s'%(predict, binary2int(predict)))