Tensorflow中使用TFRecords高效读取数据--结合NLP数据实践

之前一篇博客在进行论文仿真的时候用到了TFRecords进行数据的读取操作,但是因为当时比较忙,所以没有进行深入学习。这两天看了一下,决定写篇博客专门结合该代码记录一下TFRecords的相关操作。
首先说一下为什么要使用TFRecords来进行文件的读写,在TF中数据的传入方式主要包含以下几种:

  1. 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
  2. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
  3. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

之前都是使用1和3进行数据的操作,但是当我们遇到数据集比较大的情况时,这两种方法会及其占用内存,效率很差。那么为甚么使用TFRecords会比较快呢?因为其使用二进制存储文件,也就是将数据存储在一个内存块中,相比其它文件格式要快很多,特别是如果你使用hdd而不是ssd,因为它涉及移动磁盘阅读器头并且需要相当长的时间。总体而言,通过使用二进制文件,您可以更轻松地分发数据,使数据更好地对齐,以实现高效的读取。接下来我们看一下具体的操作。

个人感觉可以分成两部分,一是将文件保存成TFRecords格式的.tfrecords文件,这里主要涉及到使用tf.python_io.TFRecordWriter("train.tfrecords")tf.train.Example以及tf.train.Features三个函数,第一个是生成需要对应格式的文件,后面两个函数主要是将我们要传入的数据按照一定的格式进行规范化。这里还要提到一点就是使用TFRecords可以避免多个文件的使用,比如说我们一般会将一次要传入的数据的不同部分分别存放在不同文件夹中,question一个,answer一个,query一个等等,但是使用TFRecords之后,我们可以将一批数据同时保存在一个文件之中,这样方便我们在后续程序中的使用。

另一部分就是在训练模型时将我们生成的.tfrecords文件读入并传到模型中进行使用。这部分主要涉及到使用tf.TFRecordReader("train.tfrecords")tf.parse_single_example两个函数。第一个函数是将我们的二进制文件读入,第二个则是进行解析然后得到我们想要的数据。

接下来我们结合代码进行理解:

生成TFRecords文件

这里关于要使用的数据集的介绍可以参考我的上一篇博客,主要是QA任务的数据集。代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def tokenize(index, word):
#index是每个单词对应词袋子之中的索引值,word是所有出现的单词
directories = ['cnn/questions/training/', 'cnn/questions/validation/', 'cnn/questions/test/']
for directory in directories:
#分别读取训练测试验证集的数据
out_name = directory.split('/')[-2] + '.tfrecords'
#生成对应.tfrecords文件
writer = tf.python_io.TFRecordWriter(out_name)
#每个文件夹下面都有若干文件,每个文件代表一个QA队,也就是一条训练数据
files = map(lambda file_name: directory + file_name, os.listdir(directory))
for file_name in files:
with open(file_name, 'r') as f:
lines = f.readlines()
#对每条数据分别获得文档,问题,答案三个值,并将相应单词转化为索引
document = [index[token] for token in lines[2].split()]
query = [index[token] for token in lines[4].split()]
answer = [index[token] for token in lines[6].split()]
#调用Example和Features函数将数据格式化保存起来。注意Features传入的参数应该是一个字典,方便后续读数据时的操作
example = tf.train.Example(
features = tf.train.Features(
feature = {
'document': tf.train.Feature(
int64_list=tf.train.Int64List(value=document)),
'query': tf.train.Feature(
int64_list=tf.train.Int64List(value=query)),
'answer': tf.train.Feature(
int64_list=tf.train.Int64List(value=answer))
}))
true #写数据
serialized = example.SerializeToString()
writer.write(serialized)

读取.tfrecords文件

因为在读取数据之后我们可能还会进行一些额外的操作,使我们的数据格式满足模型输入,所以这里会引入一些额外的函数来实现我们的目的。这里介绍几个个人感觉较重要常用的函数。不过还是推荐到官网API去查,或者有某种需求的时候到Stack Overflow上面搜一搜,一般都能找到满足自己需求的函数。
1,string_input_producer( string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None )其输出是一个输入管道的队列

2,shuffle_batch( tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None )产生随机打乱之后的batch数据

3,sparse_ops.serialize_sparse(sp_input, name=None): 返回一个字符串的3-vector(1-D的tensor),分别表示索引、值、shape

4,deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): 将多个稀疏的serialized_sparse合并成一个

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def read_records(index=0):
#生成读取数据的队列,要指定epoches
train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)
queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
#定义一个recordreader对象,用于数据的读取
reader = tf.TFRecordReader()
#从之前的队列中读取数据到serialized_example
_, serialized_example = reader.read(queue)
#调用parse_single_example函数解析数据
features = tf.parse_single_example(
serialized_example,
features={
'document': tf.VarLenFeature(tf.int64),
'query': tf.VarLenFeature(tf.int64),
'answer': tf.FixedLenFeature([], tf.int64)
})
#返回索引、值、shape的三元组信息
document = sparse_ops.serialize_sparse(features['document'])
query = sparse_ops.serialize_sparse(features['query'])
answer = features['answer']
#生成batch切分数据
document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
[document, query, answer], batch_size=FLAGS.batch_size,
capacity=2000,
min_after_dequeue=1000)
sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)
document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.shape, 1)
query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.shape, 1)
return document_batch, document_weights, query_batch, query_weights, answer_batch

至此,我们就可以
参考连接:
https://www.tensorflow.org/programmers_guide/reading_data

http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/

坚持原创技术分享,您的支持将鼓励我继续创作!