TensorFlow读写数据

前言

只有光头才能变强。php

文本已收录至个人GitHub仓库,欢迎Star:https://github.com/ZhongFuCheng3y/3ypython

回顾前面:git

众所周知,要训练出一个模型,首先咱们得有数据。咱们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么咱们是提早下载好,放在对应的目录上,要么就根据他给的url直接从网上下载)。github

通常来讲,咱们使用TensorFlow是从TFRecord文件中读取数据的。api

TFRecord 文件格式是一种面向记录的简单二进制格式,不少 TensorFlow 应用采用此格式来训练数据网络

因此,这篇文章来聊聊怎么读取TFRecord文件的数据。session

1、入门对数据集的数据进行读和写

首先,咱们来体验一下怎么造一个TFRecord文件,怎么从TFRecord文件中读取数据,遍历(消费)这些数据。数据结构

1.1 造一个TFRecord文件

如今,咱们尚未TFRecord文件,咱们能够本身简单写一个:机器学习

def write_sample_to_tfrecord():
    gmv_values = np.arange(10)
    click_values = np.arange(10)
    label_values = np.arange(10)

    with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
        for _ in range(10):
            feature_internal = {
                "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
                "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
            }
            features_extern = tf.train.Features(feature=feature_internal)

            # 使用tf.train.Example将features编码数据封装成特定的PB协议格式
            # example = tf.train.Example(features=tf.train.Features(feature=features_extern))
            example = tf.train.Example(features=features_extern)

            # 将example数据系列化为字符串
            example_str = example.SerializeToString()

            # 将系列化为字符串的example数据写入协议缓冲区
            writer.write(example_str)


if __name__ == '__main__':
    write_sample_to_tfrecord()

我相信你们代码应该是可以看得懂的,其实就是分了几步:函数

  • 生成TFRecord Writer
  • tf.train.Feature生成协议信息
  • 使用tf.train.Example将features编码数据封装成特定的PB协议格式
  • 将example数据系列化为字符串
  • 将系列化为字符串的example数据写入协议缓冲区

参考资料:

ok,如今咱们就有了一个TFRecord文件啦。

1.2 读取TFRecord文件

  • 其实就是经过tf.data.TFRecordDataset这个api来读取到TFRecord文件,生成处dataset对象

  • 对dataset进行处理(shape处理,格式处理...等等)
  • 使用迭代器对dataset进行消费(遍历)

demo代码以下:

import tensorflow as tf


def read_tensorflow_tfrecord_files():
    # 定义消费缓冲区协议的parser,做为dataset.map()方法中传入的lambda:
    def _parse_function(single_sample):
        features = {
            "gmv": tf.FixedLenFeature([1], tf.float32),
            "click": tf.FixedLenFeature([1], tf.int64),  # ()或者[]没啥影响
            "label": tf.FixedLenFeature([1], tf.int64)
        }
        parsed_features = tf.parse_single_example(single_sample, features=features)

        # 对parsed 以后的值进行cast.
        gmv = tf.cast(parsed_features["gmv"], tf.float64)
        click = tf.cast(parsed_features["click"], tf.float64)
        label = tf.cast(parsed_features["label"], tf.float64)

        return gmv, click, label

    # 开始定义dataset以及解析tfrecord格式
    filenames = tf.placeholder(tf.string, shape=[None])

    # 定义dataset 和 一些列trasformation method
    dataset = tf.data.TFRecordDataset(filenames)
    parsed_dataset = dataset.map(_parse_function)  # 消费缓冲区须要定义在dataset 的map 函数中
    batchd_dataset = parsed_dataset.batch(3)

    # 建立Iterator
    sample_iter = batchd_dataset.make_initializable_iterator()
    # 获取next_sample
    gmv, click, label = sample_iter.get_next()
    training_filenames = [
        "/Users/zhongfucheng/data/fashin/demo.tfrecord"]
    with tf.Session() as session:
        # 初始化带参数的Iterator
        session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
        # 读取文件
        print(session.run(gmv))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

无心外的话,咱们能够输出这样的结果:

[[0.]
 [1.]
 [2.]]

ok,如今咱们已经大概知道怎么写一个TFRecord文件,以及怎么读取TFRecord文件的数据,而且消费这些数据了。

2、epoch和batchSize术语解释

我在学习TensorFlow翻阅资料时,常常看到一些机器学习的术语,因为本身没啥机器学习的基础,因此不少时候看到一些专业名词就开始懵逼了。

2.1epoch

当一个完整的数据集经过了神经网络一次而且返回了一次,这个过程称为一个epoch

这可能使咱们跟dataset.repeat()方法联系起来,这个方法可使当前数据集重复一遍。好比说,原有的数据集是[1,2,3,4,5],若是我调用dataset.repeat(2)的话,那么咱们的数据集就变成了[1,2,3,4,5],[1,2,3,4,5]

  • 因此会有个说法:假设原先的数据是一个epoch,使用repeat(5)就能够将之变成5个epoch

2.2batchSize

通常来讲咱们的数据集都是比较大的,没法一次性将整个数据集的数据喂进神经网络中,因此咱们会将数据集分红好几个部分。每次喂多少条样本进神经网络,这个叫作batchSize。

在TensorFlow也提供了方法给咱们设置:dataset.batch(),在API中是这样介绍batchSize的:

representing the number of consecutive elements of this dataset to combine in a single batch

咱们通常在每次训练以前,会将整个数据集的顺序打乱,提升咱们模型训练的效果。这里咱们用到的api是:dataset.shffle();

3、再来聊聊dataset

我从官网的介绍中截了一个dataset的方法图(部分):

dataset的方法图

dataset的功能主要有如下三种:

  • 建立dataset实例
    • 经过文件建立(好比TFRecord)
    • 经过内存建立
  • 对数据集的数据进行变换
    • 好比上面的batch(),常见的map(),flat_map(),zip(),repeat()等等
    • 文档中通常都有给出例子,跑一下通常就知道对应的意思了。
  • 建立迭代器,遍历数据集的数据

3.1 聊聊迭代器

迭代器能够分为四种:

  • 单次。对数据集进行一次迭代,不支持参数化
  • 可初始化迭代
    • 使用前须要进行初始化,支持传入参数。面向的是同一个DataSet
  • 可从新初始化:同一个Iterator从不一样的DataSet中读取数据
    • DataSet的对象具备相同的结构,可使用tf.data.Iterator.from_structure来进行初始化
    • 问题:每次 Iterator 切换时,数据都从头开始打印了
  • 可馈送(也是经过对象相同的结果来建立的迭代器)
    • 可以让您在两个数据集之间切换的可馈送迭代器
    • 经过一个string handler来实现。
    • 可馈送的 Iterator 在不一样的 Iterator 切换的时候,能够作到不从头开始

简单总结:

  • 一、 单次 Iterator ,它最简单,但没法重用,没法处理数据集参数化的要求。
  • 二、 能够初始化的 Iterator ,它能够知足 Dataset 重复加载数据,知足了参数化要求。
  • 三、可从新初始化的 Iterator,它能够对接不一样的 Dataset,也就是能够从不一样的 Dataset 中读取数据。
  • 四、可馈送的 Iterator,它能够经过 feeding 的方式,让程序在运行时候选择正确的 Iterator,它和可从新初始化的 Iterator 不一样的地方就是它的数据在不一样的 Iterator 切换时,能够作到不重头开始读取数据

string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个Demo来使用了一下,代码以下:

def read_tensorflow_tfrecord_files():
    # 开始定义dataset以及解析tfrecord格式.
    train_filenames = tf.placeholder(tf.string, shape=[None])
    vali_filenames = tf.placeholder(tf.string, shape=[None])

    # 加载train_dataset   batch_inputs这个方法每一个人都不同的,这个方法我就不给了。
    train_dataset = batch_inputs([
        train_filenames], batch_size=5, type=False,
        num_epochs=2, num_preprocess_threads=3)
    # 加载validation_dataset  batch_inputs这个方法每一个人都不同的,这个方法我就不给了。
    validation_dataset = batch_inputs([vali_filenames
                                       ], batch_size=5, type=False,
                                      num_epochs=2, num_preprocess_threads=3)

    # 建立出string_handler()的迭代器(经过相同数据结构的dataset来构建)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)

    # 有了迭代器就能够调用next方法了。
    itemid = iterator.get_next()

    # 指定哪一种具体的迭代器,有单次迭代的,有初始化的。
    training_iterator = train_dataset.make_initializable_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()

    # 定义出placeholder的值
    training_filenames = [
        "/Users/zhongfucheng/tfrecord_test/data01aa"]
    validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]

    with tf.Session() as sess:
        # 初始化迭代器
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())

        for _ in range(2):
            sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
            print("this is training iterator ----")

            for _ in range(5):
                print(sess.run(itemid, feed_dict={handle: training_handle}))

            sess.run(validation_iterator.initializer,
                     feed_dict={vali_filenames: validation_filenames})

            print("this is validation iterator ")
            for _ in range(5):
                print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))


if __name__ == '__main__':
    read_tensorflow_tfrecord_files()

参考资料:

3.2 dataset参考资料

在翻阅资料时,发现写得不错的一些博客:

最后

乐于输出干货的Java技术公众号:Java3y。公众号内有200多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!

下一篇文章打算讲讲如何理解axis~

帅的人都关注了

以为个人文章写得不错,不妨点一下

相关文章
相关标签/搜索