使用pytorch提取卷積神經(jīng)網(wǎng)絡(luò)的特征圖可視化
前言
文章中的代碼是參考基于Pytorch的特征圖提取編寫(xiě)的代碼本身很簡(jiǎn)單這里只做簡(jiǎn)單的描述。
1. 效果圖
先看效果圖(第一張是原圖,后面的都是相應(yīng)的特征圖,這里使用的網(wǎng)絡(luò)是resnet50,需要注意的是下面圖片顯示的特征圖是經(jīng)過(guò)放大后的圖,原圖是比較小的圖,因?yàn)樘〔焕谖覀冇^察):



2. 完整代碼
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (256, 256))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './images/2.jpg'
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 插入維度
img = img.unsqueeze(0)
img = img.to(device)
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None
dst = './feautures'
therd_size = 256
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
3. 代碼說(shuō)明
下面的模塊是根據(jù)所指定的模型篩選出指定層的特征圖輸出,如果未指定也就是extracted_layers是None則以字典的形式輸出全部的特征圖,另外因?yàn)槿B接層本身是一維的沒(méi)必要輸出因此進(jìn)行了過(guò)濾。
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
這段主要是存儲(chǔ)圖片,為每個(gè)層創(chuàng)建一個(gè)文件夾將特征圖以JET的colormap進(jìn)行按順序存儲(chǔ)到該文件夾,并且如果特征圖過(guò)小也會(huì)對(duì)特征圖放大同時(shí)存儲(chǔ)原始圖和放大后的圖。
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
這里主要是一些參數(shù),比如要提取的網(wǎng)絡(luò),網(wǎng)絡(luò)的權(quán)重,要提取的層,指定的圖像放大的大小,存儲(chǔ)路徑等等。
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None#['conv1']
dst = './feautures'
therd_size = 256
4. 可視化梯度,feature
上面的辦法只是簡(jiǎn)單的將經(jīng)過(guò)網(wǎng)絡(luò)計(jì)算的圖片的輸出的feature進(jìn)行圖片,github上有將CNN的梯度等全部進(jìn)行可視化的代碼:pytorch-cnn-visualizations,需要注意的是如果只是簡(jiǎn)單的替換成自己的網(wǎng)絡(luò)可能無(wú)法運(yùn)行,大概率會(huì)報(bào)model沒(méi)有features或者classifier等錯(cuò)誤,這兩個(gè)是進(jìn)行分類(lèi)網(wǎng)絡(luò)定義時(shí)的Sequential,其實(shí)就是索引網(wǎng)絡(luò)的每一層,自己稍微修改用model.children()等方法進(jìn)行替換即可,我自己修改之后得到的代碼grayondream-pytorch-visualization(本來(lái)想稍微封裝一下成為一個(gè)更加通用的結(jié)構(gòu),暫時(shí)沒(méi)時(shí)間以后再說(shuō)吧!),下面是效果圖:













總結(jié)
到此這篇關(guān)于使用pytorch提取卷積神經(jīng)網(wǎng)絡(luò)的特征圖可視化的文章就介紹到這了,更多相關(guān)pytorch提取特征圖可視化內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Flask框架debug與配置項(xiàng)的開(kāi)啟與設(shè)置詳解
這篇文章主要介紹了Flask框架debug與配置項(xiàng)的開(kāi)啟與設(shè)置,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-09-09
用python寫(xiě)一個(gè)定時(shí)提醒程序的實(shí)現(xiàn)代碼
今天小編就為大家分享一篇用python寫(xiě)一個(gè)定時(shí)提醒程序的實(shí)現(xiàn)代碼,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-07-07
Python中的Request請(qǐng)求重試機(jī)制
這篇文章主要介紹了Python中的Request請(qǐng)求重試機(jī)制,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06
python正則表達(dá)式re.sub各個(gè)參數(shù)的超詳細(xì)講解
Python 的 re 模塊提供了re.sub用于替換字符串中的匹配項(xiàng),下面這篇文章主要給大家介紹了關(guān)于python正則表達(dá)式re.sub各個(gè)參數(shù)的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-07-07
用uWSGI和Nginx部署Flask項(xiàng)目的方法示例
這篇文章主要介紹了用uWSGI和Nginx部署Flask項(xiàng)目的方法示例,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2019-05-05

