Keras自動(dòng)下載的數(shù)據(jù)集/模型存放位置介紹
Mac
# 數(shù)據(jù)集
~/.keras/datasets/# 模型
~/.keras/models/
Linux
# 數(shù)據(jù)集
~/.keras/datasets/
Windows
# win10
C:\Users\user_name\.keras\datasets
補(bǔ)充知識(shí):Keras_gan生成自己的數(shù)據(jù),并保存模型
我就廢話不多說了,大家還是直接看代碼吧~
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 3
self.img_cols = 60
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# 構(gòu)建和編譯判別器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 構(gòu)建生成器
self.generator = self.build_generator()
# 生成器輸入噪音,生成假的圖片
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# 為了組合模型,只訓(xùn)練生成器
self.discriminator.trainable = False
# 判別器將生成的圖像作為輸入并確定有效性
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# 訓(xùn)練生成器騙過判別器
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(64, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
#np.prod(self.img_shape)=3x60x1
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
#輸入噪音,輸出圖片
return Model(noise, img)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(64))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
############################################################
#自己數(shù)據(jù)集此部分需要更改
# 加載數(shù)據(jù)集
data = np.load('data/相對(duì)大小分叉.npy')
data = data[:,:,0:60]
# 歸一化到-1到1
data = data * 2 - 1
data = np.expand_dims(data, axis=3)
############################################################
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# 訓(xùn)練判別器
# ---------------------
# data.shape[0]為數(shù)據(jù)集的數(shù)量,隨機(jī)生成batch_size個(gè)數(shù)量的隨機(jī)數(shù),作為數(shù)據(jù)的索引
idx = np.random.randint(0, data.shape[0], batch_size)
#從數(shù)據(jù)集隨機(jī)挑選batch_size個(gè)數(shù)據(jù),作為一個(gè)批次訓(xùn)練
imgs = data[idx]
#噪音維度(batch_size,100)
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# 由生成器根據(jù)噪音生成假的圖片
gen_imgs = self.generator.predict(noise)
# 訓(xùn)練判別器,判別器希望真實(shí)圖片,打上標(biāo)簽1,假的圖片打上標(biāo)簽0
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# 訓(xùn)練生成器
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# 打印loss值
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# 沒sample_interval個(gè)epoch保存一次生成圖片
if epoch % sample_interval == 0:
self.sample_images(epoch)
if not os.path.exists("keras_model"):
os.makedirs("keras_model")
self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
def sample_images(self, epoch):
r, c = 10, 10
# 重新生成一批噪音,維度為(100,100)
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# 將生成的圖片重新歸整到0-1之間
gen = 0.5 * gen_imgs + 0.5
gen = gen.reshape(-1,3,60)
fig,axs = plt.subplots(r,c)
cnt = 0
for i in range(r):
for j in range(c):
xy = gen[cnt]
for k in range(len(xy)):
x = xy[k][0:30]
y = xy[k][30:60]
if k == 0:
axs[i,j].plot(x,y,color='blue')
if k == 1:
axs[i,j].plot(x,y,color='red')
if k == 2:
axs[i,j].plot(x,y,color='green')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.xticks(np.arange(0,1,0.1))
axs[i,j].axis('off')
cnt += 1
if not os.path.exists("keras_imgs"):
os.makedirs("keras_imgs")
fig.savefig("keras_imgs/%d.png" % epoch)
plt.close()
def test(self,gen_nums=100,save=False):
self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
gen = self.generator.predict(noise)
gen = 0.5 * gen + 0.5
gen = gen.reshape(-1,3,60)
print(gen.shape)
###############################################################
#直接可視化生成圖片
if save:
for i in range(0,len(gen)):
plt.figure(figsize=(128,128),dpi=1)
plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
plt.axis('off')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
if not os.path.exists("keras_gen"):
os.makedirs("keras_gen")
plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
plt.close()
##################################################################
#重整圖片到0-1
else:
for i in range(len(gen)):
plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.xticks(np.arange(0,1,0.1))
plt.show()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)
以上這篇Keras自動(dòng)下載的數(shù)據(jù)集/模型存放位置介紹就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
PyQt5打開文件對(duì)話框QFileDialog實(shí)例代碼
這篇文章主要介紹了PyQt5打開文件對(duì)話框QFileDialog實(shí)例代碼,分享了相關(guān)代碼示例,小編覺得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-02-02
Python中11種NumPy高級(jí)操作總結(jié)
熬夜整了了11種Numpy的高級(jí)操作,每一種都有參數(shù)解釋與小例子輔助說明。文中的示例代碼講解詳細(xì),感興趣的小伙伴快跟隨小編一起學(xué)習(xí)一下吧2022-05-05
10個(gè)Python常用的損失函數(shù)及代碼實(shí)現(xiàn)分享
損失函數(shù)是一種衡量模型與數(shù)據(jù)吻合程度的算法。損失函數(shù)測量實(shí)際測量值和預(yù)測值之間差距的一種方式。本文為大家總結(jié)了10個(gè)常用的損失函數(shù)及Python代碼實(shí)現(xiàn),需要的可以參考一下2022-09-09
Python讀取xlsx文件的實(shí)現(xiàn)方法
這篇文章主要介紹了Python讀取xlsx文件的實(shí)現(xiàn)方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
python操作xlsx文件的包openpyxl實(shí)例
下面小編就為大家分享一篇python操作xlsx文件的包openpyxl實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-05-05

