Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)基礎(chǔ)GAN和DCGAN詳解
原始生成對(duì)抗網(wǎng)絡(luò)Generative Adversarial Networks GAN包含生成器Generator和判別器Discriminator,數(shù)據(jù)有真實(shí)數(shù)據(jù)groundtruth,還有需要網(wǎng)絡(luò)生成的“fake”數(shù)據(jù),目的是網(wǎng)絡(luò)生成的fake數(shù)據(jù)可以“騙過(guò)”判別器,讓判別器認(rèn)不出來(lái),就是讓判別器分不清進(jìn)入的數(shù)據(jù)是真實(shí)數(shù)據(jù)還是fake數(shù)據(jù)。總的來(lái)說(shuō)是:判別器區(qū)分真實(shí)數(shù)據(jù)和fake數(shù)據(jù)的能力越強(qiáng)越好;生成器生成的數(shù)據(jù)騙過(guò)判別器的能力越強(qiáng)越好,這個(gè)是矛盾的,所以只能交替訓(xùn)練網(wǎng)絡(luò)。
需要搭建生成器網(wǎng)絡(luò)和判別器網(wǎng)絡(luò),訓(xùn)練的時(shí)候交替訓(xùn)練。
首先訓(xùn)練判別器的參數(shù),固定生成器的參數(shù),讓判別器判斷生成器生成的數(shù)據(jù),讓其和0接近,讓判別器判斷真實(shí)數(shù)據(jù),讓其和1接近;
接著訓(xùn)練生成器的參數(shù),固定判別器的參數(shù),讓生成器生成的數(shù)據(jù)進(jìn)入判別器,讓判斷結(jié)果和1接近。生成器生成數(shù)據(jù)需要給定隨機(jī)初始值
線性版:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
def showimg(images,count):
images=images.detach().numpy()[0:16,:]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = int(np.sqrt((images.shape[1])))
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
# gs.update(wspace=0, hspace=0)
print('starting...')
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([width,width]),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
print('showing...')
plt.tight_layout()
plt.savefig('./GAN_Image/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST圖片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Linear(784,300),
nn.LeakyReLU(0.2),
nn.Linear(300,150),
nn.LeakyReLU(0.2),
nn.Linear(150,1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
return x
class generator(nn.Module):
def __init__(self,input_size):
super(generator,self).__init__()
self.gen=nn.Sequential(
nn.Linear(input_size,150),
nn.ReLU(True),
nn.Linear(150,300),
nn.ReLU(True),
nn.Linear(300,784),
nn.Tanh()
)
def forward(self, x):
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension)
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替訓(xùn)練的方式訓(xùn)練網(wǎng)絡(luò)
先訓(xùn)練判別器網(wǎng)絡(luò)D再訓(xùn)練生成器網(wǎng)絡(luò)G
不同網(wǎng)絡(luò)的訓(xùn)練次數(shù)是超參數(shù)
也可以兩個(gè)網(wǎng)絡(luò)訓(xùn)練相同的次數(shù)
這樣就可以不用分別訓(xùn)練兩個(gè)網(wǎng)絡(luò)
'''
count=0
#鑒別器D的訓(xùn)練,固定G的參數(shù)
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
real_img=img.view(num_img,-1)#展開(kāi)為28*28=784
real_label=torch.ones(num_img)#真實(shí)label為1
fake_label=torch.zeros(num_img)#假的label為0
#compute loss of real_img
real_out=D(real_img) #真實(shí)圖片送入判別器D輸出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真實(shí)圖片放入判別器輸出越接近1越好
#compute loss of fake_img
z=torch.randn(num_img,z_dimension)#隨機(jī)生成向量
fake_img=G(z)#將向量放入生成網(wǎng)絡(luò)G生成一張圖片
fake_out=D(fake_img)#判別器判斷假的圖片
d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss
fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判別器D的梯度歸零
d_loss.backward() #反向傳播
d_optimizer.step() #更新判別器D參數(shù)
#生成器G的訓(xùn)練compute loss of fake_img
for j in range(gepoch):
fake_label = torch.ones(num_img) # 真實(shí)label為1
z = torch.randn(num_img, z_dimension) # 隨機(jī)生成向量
fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片
output = D(fake_img) # 經(jīng)過(guò)判別器得到結(jié)果
g_loss = criterion(output, fake_label)#得到假的圖片與真實(shí)標(biāo)簽的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度歸零
g_loss.backward() #反向傳播
g_optimizer.step()#更新生成器G參數(shù)
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
# plt.show()
count += 1
這里的圖分別是 epoch為0、50、100、150、190的運(yùn)行結(jié)果,可以看到圖片中的數(shù)字并不單一

卷積版 Deep Convolutional Generative Adversarial Networks:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import matplotlib.gridspec as gridspec
import os
def showimg(images,count):
images=images.to('cpu')
images=images.detach().numpy()
images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = images.shape[2]
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
print(images.shape)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
# print('showing...')
plt.tight_layout()
# plt.savefig('./GAN_Imaget/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST圖片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Conv2d(1,32,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2)),
nn.Conv2d(32,64,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2))
)
self.fc=nn.Sequential(
nn.Linear(7 * 7 * 64, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
x=x.view(x.size(0),-1)
x=self.fc(x)
return x
class generator(nn.Module):
def __init__(self,input_size,num_feature):
super(generator,self).__init__()
self.fc=nn.Linear(input_size,num_feature) #1*56*56
self.br=nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.gen=nn.Sequential(
nn.Conv2d(1,50,3,stride=1,padding=1),
nn.BatchNorm2d(50),
nn.ReLU(True),
nn.Conv2d(50,25,3,stride=1,padding=1),
nn.BatchNorm2d(25),
nn.ReLU(True),
nn.Conv2d(25,1,2,stride=2),
nn.Tanh()
)
def forward(self, x):
x=self.fc(x)
x=x.view(x.size(0),1,56,56)
x=self.br(x)
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension,3136) #1*56*56
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
D=D.cuda()
G=G.cuda()
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替訓(xùn)練的方式訓(xùn)練網(wǎng)絡(luò)
先訓(xùn)練判別器網(wǎng)絡(luò)D再訓(xùn)練生成器網(wǎng)絡(luò)G
不同網(wǎng)絡(luò)的訓(xùn)練次數(shù)是超參數(shù)
也可以兩個(gè)網(wǎng)絡(luò)訓(xùn)練相同的次數(shù),
這樣就可以不用分別訓(xùn)練兩個(gè)網(wǎng)絡(luò)
'''
count=0
#鑒別器D的訓(xùn)練,固定G的參數(shù)
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
img=Variable(img).cuda()
real_label=Variable(torch.ones(num_img)).cuda()#真實(shí)label為1
fake_label=Variable(torch.zeros(num_img)).cuda()#假的label為0
#compute loss of real_img
real_out=D(img) #真實(shí)圖片送入判別器D輸出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真實(shí)圖片放入判別器輸出越接近1越好
#compute loss of fake_img
z=Variable(torch.randn(num_img,z_dimension)).cuda()#隨機(jī)生成向量
fake_img=G(z)#將向量放入生成網(wǎng)絡(luò)G生成一張圖片
fake_out=D(fake_img)#判別器判斷假的圖片
d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss
fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判別器D的梯度歸零
d_loss.backward() #反向傳播
d_optimizer.step() #更新判別器D參數(shù)
#生成器G的訓(xùn)練compute loss of fake_img
for j in range(gepoch):
fake_label = Variable(torch.ones(num_img)).cuda() # 真實(shí)label為1
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 隨機(jī)生成向量
fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片
output = D(fake_img) # 經(jīng)過(guò)判別器得到結(jié)果
g_loss = criterion(output, fake_label)#得到假的圖片與真實(shí)標(biāo)簽的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度歸零
g_loss.backward() #反向傳播
g_optimizer.step()#更新生成器G參數(shù)
# if ((i+1)%1000==0):
# print("[%d/%d] GLoss: %.5f" % (i + 1, gepoch, g_loss.data[0]))
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
plt.show()
count += 1
這里的gepoch設(shè)置為1,運(yùn)行39次的結(jié)果是:

gepoch設(shè)置為2,運(yùn)行0、25、50、75、100次的結(jié)果是:

gepoch設(shè)置為3,運(yùn)行25、50、75次的結(jié)果是:

gepoch設(shè)置為4,運(yùn)行0、10、20、30、35次的結(jié)果是:

gepoch設(shè)置為5,運(yùn)行0、10、20、25、29次的結(jié)果是:

gepoch設(shè)置為3,z_dimension設(shè)置為190,epoch運(yùn)行0、10、15、20、25、35的結(jié)果是:

可以看到生成的數(shù)字基本沒(méi)有太多的規(guī)律,可能最終都是同個(gè)數(shù)字,不能生成指定的數(shù)字,CGAN就很好的解決這個(gè)問(wèn)題,可以生成指定的數(shù)字 Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)CGAN和生成指定的數(shù)字方式
以上這篇Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)基礎(chǔ)GAN和DCGAN詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python3.8對(duì)可迭代解包的改進(jìn)及用法詳解
這篇文章主要介紹了Python3.8對(duì)可迭代解包的改進(jìn)及用法詳解,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-10-10
利用Python實(shí)現(xiàn)炸彈人游戲的完整代碼
這篇文章主要介紹了如何使用Python的Pygame庫(kù)實(shí)現(xiàn)一個(gè)炸彈人游戲,并對(duì)其進(jìn)行多方面的優(yōu)化,文中通過(guò)代碼介紹的非常詳細(xì),需要的朋友可以參考下2025-01-01
Django框架自定義模型管理器與元選項(xiàng)用法分析
這篇文章主要介紹了Django框架自定義模型管理器與元選項(xiàng)用法,結(jié)合實(shí)例形式分析了自定義模型管理器與元選項(xiàng)的功能、用法及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下2019-07-07
在Python Flask App中獲取已發(fā)布的JSON對(duì)象的解決方案
這篇文章主要介紹了在Python Flask App中獲取已發(fā)布的JSON對(duì)象的解決方案,文中通過(guò)代碼示例介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作有一定的幫助,需要的朋友可以參考下2024-08-08
Python安裝Numpy和matplotlib的方法(推薦)
下面小編就為大家?guī)?lái)一篇Python安裝Numpy和matplotlib的方法(推薦)。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-11-11
tkinter如何實(shí)現(xiàn)label超鏈接調(diào)用瀏覽器打開(kāi)網(wǎng)址
這篇文章主要介紹了tkinter如何實(shí)現(xiàn)label超鏈接調(diào)用瀏覽器打開(kāi)網(wǎng)址問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01
Python實(shí)現(xiàn)分段讀取和保存遙感數(shù)據(jù)
當(dāng)遇到批量讀取大量遙感數(shù)據(jù)進(jìn)行運(yùn)算的時(shí)候,如果不進(jìn)行分段讀取操作的話,電腦內(nèi)存可能面臨著不夠使用的情況,所以我們要進(jìn)行分段讀取數(shù)據(jù)然后進(jìn)行運(yùn)算,運(yùn)算結(jié)束之后把這段數(shù)據(jù)保存成tif文件,本文介紹了Python實(shí)現(xiàn)分段讀取和保存遙感數(shù)據(jù),需要的朋友可以參考下2023-08-08

