pytorch:實(shí)現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集)
更新時(shí)間:2020年01月10日 09:17:37 作者:xckkcxxck
今天小編就為大家分享一篇pytorch:實(shí)現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
我就廢話不多說了,直接上代碼吧!
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設(shè)置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images): # 定義畫圖工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = tfs.ToTensor()(x)
return (x - 0.5) / 0.5
def deprocess_img(x):
return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler): # 定義一個(gè)取樣的函數(shù)
"""Samples elements sequentially from some offset.
Arguments:
num_samples: # of desired datapoints
start: offset where we should start selecting from
"""
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
#判別網(wǎng)絡(luò)
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
#生成網(wǎng)絡(luò)
def generator(noise_dim=NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
#判別器的 loss 就是將真實(shí)數(shù)據(jù)的得分判斷為 1,假的數(shù)據(jù)的得分判斷為 0,而生成器的 loss 就是將假的數(shù)據(jù)判斷為 1
bce_loss = nn.BCEWithLogitsLoss()#交叉熵?fù)p失函數(shù)
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
false_labels = Variable(torch.zeros(size, 1)).float()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
loss = bce_loss(logits_fake, true_labels)
return loss
# 使用 adam 來進(jìn)行訓(xùn)練,學(xué)習(xí)率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
# 判別網(wǎng)絡(luò)
real_data = Variable(x).view(bs, -1) # 真實(shí)數(shù)據(jù)
logits_real = D_net(real_data) # 判別網(wǎng)絡(luò)得分
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
logits_fake = D_net(fake_images) # 判別網(wǎng)絡(luò)得分
d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 優(yōu)化判別網(wǎng)絡(luò)
# 生成網(wǎng)絡(luò)
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
gen_logits_fake = D_net(fake_images)
g_error = generator_loss(gen_logits_fake) # 生成網(wǎng)絡(luò)的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 優(yōu)化生成網(wǎng)絡(luò)
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
D = discriminator()
G = generator()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
以上這篇pytorch:實(shí)現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:
- pytorch實(shí)現(xiàn)mnist數(shù)據(jù)集的圖像可視化及保存
- 關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解
- 使用 PyTorch 實(shí)現(xiàn) MLP 并在 MNIST 數(shù)據(jù)集上驗(yàn)證方式
- 用Pytorch訓(xùn)練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- 詳解PyTorch手寫數(shù)字識(shí)別(MNIST數(shù)據(jù)集)
- pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt的方法
- Python PyTorch 如何獲取 MNIST 數(shù)據(jù)
相關(guān)文章
python 實(shí)現(xiàn)一個(gè)反向單位矩陣示例
今天小編就為大家分享一篇python 實(shí)現(xiàn)一個(gè)反向單位矩陣示例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-11-11
使用Python實(shí)現(xiàn)快速搭建本地HTTP服務(wù)器
這篇文章主要介紹了如何使用Python快速搭建本地HTTP服務(wù)器,輕松實(shí)現(xiàn)一鍵 HTTP 文件共享,同時(shí)結(jié)合二維碼技術(shù),讓訪問更簡單,感興趣的小伙伴可以了解下2025-04-04
Python密碼學(xué)XOR算法編碼流程及乘法密碼教程
這篇文章主要為大家介紹了Python密碼學(xué)XOR流程及乘法密碼教程示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05
Django一小時(shí)寫出賬號(hào)密碼管理系統(tǒng)
這篇文章主要介紹了Django一小時(shí)寫出賬號(hào)密碼管理系統(tǒng),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04
利用Pycharm + Django搭建一個(gè)簡單Python Web項(xiàng)目的步驟
這篇文章主要介紹了利用Pycharm + Django搭建一個(gè)簡單Python Web項(xiàng)目的步驟,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10
python運(yùn)行其他程序的實(shí)現(xiàn)方法
這篇文章主要介紹了python運(yùn)行其他程序的實(shí)現(xiàn)方法的相關(guān)資料,需要的朋友可以參考下2017-07-07

