pytorch自定義二值化網(wǎng)絡(luò)層方式
任務(wù)要求:
自定義一個(gè)層主要是定義該層的實(shí)現(xiàn)函數(shù),只需要重載Function的forward和backward函數(shù)即可,如下:
import torch from torch.autograd import Function from torch.autograd import Variable
定義二值化函數(shù)
class BinarizedF(Function):
def forward(self, input):
self.save_for_backward(input)
a = torch.ones_like(input)
b = -torch.ones_like(input)
output = torch.where(input>=0,a,b)
return output
def backward(self, output_grad):
input, = self.saved_tensors
input_abs = torch.abs(input)
ones = torch.ones_like(input)
zeros = torch.zeros_like(input)
input_grad = torch.where(input_abs<=1,ones, zeros)
return input_grad
定義一個(gè)module
class BinarizedModule(nn.Module):
def __init__(self):
super(BinarizedModule, self).__init__()
self.BF = BinarizedF()
def forward(self,input):
print(input.shape)
output =self.BF(input)
return output
進(jìn)行測(cè)試
a = Variable(torch.randn(4,480,640), requires_grad=True) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print(a) print(a.grad)
其中, 二值化函數(shù)部分也可以按照方式寫,但是速度慢了0.05s
class BinarizedF(Function):
def forward(self, input):
self.save_for_backward(input)
output = torch.ones_like(input)
output[input<0] = -1
return output
def backward(self, output_grad):
input, = self.saved_tensors
input_grad = output_grad.clone()
input_abs = torch.abs(input)
input_grad[input_abs>1] = 0
return input_grad
以上這篇pytorch自定義二值化網(wǎng)絡(luò)層方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)深度遍歷和廣度遍歷的方法
今天小編就為大家分享一篇Python實(shí)現(xiàn)深度遍歷和廣度遍歷的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01
解決Python print 輸出文本顯示 gbk 編碼錯(cuò)誤問(wèn)題
這篇文章主要介紹了解決Python print 輸出文本顯示 gbk 編碼錯(cuò)誤問(wèn)題,本文給出了三種解決方法,需要的朋友可以參考下2018-07-07
Django 1.10以上版本 url 配置注意事項(xiàng)詳解
這篇文章主要介紹了Django 1.10以上版本 url 配置注意事項(xiàng)詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-08-08
pycharm遠(yuǎn)程連接服務(wù)器運(yùn)行pytorch的過(guò)程詳解
這篇文章主要介紹了在Linux環(huán)境下使用Anaconda管理不同版本的Python環(huán)境,并通過(guò)PyCharm遠(yuǎn)程連接服務(wù)器來(lái)運(yùn)行PyTorch的過(guò)程,包括安裝PyTorch、CUDA以及配置PyCharm遠(yuǎn)程開(kāi)發(fā)環(huán)境的詳細(xì)步驟,需要的朋友可以參考下2025-02-02
YOLOv5中SPP/SPPF結(jié)構(gòu)源碼詳析(內(nèi)含注釋分析)
其實(shí)關(guān)于YOLOv5的網(wǎng)絡(luò)結(jié)構(gòu)其實(shí)網(wǎng)上相關(guān)的講解已經(jīng)有很多了,但是覺(jué)著還是有必要再給大家介紹下,下面這篇文章主要給大家介紹了關(guān)于YOLOv5中SPP/SPPF結(jié)構(gòu)源碼的相關(guān)資料,需要的朋友可以參考下2022-05-05
python turtle工具繪制四葉草的實(shí)例分享
在本篇文章里小編給各位整理的是關(guān)于python turtle工具繪制四葉草的實(shí)例分享,有興趣的朋友們可以跟著學(xué)習(xí)下。2020-02-02

