解決pytorch中的kl divergence計算問題
偶然從pytorch討論論壇中看到的一個問題,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中計算結(jié)果不同,平時沒有注意到,記錄下
kl divergence 介紹
KL散度( Kullback–Leibler divergence),又稱相對熵,是描述兩個概率分布 P 和 Q 差異的一種方法。計算公式:

可以發(fā)現(xiàn),P 和 Q 中元素的個數(shù)不用相等,只需要兩個分布中的離散元素一致。
舉個簡單例子:
兩個離散分布分布分別為 P 和 Q
P 的分布為:{1,1,2,2,3}
Q 的分布為:{1,1,1,1,1,2,3,3,3,3}
我們發(fā)現(xiàn),雖然兩個分布中元素個數(shù)不相同,P 的元素個數(shù)為 5,Q 的元素個數(shù)為 10。但里面的元素都有 “1”,“2”,“3” 這三個元素。
當(dāng) x = 1時,在 P 分布中,“1” 這個元素的個數(shù)為 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 這個元素的個數(shù)為 5,故 Q(x = 1) = 5/10 = 0.5
同理,
當(dāng) x = 2 時,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
當(dāng) x = 3 時,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
把上述概率帶入公式:

至此,就計算完成了兩個離散變量分布的KL散度。
pytorch 中的 kl_div 函數(shù)
pytorch中有用于計算kl散度的函數(shù) kl_div
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

計算 D (p||q)
1、不用這個函數(shù)的計算結(jié)果為:

與手算結(jié)果相同
2、使用函數(shù):
(這是計算正確的,結(jié)果有差異是因為pytorch這個函數(shù)中默認(rèn)的是以e為底)

注意:
1、函數(shù)中的 p q 位置相反(也就是想要計算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是選擇對各部分結(jié)果做什么操作,默認(rèn)為取平均數(shù),這里選擇求和
好別扭的用法,不知道為啥官方把它設(shè)計成這樣
補充:pytorch 的KL divergence的實現(xiàn)
看代碼吧~
import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
p = F.softmax(p_logit, dim=-1)
_kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
- F.log_softmax(q_logit, dim=-1)), 1)
return torch.mean(_kl)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn) PS 圖像調(diào)整中的亮度調(diào)整
這篇文章主要介紹了Python實現(xiàn) PS 圖像調(diào)整中的亮度調(diào)整 ,需要的朋友可以參考下2019-06-06
詳解python路徑拼接os.path.join()函數(shù)的用法
os.path.join()函數(shù):連接兩個或更多的路徑名組件。這篇文章主要介紹了python路徑拼接os.path.join()函數(shù)的用法,需要的朋友可以參考下2019-10-10
python如何調(diào)用php文件中的函數(shù)詳解
這篇文章主要給大家介紹了關(guān)于python如何調(diào)用php文件中函數(shù)的相關(guān)資料,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-12-12
解析pandas apply() 函數(shù)用法(推薦)
這篇文章主要介紹了pandas apply() 函數(shù)用法,大家需要掌握函數(shù)作為一個對象,能作為參數(shù)傳遞給其它函數(shù),也能作為函數(shù)的返回值,具體內(nèi)容詳情跟隨小編一起看看吧2021-10-10
Python解析器安裝指南分享(Mac/Windows/Linux)
這篇文章主要介紹了Python解析器安裝指南(Mac/Windows/Linux),具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2025-03-03
python多進程及通信實現(xiàn)異步任務(wù)的方法
這篇文章主要介紹了python多進程及通信實現(xiàn)異步任務(wù)需求,本人也是很少接觸多進程的場景,對于python多進程的使用也是比較陌生的。在接觸了一些多進程的業(yè)務(wù)場景下,對python多進程的使用進行了學(xué)習(xí),覺得很有必要進行一個梳理總結(jié),感興趣的朋友一起看看吧2022-05-05

