一个例子了解迁移学习

迁移学习

对于传统机器学习而言,要求训练样本与测试样本知足独立同分布,并且必需要有足够多的训练样本。而迁移学习能把一个领域(即源领域)的知识,迁移到另一个领域(即目标领域),目标领域每每只有少许有标签样本,使得目标领域可以取得更好的学习效果。mysql

image

迁移方式

  • 样本迁移,在源领域中找出与目标领域类似的样本,增长该样本的权重,使其在预测目标与的比重加大。
  • 特征迁移,源领域与目标领域包含共同的交叉特征,经过特征变换将源领域和目标领域的的特征变换到相同空间,使它们具备相同分布。
  • 模型迁移,源领域和目标领域共享模型参数,将源领域已训练好的网络模型应用到目标领域的新问题上。
  • 关系迁移,源领域和目标领域具备某种类似关系,能够将源领域的逻辑关系应用到目标领域中。

模型迁移

这里基于预训练的卷积神经网络训练一组新参数,而后将其用于分类任务,这样就能共享模型参数,避免了从头开始训练模型的参数,大大减小训练时间。git

数据集

在示例中使用flower17数据集,它是一个包含17种花卉类别的数据集,每一个类别有80张图像。收集的花都是英国一些常见的花,这些图像具备大比例、不一样姿态和光线变化等性质。github

使用水仙花和款冬这两类花,而且在预训练的VGG16网络之上构建分类器。sql

image

image

实现

首先导入全部必需的库,包括应用程序、预处理、模型检查点以及相关对象,cv2库和NumPy库用于图像处理和数值的基本操做。数组

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np
复制代码

定义输入、数据源及与训练参数相关的全部变量。bash

img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1
复制代码

调用VGG16预训练模型,其中不包括顶部的平整化层。冻结不参与训练的层,这里咱们冻结前五层,而后添加自定义层,从而建立最终的模型。网络

model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
    layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)
复制代码

接着开始编译模型,并为训练、测试数据集建立图像数据加强生成器。并发

model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
                    metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                   width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                  width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
复制代码

生成加强后新的数据,根据状况保存模型。app

train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
                             mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')
复制代码

开始对模型中新的网络层进行拟合。机器学习

model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
                          validation_data=validation_generator, nb_val_samples=nb_validation_samples,
                          callbacks=[checkpoint, early])
复制代码

练完成后用水仙花图像测试这个新模型,输出的正确值应该为接近[1.,0.]的数组。

im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
复制代码
1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
 2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
 3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
 4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
 5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
 6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567 
 7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
 8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
 9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5

 9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877  0.77861226]]
复制代码

github

github.com/sea-boat/De…

-------------推荐阅读------------

个人开源项目汇总(机器&深度学习、NLP、网络IO、AIML、mysql协议、chatbot)

为何写《Tomcat内核设计剖析》

个人2017文章汇总——机器学习篇

个人2017文章汇总——Java及中间件

个人2017文章汇总——深度学习篇

个人2017文章汇总——JDK源码篇

个人2017文章汇总——天然语言处理篇

个人2017文章汇总——Java并发篇


跟我交流,向我提问:

欢迎关注:

相关文章
相关标签/搜索