TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。android
咱们知道大多数的 AI 是在云端运算的,可是在移动端使用 AI 具备无网络延迟、响应更加及时、数据隐私等特性。git
对于离线的场合,云端的 AI 就没法使用了,而此时能够在移动设备中使用 TensorFlow Lite。github
TensorFlow 生成的模型是没法直接给移动端使用的,须要离线转换成.tflite文件格式。网络
tflite 存储格式是 flatbuffers。app
FlatBuffers 是由Google开源的一个免费软件库,用于实现序列化格式。它相似于Protocol Buffers、Thrift、Apache Avro。ide
所以,若是要给移动端使用的话,必须把 TensorFlow 训练好的 protobuf 模型文件转换成 FlatBuffers 格式。官方提供了 toco 来实现模型格式的转换。post
TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。不管哪一种 API 都须要加载模型和运行模型。测试
而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。后面的例子会看到如何使用 Interpreter。gradle
mnist 是手写数字图片数据集,包含60000张训练样本和10000张测试样本。 测试集也是一样比例的手写数字数据。每张图片有28x28个像素点构成,每一个像素点用一个灰度值表示,这里是将28x28的像素展开为一个一维的行向量(每行784个值)。ui
mnist 数据集获取地址:yann.lecun.com/exdb/mnist/
下面的 demo 中已经包含了 mnist.tflite 模型文件。(若是没有的话,须要本身训练保存成pb文件,再转换成tflite 格式)
对于一个识别类,首先须要初始化 TensorFlow Lite 解释器,以及输入、输出。
// The tensorflow lite file
private lateinit var tflite: Interpreter
// Input byte buffer
private lateinit var inputBuffer: ByteBuffer
// Output array [batch_size, 10]
private lateinit var mnistOutput: Array<FloatArray>
init {
try {
tflite = Interpreter(loadModelFile(activity))
inputBuffer = ByteBuffer.allocateDirect(
BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)
inputBuffer.order(ByteOrder.nativeOrder())
mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) }
Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.")
} catch (e: IOException) {
Log.e(TAG, "IOException loading the tflite file failed.")
}
}
复制代码
从 asserts 文件中加载 mnist.tflite 模型:
/** * Load the model file from the assets folder */
@Throws(IOException::class)
private fun loadModelFile(activity: Activity): MappedByteBuffer {
val fileDescriptor = activity.assets.openFd(MODEL_PATH)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
复制代码
真正识别手写数字是在 classify() 方法:
val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))
复制代码
classify() 方法包含了预处理用于初始化 inputBuffer、运行 mnist 模型、识别出数字。
/** * Classifies the number with the mnist model. * * @param bitmap * @return the identified number */
fun classify(bitmap: Bitmap): Int {
if (tflite == null) {
Log.e(TAG, "Image classifier has not been initialized; Skipped.")
}
preProcess(bitmap)
runModel()
return postProcess()
}
/** * Converts it into the Byte Buffer to feed into the model * * @param bitmap */
private fun preProcess(bitmap: Bitmap?) {
if (bitmap == null || inputBuffer == null) {
return
}
// Reset the image data
inputBuffer.rewind()
val width = bitmap.width
val height = bitmap.height
// The bitmap shape should be 28 x 28
val pixels = IntArray(width * height)
bitmap.getPixels(pixels, 0, width, 0, 0, width, height)
for (i in pixels.indices) {
// Set 0 for white and 255 for black pixels
val pixel = pixels[i]
// The color of the input is black so the blue channel will be 0xFF.
val channel = pixel and 0xff
inputBuffer.putFloat((0xff - channel).toFloat())
}
}
/** * Run the TFLite model */
private fun runModel() = tflite.run(inputBuffer, mnistOutput)
/** * Go through the output and find the number that was identified. * * @return the number that was identified (returns -1 if one wasn't found) */
private fun postProcess(): Int {
for (i in 0 until mnistOutput[0].size) {
val value = mnistOutput[0][i]
if (value == 1f) {
return i
}
}
return -1
}
复制代码
对于 Android 有一个地方须要注意,必须在 app 模块的 build.gradle 中添加以下的语句,不然没法加载模型。
android {
......
aaptOptions {
noCompress "tflite"
}
}
复制代码
demo 运行效果以下:
本文只是 TF Lite 的初探,不少细节并无详细阐述。应该会在将来的文章中详细介绍。
本文 demo 的 github 地址:github.com/fengzhizi71…
固然,也能够跑一下官方的例子:github.com/tensorflow/…
Java与Android技术栈:每周更新推送原创技术文章,欢迎扫描下方的公众号二维码并关注,期待与您的共同成长和进步。