Keras使用ImageNet上預訓練的模型方式
我就廢話不多說了,大家還是直接看代碼吧!
import keras import numpy as np from keras.applications import vgg16, inception_v3, resnet50, mobilenet #Load the VGG model vgg_model = vgg16.VGG16(weights='imagenet') #Load the Inception_V3 model inception_model = inception_v3.InceptionV3(weights='imagenet') #Load the ResNet50 model resnet_model = resnet50.ResNet50(weights='imagenet') #Load the MobileNet model mobilenet_model = mobilenet.MobileNet(weights='imagenet')
在以上代碼中,我們首先import各種模型對應的module,然后load模型,并用ImageNet的參數(shù)初始化模型的參數(shù)。
如果不想使用ImageNet上預訓練到的權(quán)重初始話模型,可以將各語句的中'imagenet'替換為'None'。
補充知識:keras上使用alexnet模型來高準確度對mnist數(shù)據(jù)進行分類
綱要
本文有兩個特點:一是直接對本地mnist數(shù)據(jù)進行讀取(假設事先已經(jīng)下載或從別處拷來)二是基于keras框架(網(wǎng)上多是基于tf)使用alexnet對mnist數(shù)據(jù)進行分類,并獲得較高準確度(約為98%)
本地數(shù)據(jù)讀取和分析
很多代碼都是一開始簡單調(diào)用一行代碼來從網(wǎng)站上下載mnist數(shù)據(jù),雖然只有10來MB,但是現(xiàn)在下載速度非常慢,而且經(jīng)常中途出錯,要費很大的勁才能拿到數(shù)據(jù)。
(X_train, y_train), (X_test, y_test) = mnist.load_data()
其實可以單獨來獲得這些數(shù)據(jù)(一共4個gz包,如下所示),然后調(diào)用別的接口來分析它們。

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄
x_train = mnist.train.images y_train = mnist.train.labels x_test = mnist.test.images y_test = mnist.test.labels
這里面要注意的是,兩種接口拿到的數(shù)據(jù)形式是不一樣的。 從網(wǎng)上直接下載下來的數(shù)據(jù) 其image data值的范圍是0~255,且label值為0,1,2,3...9。 而第二種接口獲取的數(shù)據(jù) image值已經(jīng)除以255(歸一化)變成0~1范圍,且label值已經(jīng)是one-hot形式(one_hot=True時),比如label值2的one-hot code為(0 0 1 0 0 0 0 0 0 0)
所以,以第一種方式獲取的數(shù)據(jù)需要做一些預處理(歸一和one-hot)才能輸入網(wǎng)絡模型進行訓練 而第二種接口拿到的數(shù)據(jù)則可以直接進行訓練。
Alexnet模型的微調(diào)
按照公開的模型框架,Alexnet只有第1、2個卷積層才跟著BatchNormalization,后面三個CNN都沒有(如有說錯,請指正)。如果按照這個來搭建網(wǎng)絡模型,很容易導致梯度消失,現(xiàn)象就是 accuracy值一直處在很低的值。 如下所示。

在每個卷積層后面都加上BN后,準確度才迭代提高。如下所示

完整代碼
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已經(jīng)包含了mnist案例的數(shù)據(jù)
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
# input dimensions
img_rows, img_cols = 28,28
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #導入已經(jīng)下載好的數(shù)據(jù)集,"./MNIST_data"為存放mnist數(shù)據(jù)的目錄
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
#Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
# 基于Model方法構(gòu)建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 編譯模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
loss= keras.losses.categorical_crossentropy,
metrics= ['accuracy'])
# 訓練配置,僅供參考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))
以上這篇Keras使用ImageNet上預訓練的模型方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python增量循環(huán)刪除MySQL表數(shù)據(jù)的方法
這篇文章主要介紹了Python增量循環(huán)刪除MySQL表數(shù)據(jù)的相關(guān)資料,本文介紹的非常詳細,具有參考借鑒價值,需要的朋友可以參考下2016-09-09
Python利用標簽實現(xiàn)清理微信好友的自動化腳本
微信已經(jīng)成為我們?nèi)粘I钪胁豢苫蛉钡纳缃还ぞ?隨著使用時間的增長,我們的微信好友列表可能會變得越來越臃腫,所以本文為大家準備了通過標簽清理微信好友的Python自動化腳本,希望對大家有所幫助2024-12-12
python實現(xiàn)讀取excel文件中所有sheet操作示例
這篇文章主要介紹了python實現(xiàn)讀取excel文件中所有sheet操作,涉及Python基于openpyxl模塊的Excel文件讀取、遍歷相關(guān)操作技巧,需要的朋友可以參考下2019-08-08
python使用redis模塊來跟redis實現(xiàn)交互
這篇文章主要介紹了python使用redis模塊來跟redis實現(xiàn)交互,文章圍繞主題展開詳細的內(nèi)容介紹,具有一定的參考價值,需要的小伙伴可以參考一下2022-06-06
Python logging模塊異步線程寫日志實現(xiàn)過程解析
這篇文章主要介紹了Python logging模塊異步線程寫日志實現(xiàn)過程解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-06-06
python忽略警告(warning)的3種方法小結(jié)
python開發(fā)中經(jīng)常遇到報錯的情況,但是warning通常并不影響程序的運行,而且有時特別討厭,下面我們來說下如何忽略warning錯誤,這篇文章主要給大家介紹了關(guān)于python忽略警告(warning)的3種方法,需要的朋友可以參考下2023-10-10

