pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)
在圖像分割這個(gè)問題上,主要有兩個(gè)流派:Encoder-Decoder和Dialated Conv。本文介紹的是編解碼網(wǎng)絡(luò)中最為經(jīng)典的U-Net。隨著骨干網(wǎng)路的進(jìn)化,很多相應(yīng)衍生出來(lái)的網(wǎng)絡(luò)大多都是對(duì)于Unet進(jìn)行了改進(jìn)但是本質(zhì)上的思路還是沒有太多的變化。比如結(jié)合DenseNet 和Unet的FCDenseNet, Unet++
一、Unet網(wǎng)絡(luò)介紹
論文:https://arxiv.org/abs/1505.04597v1(2015)
UNet的設(shè)計(jì)就是應(yīng)用與醫(yī)學(xué)圖像的分割。由于醫(yī)學(xué)影像處理中,數(shù)據(jù)量較少,本文提出的方法有效提升了使用少量數(shù)據(jù)集訓(xùn)練檢測(cè)的效果,提出了處理大尺寸圖像的有效方法。
UNet的網(wǎng)絡(luò)架構(gòu)繼承自FCN,并在此基礎(chǔ)上做了些改變。提出了Encoder-Decoder概念,實(shí)際上就是FCN那個(gè)先卷積再上采樣的思想。

上圖是Unet的網(wǎng)絡(luò)結(jié)構(gòu),從圖中可以看出,
結(jié)構(gòu)左邊為Encoder,即下采樣提取特征的過程。Encoder基本模塊為雙卷積形式,即輸入經(jīng)過兩個(gè)
conu 3x3,使用的valid卷積,在代碼實(shí)現(xiàn)時(shí)我們可以增加padding使用same卷積,來(lái)適應(yīng)Skip Architecture。下采樣采用的池化層直接縮小2倍。
結(jié)構(gòu)右邊是Decoder,即上采樣恢復(fù)圖像尺寸并預(yù)測(cè)的過程。Decoder一樣采用雙卷積的形式,其中上采樣使用轉(zhuǎn)置卷積實(shí)現(xiàn),每次轉(zhuǎn)置卷積放大2倍。
結(jié)構(gòu)中間copy and crop是一個(gè)cat操作,即feature map的通道疊加。
二、VOC訓(xùn)練Unet
2.1 Unet代碼實(shí)現(xiàn)
根據(jù)上面對(duì)于Unet網(wǎng)絡(luò)結(jié)構(gòu)的介紹,可見其結(jié)構(gòu)非常對(duì)稱簡(jiǎn)單,代碼Unet.py實(shí)現(xiàn)如下:
from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
# Encoder
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# Decoder
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.output = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
conv1 = self.conv1(x)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
pool4 = self.pool4(conv4)
conv5 = self.conv5(pool4)
up6 = self.up6(conv5)
meger6 = torch.cat([up6, conv4], dim=1)
conv6 = self.conv6(meger6)
up7 = self.up7(conv6)
meger7 = torch.cat([up7, conv3], dim=1)
conv7 = self.conv7(meger7)
up8 = self.up8(conv7)
meger8 = torch.cat([up8, conv2], dim=1)
conv8 = self.conv8(meger8)
up9 = self.up9(conv8)
meger9 = torch.cat([up9, conv1], dim=1)
conv9 = self.conv9(meger9)
out = self.output(conv9)
return out
if __name__=="__main__":
model = Unet(3, 21)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(model)2.2 數(shù)據(jù)集處理


數(shù)據(jù)來(lái)源于kaggle,下載地址我忘了。包含2個(gè)類別,1個(gè)車,還有1個(gè)背景類,共有5k+的數(shù)據(jù),按照比例分為訓(xùn)練集和驗(yàn)證集即可。具體見carnava.py
from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
def __init__(self, root, train=True):
self.root = root
self.crop_size = (256, 256)
self.img_path = os.path.join(root, "train_hq")
self.label_path = os.path.join(root, "train_masks")
img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
train_path_list, val_path_list = self._split_data_set(img_path_list)
if train:
self.imgs_list = train_path_list
else:
self.imgs_list = val_path_list
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.transforms = T.Compose([
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
normalize
])
self.transforms_val = T.Compose([
T.Resize(256),
T.CenterCrop(256)
])
self.color_map = [[0, 0, 0], [255, 255, 255]]
def __getitem__(self, index: int):
im_path = self.imgs_list[index]
image = Image.open(im_path).convert("RGB")
data = self.transforms(image)
(filepath, filename) = os.path.split(im_path)
filename = filename.split('.')[0]
label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
label = self.transforms_val(label)
cm2lb=np.zeros(256**3)
for i,cm in enumerate(self.color_map):
cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
image=np.array(label,dtype=np.int64)
idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
label=np.array(cm2lb[idx],dtype=np.int64)
label=torch.from_numpy(label).long()
return data, label
def label2img(self, label):
cmap = self.color_map
cmap = np.array(cmap).astype(np.uint8)
pred = cmap[label]
return pred
def __len__(self):
return len(self.imgs_list)
def _split_data_set(self, img_path_list):
val_path_list = img_path_list[::8]
train_path_list = []
for item in img_path_list:
if item not in val_path_list:
train_path_list.append(item)
return train_path_list, val_path_list
if __name__=="__main__":
root = "../dataset/carvana"
car_train = Car(root,train=True)
train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
print(len(car_train))
print(len(train_dataloader))
# for data, label in car_train:
# print(data.shape)
# print(label.shape)
# break
(data, label) = car_train[190]
label_np = label.data.numpy()
label_im = car_train.label2img(label_np)
plt.figure()
plt.imshow(label_im)
plt.show()2.3 訓(xùn)練過程
分割其實(shí)就是給每個(gè)像素分類而已,所以損失函數(shù)依舊是交叉熵函數(shù),正確率為分類正確的像素點(diǎn)個(gè)數(shù)/全部的像素點(diǎn)個(gè)數(shù)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 計(jì)算混淆矩陣
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def label_accuracy_score(label_trues, label_preds, n_class):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
"""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
acc = np.diag(hist).sum() / hist.sum()
with np.errstate(divide='ignore', invalid='ignore'):
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
with np.errstate(divide='ignore', invalid='ignore'):
iu = np.diag(hist) / (
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
best_score = 0.0
for e in range(epochs):
net.train()
train_loss = 0.0
label_true = torch.LongTensor()
label_pred = torch.LongTensor()
for batch_id, (data, label) in enumerate(train_dataloader):
data, label = data.to(device), label.to(device)
output = net(data)
loss = criterion(output, label)
pred = output.argmax(dim=1).squeeze().data.cpu()
real = label.data.cpu()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss+=loss.cpu().item()
label_true = torch.cat((label_true,real),dim=0)
label_pred = torch.cat((label_pred,pred),dim=0)
train_loss /= len(train_dataloader)
acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
e+1, train_loss, acc, acc_cls, mean_iu))
with open(log_path, 'a') as f:
f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
e+1,train_loss,acc, acc_cls, mean_iu))
net.eval()
val_loss = 0.0
val_label_true = torch.LongTensor()
val_label_pred = torch.LongTensor()
with torch.no_grad():
for batch_id, (data, label) in enumerate(val_dataloader):
data, label = data.to(device), label.to(device)
output = net(data)
loss = criterion(output, label)
pred = output.argmax(dim=1).squeeze().data.cpu()
real = label.data.cpu()
val_loss += loss.cpu().item()
val_label_true = torch.cat((val_label_true, real), dim=0)
val_label_pred = torch.cat((val_label_pred, pred), dim=0)
val_loss/=len(val_dataloader)
val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
val_label_pred.numpy(),numclasses)
print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
with open(log_path, 'a') as f:
f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
score = (val_acc_cls+val_mean_iu)/2
if score > best_score:
best_score = score
torch.save(net.state_dict(), model_path)
def evaluate():
import util
import random
import matplotlib.pyplot as plt
net.load_state_dict(torch.load(model_path))
index = random.randint(0, len(val_data)-1)
val_image, val_label = val_data[index]
out = net(val_image.unsqueeze(0).to(device))
pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
label = val_label.data.numpy()
img_pred = val_data.label2img(pred)
img_label = val_data.label2img(label)
temp = val_image.numpy()
temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
fig, ax = plt.subplots(1,3)
ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
ax[1].imshow(img_label)
ax[2].imshow(img_pred)
plt.show()
if __name__=="__main__":
# train_model()
evaluate()最終訓(xùn)練結(jié)果是:

由于數(shù)據(jù)比較簡(jiǎn)單,訓(xùn)練到epoch為5時(shí),mIOU就已經(jīng)達(dá)到0.97了。
最后測(cè)試一下效果:

從左到右分別是:原圖、真實(shí)label、預(yù)測(cè)label
備注:
其實(shí)最開始使用voc數(shù)據(jù)集訓(xùn)練的,但效果極差,也沒發(fā)現(xiàn)哪里有問題。換個(gè)數(shù)據(jù)集效果就好了,可能有兩個(gè)原因:
1. voc數(shù)據(jù)我在處理數(shù)據(jù)時(shí)出錯(cuò)了,沒檢查出來(lái)
2. 這個(gè)數(shù)據(jù)集比較簡(jiǎn)單,容易學(xué)習(xí),所以效果差不多。
到此這篇關(guān)于pytorch通過自己的數(shù)據(jù)集訓(xùn)練Unet網(wǎng)絡(luò)架構(gòu)的文章就介紹到這了,更多相關(guān)pytorch Unet內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pyttsx3實(shí)現(xiàn)中文文字轉(zhuǎn)語(yǔ)音的方法
今天小編就為大家分享一篇pyttsx3實(shí)現(xiàn)中文文字轉(zhuǎn)語(yǔ)音的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-12-12
使用Numpy讀取CSV文件,并進(jìn)行行列刪除的操作方法
今天小編就為大家分享一篇使用Numpy讀取CSV文件,并進(jìn)行行列刪除的操作方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-07-07
Paddle模型性能分析工具Profiler定位瓶頸點(diǎn)優(yōu)化程序詳解
這篇文章主要為大家介紹了Paddle模型性能分析工具Profiler定位瓶頸點(diǎn)優(yōu)化程序詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-03-03
TensorFlow實(shí)現(xiàn)RNN循環(huán)神經(jīng)網(wǎng)絡(luò)
這篇文章主要介紹了TensorFlow實(shí)現(xiàn)RNN循環(huán)神經(jīng)網(wǎng)絡(luò),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來(lái)看看吧2018-02-02
使用keras框架cnn+ctc_loss識(shí)別不定長(zhǎng)字符圖片操作
這篇文章主要介紹了使用keras框架cnn+ctc_loss識(shí)別不定長(zhǎng)字符圖片操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2020-06-06
Python實(shí)現(xiàn)AVIF圖片與其他圖片格式間的批量轉(zhuǎn)換
這篇文章主要為大家詳細(xì)介紹了如何使用 Pillow 庫(kù)實(shí)現(xiàn)AVIF與其他格式的相互轉(zhuǎn)換,即將AVIF轉(zhuǎn)換為常見的格式,比如 JPG 或 PNG,需要的小伙伴可以參考下2025-04-04
Python?xpath,JsonPath,bs4的基本使用
這篇文章主要介紹了Python?xpath,JsonPath,bs4的基本使用,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,感興趣的小伙伴可以參考一下2022-07-07

