Tensorflow使用tfrecord輸入數(shù)據(jù)格式
Tensorflow 提供了一種統(tǒng)一的格式來存儲數(shù)據(jù),這個格式就是TFRecord,上一篇文章中所提到的方法當數(shù)據(jù)的來源更復雜,每個樣例中的信息更豐富的時候就很難有效的記錄輸入數(shù)據(jù)中的信息了,于是Tensorflow提供了TFRecord來統(tǒng)一存儲數(shù)據(jù),接下來我們就來介紹如何使用TFRecord來同意輸入數(shù)據(jù)的格式。
1. TFRecord格式介紹
TFRecord文件中的數(shù)據(jù)是通過tf.train.Example Protocol Buffer的格式存儲的,下面是tf.train.Example的定義
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
從上述代碼可以看到,ft.train.Example 的數(shù)據(jù)結構相對簡潔。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字符串,屬性的取值可以為字符串(BytesList ),實數(shù)列表(FloatList )或整數(shù)列表(Int64List )。例如我們可以將解碼前的圖片作為字符串,圖像對應的類別標號作為整數(shù)列表。
2. 將自己的數(shù)據(jù)轉化為TFRecord格式
準備數(shù)據(jù)
在上一篇中,我們?yōu)榱讼駛ゴ蟮腗NIST致敬,所以選擇圖像的前綴來進行不同類別的分類依據(jù),但是大多數(shù)的情況下,在進行分類任務的過程中,不同的類別都會放在不同的文件夾下,而且類別的個數(shù)往往浮動性又很大,所以針對這樣的情況,我們現(xiàn)在利用不同類別在不同文件夾中的圖像來生成TFRecord.
我們在Iris&Contact這個文件夾下有兩個文件夾,分別為iris,contact。對于每個文件夾中存放的是對應的圖片
轉換數(shù)據(jù)
數(shù)據(jù)準備好以后,就開始準備生成TFRecord,具體代碼如下:
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'}
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords")
for index,name in enumerate(classes):
class_path=cwd+name+'/'
for img_name in os.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img= img.resize((512,80))
img_raw=img.tobytes()
#plt.imshow(img) # if you want to check you image,please delete '#'
#plt.show()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
3. Tensorflow從TFRecord中讀取數(shù)據(jù)
def read_and_decode(filename): # read iris_contact.tfrecords
filename_queue = tf.train.string_input_producer([filename])# create a queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
label = tf.cast(features['label'], tf.int32) #throw label tensor
return img, label
4. 將TFRecord中的數(shù)據(jù)保存為圖片
filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #return file and file_name
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(20):
example, l = sess.run([image,label])#take out image and label
img=Image.fromarray(example, 'RGB')
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
print(example, l)
coord.request_stop()
coord.join(threads)
以上就是本文的全部內(nèi)容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。
相關文章
Caffe均值文件mean.binaryproto轉mean.npy的方法
今天小編就為大家分享一篇Caffe均值文件mean.binaryproto轉mean.npy的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07
PyCharm 無法 import pandas 程序卡住的解決方式
這篇文章主要介紹了PyCharm 無法 import pandas 程序卡住的解決方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03
python使用SimpleXMLRPCServer實現(xiàn)簡單的rpc過程
這篇文章主要介紹了python使用SimpleXMLRPCServer實現(xiàn)簡單的rpc過程,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-06-06

