將自己的數(shù)據(jù)集制作成TFRecord格式教程
在使用TensorFlow訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),首先面臨的問題是:網(wǎng)絡(luò)的輸入
此篇文章,教大家將自己的數(shù)據(jù)集制作成TFRecord格式,feed進(jìn)網(wǎng)絡(luò),除了TFRecord格式,TensorFlow也支持其他格
式的數(shù)據(jù),此處就不再介紹了。建議大家使用TFRecord格式,在后面可以通過api進(jìn)行多線程的讀取文件隊(duì)列。
1. 原本的數(shù)據(jù)集
此時(shí),我有兩類圖片,分別是xiansu100,xiansu60,每一類中有10張圖片。

2.制作成TFRecord格式
tfrecord會(huì)根據(jù)你選擇輸入文件的類,自動(dòng)給每一類打上同樣的標(biāo)簽。如在本例中,只有0,1 兩類,想知道文件夾名與label關(guān)系的,可以自己保存起來。
#生成整數(shù)型的屬性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
#生成字符串類型的屬性
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
#制作TFRecord格式
def createTFRecord(filename,mapfile):
class_map = {}
data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
classes = {'xiansu60','xiansu100'}
#輸出TFRecord文件的地址
writer = tf.python_io.TFRecordWriter(filename)
for index,name in enumerate(classes):
class_path=data_dir+name+'/'
class_map[index] = name
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每個(gè)圖片的地址
img = Image.open(img_path)
img= img.resize((224,224))
img_raw = img.tobytes() #將圖片轉(zhuǎn)化成二進(jìn)制格式
example = tf.train.Example(features = tf.train.Features(feature = {
'label':_int64_feature(index),
'image_raw': _bytes_feature(img_raw)
}))
writer.write(example.SerializeToString())
writer.close()
txtfile = open(mapfile,'w+')
for key in class_map.keys():
txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()
此段代碼,運(yùn)行完后會(huì)產(chǎn)生生成的.tfrecord文件。
3. 讀取TFRecord的數(shù)據(jù),進(jìn)行解析,此時(shí)使用了文件隊(duì)列以及多線程
#讀取train.tfrecord中的數(shù)據(jù)
def read_and_decode(filename):
#創(chuàng)建一個(gè)reader來讀取TFRecord文件中的樣例
reader = tf.TFRecordReader()
#創(chuàng)建一個(gè)隊(duì)列來維護(hù)輸入文件列表
filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
#從文件中讀出一個(gè)樣例,也可以使用read_up_to一次讀取多個(gè)樣例
_,serialized_example = reader.read(filename_queue)
# print _,serialized_example
#解析讀入的一個(gè)樣例,如果需要解析多個(gè),可以用parse_example
features = tf.parse_single_example(
serialized_example,
features = {'label':tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),})
#將字符串解析成圖像對應(yīng)的像素?cái)?shù)組
img = tf.decode_raw(features['image_raw'], tf.uint8)
img = tf.reshape(img,[224, 224, 3]) #reshape為128*128*3通道圖片
img = tf.image.per_image_standardization(img)
labels = tf.cast(features['label'], tf.int32)
return img, labels
4. 將圖片幾個(gè)一打包,形成batch
def createBatch(filename,batchsize):
images,labels = read_and_decode(filename)
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batchsize
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batchsize,
capacity=capacity,
min_after_dequeue=min_after_dequeue
)
label_batch = tf.one_hot(label_batch,depth=2)
return image_batch, label_batch
5.主函數(shù)
if __name__ =="__main__":
#訓(xùn)練圖片兩張為一個(gè)batch,進(jìn)行訓(xùn)練,測試圖片一起進(jìn)行測試
mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
# createTFRecord(train_filename,mapfile)
test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
# createTFRecord(test_filename,mapfile)
image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
with tf.Session() as sess:
initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(initop)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
try:
step = 0
while 1:
_image_batch,_label_batch = sess.run([image_batch,label_batch])
step += 1
print step
print (_label_batch)
except tf.errors.OutOfRangeError:
print (" trainData done!")
try:
step = 0
while 1:
_test_images,_test_labels = sess.run([test_images,test_labels])
step += 1
print step
# print _image_batch.shape
print (_test_labels)
except tf.errors.OutOfRangeError:
print (" TEST done!")
coord.request_stop()
coord.join(threads)
此時(shí),生成的batch,就可以feed進(jìn)網(wǎng)絡(luò)了。
以上這篇將自己的數(shù)據(jù)集制作成TFRecord格式教程就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
30秒學(xué)會(huì)30個(gè)超實(shí)用Python代碼片段【收藏版】
許多人在數(shù)據(jù)科學(xué)、機(jī)器學(xué)習(xí)、web開發(fā)、腳本編寫和自動(dòng)化等領(lǐng)域中都會(huì)使用Python,它是一種十分流行的語言。本文將簡要介紹30個(gè)簡短的、且能在30秒內(nèi)掌握的代碼片段,感興趣的朋友一起看看吧2019-10-10
Python基于numpy模塊實(shí)現(xiàn)回歸預(yù)測
這篇文章主要介紹了Python基于numpy模塊實(shí)現(xiàn)回歸預(yù)測,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05
使用python批量修改文件名的方法(視頻合并時(shí))
這篇文章主要介紹了視頻合并時(shí)使用python批量修改文件名的方法,代碼簡單易懂,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-08-08
Django?事務(wù)回滾的具體實(shí)現(xiàn)
本文主要介紹了Django?事務(wù)回滾的具體實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02
pycharm遠(yuǎn)程連接vagrant虛擬機(jī)中mariadb數(shù)據(jù)庫
這篇文章主要介紹了pycharm遠(yuǎn)程連接vagrant虛擬機(jī)中mariadb數(shù)據(jù)庫,需要的朋友可以參考下2020-06-06
Python限制內(nèi)存和CPU使用量的方法(Unix系統(tǒng)適用)
這篇文章主要介紹了Python限制內(nèi)存和CPU的使用量的方法,文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下2020-08-08

