[译] 如何在安卓应用中使用 TensorFlow Mobile

TensorFlow 是当今最流行的机器学习框架之一,您利用它能够轻松建立和训练深度模型 —— 一般也称为深度前馈神经网络,这些模型能够解决各类复杂问题,如图像分类、目标检测和天然语言理解。TensorFlow Mobile 是一个旨在帮助您在移动应用中利用这些模型的库。html

在本教程中,我将向您展现如何在 Android Studio 项目中使用 TensorFlow Mobile。前端

前期准备

为了可以跟上教程,您须要作的是:python

  • Android Studio 3.0 或更高版本
  • TensorFlow 1.5.0 或更高版本
  • 一台可以运行 API level 21 或更高的安卓设备
  • 以及对 TensorFlow 框架的基本了解

一、建立模型

在咱们开始使用 TensorFlow Mobile 以前,咱们须要一个已经训练好的 TensorFlow 模型。咱们如今建立一个。android

咱们的模型将很是基础,相似于异或门,接受两个输入,它们能够是零或一,而后有一个输出。若是两个输入相同,则输出为零。此外,由于它将是一个深度模型,它将有两个隐藏层,一个有四个神经元,另外一个有三个神经元。您能够自由改变隐藏层的数量以及它们包含的神经元的数量。ios

为了保持本教程的简洁,咱们将使用 TFLearn,这是一个很受欢迎的 TensorFlow 封装框架,它提供更加直接而简洁的 API,而不是直接使用低级别的 TensorFlow API。若是您还没安装它,请使用如下命令将其安装在 TensorFlow 虚拟环境中:git

pip install tflearn
复制代码

要开始建立模型,最好在空目录中先新建一个名为 create_model.py 的 Python 脚本,而后使用您最喜欢的文本编辑器打开它。github

在文件里,咱们须要作的第一件事是导入 TFLearn API。后端

import tflearn
复制代码

接下来,咱们必须建立训练数据。对于咱们的简单模型,只有四种可能的输入和输出,相似于异或门真值表的内容。数组

X = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]
 
Y = [
    [0],  # Desired output for inputs 0, 0
    [1],  # Desired output for inputs 0, 1
    [1],  # Desired output for inputs 1, 0
    [0]   # Desired output for inputs 1, 1
]
复制代码

为隐藏层中的全部神经元分配初始权重时,最好的作法一般是使用从均匀分布中产生的随机数。可使用 uniform() 方法生成这些值。bash

weights = tflearn.initializations.uniform(minval = -1, maxval = 1)
复制代码

此时,咱们能够开始构建神经网络层。要建立输入层,咱们必须使用 input_data() 方法,它容许咱们指定网络能够接受的输入数量。一旦输入层准备就绪,咱们能够屡次调用 fully_connected() 方法来向网络添加更多层。

# 输入层
net = tflearn.input_data(
        shape = [None, 2],
        name = 'my_input'
)
 
# 隐藏层
net = tflearn.fully_connected(net, 4,
        activation = 'sigmoid',
        weights_init = weights
)
net = tflearn.fully_connected(net, 3,
        activation = 'sigmoid',
        weights_init = weights
)
 
# 输出层
net = tflearn.fully_connected(net, 1,
        activation = 'sigmoid', 
        weights_init = weights,
        name = 'my_output'
)
复制代码

注意,在上面的代码中,咱们赋予了输入层和输出层有意义的名称。这么作很重要,由于咱们在使用安卓应用中的网络时须要它们。还要注意隐藏层和输出层使用了 sigmoid 激活函数。您能够试试其余激活函数,例如 softmaxtanhrelu

做为咱们网络的最后一层,咱们必须使用 regression() 函数建立一个回归层,该函数须要一些超参数做为其参数,例如网络的学习率以及它应该使用的优化器和损失函数。如下代码向您展现了如何使用随机梯度降低(简称 SGD)做为优化器函数,均方偏差做为损失函数:

net = tflearn.regression(net,
        learning_rate = 2,
        optimizer = 'sgd',
        loss = 'mean_square'
)
复制代码

接下来,为了让 TFLearn 框架知道咱们的网络模型其实是一个深度神经网络模型,咱们需要调用 DNN() 函数。

model = tflearn.DNN(net)
复制代码

模型如今已经准备好了。咱们如今要作的就是使用咱们以前建立的训练数据进行训练。所以,调用模型的 fit() 方法,并指定训练数据与训练周期。因为训练数据很是小,咱们的模型将须要数千次迭代才能达到合理的精度。

model.fit(X, Y, 5000)
复制代码

一旦训练完成,咱们能够调用模型的 predict() 方法来检查它是否生成指望的输出。如下代码展现了如何检查全部有效输入的输出:

print("1 XOR 0 = %f" % model.predict([[1,0]]).item(0))
print("1 XOR 1 = %f" % model.predict([[1,1]]).item(0))
print("0 XOR 1 = %f" % model.predict([[0,1]]).item(0))
print("0 XOR 0 = %f" % model.predict([[0,0]]).item(0))
复制代码

若是如今运行 Python 脚本,您应该看到以下所示的输出:

训练后的预测结果

请注意,输出不会彻底是 0 或 1。而是接近 0 或 1 的浮点数。所以,在使用输出时,可能须要使用 Python 的 round() 函数。

除非咱们在训练后明确保存模型,不然只要程序结束,咱们就会失去模型。幸运的是,对于 TFLearn,只需调用 save() 方法便可保存模型。可是,为了可以在 TensorFlow Mobile 中使用保存的模型,在保存以前,咱们必须确保移除全部训练相关的操做。这些操做都在 tf.GraphKeys.TRAIN_OPS 集合中。如下代码展现了怎么去移除相关操做:

# 移除训练相关的操做
with net.graph.as_default():
    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
 
# 保存模型
model.save('xor.tflearn')
复制代码

若是您再次运行该脚本,您会发现它会生成检查点文件、元数据文件、索引文件和数据文件,全部这些文件一块儿使用时能够快速重建咱们训练好的模型。

二、固化模型

除了保存模型外,咱们还必须先固化模型,而后才能将其与 TensorFlow Mobile 配合使用。正如您可能已经猜到的那样,固化模型的过程涉及将其全部变量转换为常量。此外,固化模型必须是符合 Google Protocol Buffers 序列化格式的单个二进制文件。

新建一个名为 freeze_model.py 的 Python 脚本,并使用文本编辑器打开它。咱们将在这个文件中编写固化的模型代码来。

因为 TFLearn 没有任何固化模型的功能,咱们如今必须直接使用 TensorFlow API。经过将如下行添加到文件来导入它们:

import tensorflow as tf
复制代码

整个脚本里面,咱们将使用单个 TensorFlow 会话。咱们使用 Session 类的构造函数建立会话。

with tf.Session() as session:
    # 代码的其余部分在这
复制代码

此时,咱们必须经过调用 import_meta_graph() 函数并将模型的元数据文件的名称传递给它来建立 Saver 对象,除了返回 Saver 对象外,import_meta_graph() 函数还会自动将模型的图定义添加到会话的图定义中。

一旦建立了保存器(saver),咱们能够经过调用 restore() 方法来初始化图定义中存在的全部变量,该方法须要包含模型最新检查点文件的目录路径。

my_saver = tf.train.import_meta_graph('xor.tflearn.meta')
my_saver.restore(session, tf.train.latest_checkpoint('.'))
复制代码

此时,咱们能够调用 convert_variables_to_constants() 函数来建立一个固化的图定义,其中模型的全部变量都替换成常量。做为其输入,函数须要当前会话、当前会话的图定义以及包含模型输出层名称的列表。

frozen_graph = tf.graph_util.convert_variables_to_constants(
    session,
    session.graph_def,
    ['my_output/Sigmoid']
)
复制代码

调用固化图定义的 SerializeToString() 方法为咱们提供了模型的二进制 protobuf 表示。经过使用 Python 基本的文件 I/O,我建议您把它保存为一个名为 frozen_model.pb 的文件。

with open('frozen_model.pb', 'wb') as f:
    f.write(frozen_graph.SerializeToString())
复制代码

如今能够运行脚原本生成固化模型。

咱们如今拥有开始使用 TensorFlow Mobile 所需的一切。

三、Android Studio 项目设置

TensorFlow Mobile 库可在 JCenter 上使用,因此咱们能够直接将它添加为 app 模块 build.gradle 文件中的 implementation 依赖项。

implementation 'org.tensorflow:tensorflow-android:1.7.0'
复制代码

要把固化的模型添加到项目中,请将 frozen_model.pb 文件放置到项目的 assets 文件夹中。

四、初始化 TensorFlow 接口

TensorFlow Mobile 提供了一个简单的接口,咱们可使用它与咱们的固化模型进行交互。要建立接口,请使用 TensorFlowInferenceInterface 类的构造函数,该类须要一个 AssetManager 实例和固化模型的文件名。

thread {
    val tfInterface = TensorFlowInferenceInterface(assets,
                                        "frozen_model.pb")
     
    // More code here
}
复制代码

在上面的代码中,您能够看到咱们正在产生一个新的线程。这是为了确保应用的 UI 保持响应,虽然没必要要,但建议这样作。

为了保证 TensorFlow Mobile 可以正确读取咱们模型的文件,如今让咱们尝试打印模型图中全部操做的名称。为了获得对图的引用,咱们可使用接口的 graph() 方法,并获取全部操做,即图的 operations() 方法。如下代码告诉您该怎么作:

val graph = tfInterface.graph()
graph.operations().forEach {
    println(it.name())
}
复制代码

若是如今运行该应用,则应该可以看到在 Android Studio 的 Logcat 窗口中打印的十几个操做名称。若是固化模型时没有出错,咱们能够在这些名称中找到输入和输出层的名称:my_input/Xmy_output/Sigmoid

Logcat 窗口展现了操做列表

五、使用模型

为了用模型进行预测,咱们将数据输入到输入层,在输出层获得数据。将数据输入到输入层须要使用接口的 feed() 方法,该方法须要输入层的名称、含有输入数据的数组以及数组的维数。如下代码展现如何将数字 01 输入到输入层:

tfInterface.feed("my_input/X",
            floatArrayOf(0f, 1f), 1, 2)
复制代码

数据加载到输入层后,咱们必须使用 run() 方法进行推断操做,该方法须要输出层的名称。一旦操做完成,输出层将包含模型的预测。为了将预测结果加载到 Kotlin 数组中,咱们可使用 fetch() 方法。如下代码显示了如何执行此操做:

tfInterface.run(arrayOf("my_output/Sigmoid"))
 
val output = floatArrayOf(-1f)
tfInterface.fetch("my_output/Sigmoid", output)
复制代码

您如今能够运行该应用来查看模型的预测是否正确。

Logcat window displaying the prediction

能够更改输入到输入层的数字,以确认模型的预测始终正确。

总结

您如今知道如何建立一个简单的 TensorFlow 模型以及在安卓应用上经过 TensorFlow Mobile 去使用该模型。不过没必要拘泥于本身的模型,用您今天学到的东西,使用更大的模型对您来讲应该没有任何问题。例如 MobileNet 以及 Inception,这些均可以在 TensorFlow 的 模型园 里找到。可是请注意,这些模型会使 APK 更大,从而给使用低端设备的用户形成问题。

要了解有关 TensorFlow Mobile 的更多信息,请参阅 官方文档.


掘金翻译计划 是一个翻译优质互联网技术文章的社区,文章来源为 掘金 上的英文分享文章。内容覆盖 AndroidiOS前端后端区块链产品设计人工智能等领域,想要查看更多优质译文请持续关注 掘金翻译计划官方微博知乎专栏

相关文章
相关标签/搜索