pytorch中的自定義反向傳播,求導(dǎo)實例
pytorch中自定義backward()函數(shù)。在圖像處理過程中,我們有時候會使用自己定義的算法處理圖像,這些算法多是基于numpy或者scipy等包。
那么如何將自定義算法的梯度加入到pytorch的計算圖中,能使用Loss.backward()操作自動求導(dǎo)并優(yōu)化呢。下面的代碼展示了這個功能`
import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
x_abs = np.abs(x)
if x_abs < 1 and x_abs >= 0:
y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
elif x_abs > 1 and x_abs < 2:
y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
else:
y = 0
return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
# data_in = data_in.detach().numpy()
self.grad = np.zeros(data_in.shape,dtype=np.float32)
obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
data_tmp = data_in.copy()
data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
print(data_tmp.shape)
for axis0 in range(obj_shape[0]):
f_0 = float(axis0) / scale - np.floor(axis0 / scale)
int_0 = int(axis0 / scale) + 2
axis0_weight = np.array(
[[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
for axis1 in range(obj_shape[1]):
f_1 = float(axis1) / scale - np.floor(axis1 / scale)
int_1 = int(axis1 / scale) + 2
axis1_weight = np.array(
[[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
for i in range(4):
for j in range(4):
nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
for ii in range(data_in.shape[2]):
self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
tmp = np.matmul(axis0_weight, nbr_pixel)
data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
# img = np.transpose(img[0, :, :, :], [1, 2, 0])
return data_obj
def forward(self,input):
print(type(input))
input_ = input.detach().numpy()
output = self.bicubic_interpolate(input_)
# return input.new(output)
return torch.Tensor(output)
def backward(self,grad_output):
print(self.grad.shape,grad_output.shape)
grad_output.detach().numpy()
grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
for i in range(self.grad.shape[0]):
for j in range(self.grad.shape[1]):
grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
grad_input = grad_output_tmp*self.grad
print(type(grad_input))
# return grad_output.new(grad_input)
return torch.Tensor(grad_input)
def bicubic(input):
return Bicubic()(input)
def main():
hr = Image.open('./baboon/baboon_hr.png').convert('L')
hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
hr.requires_grad = True
lr = bicubic(hr)
print(lr.is_leaf)
loss=torch.mean(lr)
loss.backward()
if __name__ =='__main__':
main()
要想實現(xiàn)自動求導(dǎo),必須同時實現(xiàn)forward(),backward()兩個函數(shù)。
1、從代碼中可以看出來,forward()函數(shù)是針對numpy數(shù)據(jù)操作,返回值再重新指定為torch.Tensor類型。因此就有這個問題出現(xiàn)了:forward輸入input被轉(zhuǎn)換為numpy類型,輸出轉(zhuǎn)換為tensor類型,那么輸出output的grad_fn參數(shù)是如何指定的呢。調(diào)試發(fā)現(xiàn),當(dāng)main()中hr的requires_grad被指定為True,即hr被指定為需要求導(dǎo)的葉子節(jié)點。只要Bicubic類繼承自torch.autograd.Function,那么output也就是代碼中的lr的grad_fn就會被指定為<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic這個類。
2、backward()為求導(dǎo)的函數(shù),gard_output是鏈?zhǔn)角髮?dǎo)法則的上一級的梯度,grad_input即為我們想要得到的梯度。只需要在輸入指定grad_output,在調(diào)用loss.backward()過程中的某一步會執(zhí)行到Bicubic的backwward()函數(shù)
以上這篇pytorch中的自定義反向傳播,求導(dǎo)實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
- 使用pytorch進行張量計算、自動求導(dǎo)和神經(jīng)網(wǎng)絡(luò)構(gòu)建功能
- pytorch如何定義新的自動求導(dǎo)函數(shù)
- 在?pytorch?中實現(xiàn)計算圖和自動求導(dǎo)
- Pytorch自動求導(dǎo)函數(shù)詳解流程以及與TensorFlow搭建網(wǎng)絡(luò)的對比
- 淺談Pytorch中的自動求導(dǎo)函數(shù)backward()所需參數(shù)的含義
- 關(guān)于PyTorch 自動求導(dǎo)機制詳解
- Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法
- 關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
相關(guān)文章
python 使用re.search()篩選后 選取部分結(jié)果的方法
今天小編就為大家分享一篇python 使用re.search()篩選后 選取部分結(jié)果的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-11-11
Tensorflow2.4從頭訓(xùn)練Word?Embedding實現(xiàn)文本分類
這篇文章主要為大家介紹了Tensorflow2.4從頭訓(xùn)練Word?Embedding實現(xiàn)文本分類,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-01-01
詳解如何使用numpy提高Python數(shù)據(jù)分析效率
NumPy是Python語言的一個第三方庫,其支持大量高維度數(shù)組與矩陣運算。本文主要為大家介紹了如何使用numpy提高python數(shù)據(jù)分析效率,需要的可以參考一下2023-04-04
詳解centos7+django+python3+mysql+阿里云部署項目全流程
這篇文章主要介紹了詳解centos7+django+python3+mysql+阿里云部署項目全流程,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11
Python 使用tf-idf算法計算文檔關(guān)鍵字權(quán)重并生成詞云的方法
這篇文章主要介紹了Python 使用tf-idf算法計算文檔關(guān)鍵字權(quán)重,并生成詞云,本文通過實例代碼給大家介紹的非常想詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-03-03
python教程網(wǎng)絡(luò)爬蟲及數(shù)據(jù)可視化原理解析
這篇文章主要為大家介紹了python教程中網(wǎng)絡(luò)爬蟲及數(shù)據(jù)可視化原理的示例解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步早日升職加薪2021-10-10

