PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg
I. 前言
在之前的一篇博客聯(lián)邦學(xué)習(xí)基本算法FedAvg的代碼實(shí)現(xiàn)中利用numpy手搭神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)了FedAvg,手搭的神經(jīng)網(wǎng)絡(luò)效果已經(jīng)很好了,不過(guò)這還是屬于自己造輪子,建議優(yōu)先使用PyTorch來(lái)實(shí)現(xiàn)。
II. 數(shù)據(jù)介紹
聯(lián)邦學(xué)習(xí)中存在多個(gè)客戶端,每個(gè)客戶端都有自己的數(shù)據(jù)集,這個(gè)數(shù)據(jù)集他們是不愿意共享的。
本文選用的數(shù)據(jù)集為中國(guó)北方某城市十個(gè)區(qū)/縣從2016年到2019年三年的真實(shí)用電負(fù)荷數(shù)據(jù),采集時(shí)間間隔為1小時(shí),即每一天都有24個(gè)負(fù)荷值。
我們假設(shè)這10個(gè)地區(qū)的電力部門不愿意共享自己的數(shù)據(jù),但是他們又想得到一個(gè)由所有數(shù)據(jù)統(tǒng)一訓(xùn)練得到的全局模型。
除了電力負(fù)荷數(shù)據(jù)以外,還有一個(gè)備選數(shù)據(jù)集:風(fēng)功率數(shù)據(jù)集。兩個(gè)數(shù)據(jù)集通過(guò)參數(shù)type指定:type == 'load’表示負(fù)荷數(shù)據(jù),'wind’表示風(fēng)功率數(shù)據(jù)。
特征構(gòu)造
用某一時(shí)刻前24個(gè)時(shí)刻的負(fù)荷值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)(如溫度、濕度、壓強(qiáng)等)來(lái)預(yù)測(cè)該時(shí)刻的負(fù)荷值。
對(duì)于風(fēng)功率數(shù)據(jù),同樣使用某一時(shí)刻前24個(gè)時(shí)刻的風(fēng)功率值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)來(lái)預(yù)測(cè)該時(shí)刻的風(fēng)功率值。
各個(gè)地區(qū)應(yīng)該就如何制定特征集達(dá)成一致意見,本文使用的各個(gè)地區(qū)上的數(shù)據(jù)的特征是一致的,可以直接使用。
III. 聯(lián)邦學(xué)習(xí)
1. 整體框架
原始論文中提出的FedAvg的框架為:

客戶端模型采用PyTorch搭建:
class ANN(nn.Module):
def __init__(self, input_dim, name, B, E, type, lr):
super(ANN, self).__init__()
self.name = name
self.B = B
self.E = E
self.len = 0
self.type = type
self.lr = lr
self.loss = 0
self.fc1 = nn.Linear(input_dim, 20)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout()
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 20)
self.fc4 = nn.Linear(20, 1)
def forward(self, data):
x = self.fc1(data)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
x = self.fc3(x)
x = self.sigmoid(x)
x = self.fc4(x)
x = self.sigmoid(x)
return x
2. 服務(wù)器端
服務(wù)器端執(zhí)行以下步驟:

簡(jiǎn)單來(lái)說(shuō),每一輪通信時(shí)都只是選擇部分客戶端,這些客戶端利用本地的數(shù)據(jù)進(jìn)行參數(shù)更新,然后將更新后的參數(shù)傳給服務(wù)器,服務(wù)器匯總客戶端更新后的參數(shù)形成最新的全局參數(shù)。下一輪通信時(shí),服務(wù)器端將最新的參數(shù)分發(fā)給被選中的客戶端,進(jìn)行下一輪更新。
3. 客戶端
客戶端沒什么可說(shuō)的,就是利用本地?cái)?shù)據(jù)對(duì)神經(jīng)網(wǎng)絡(luò)模型的參數(shù)進(jìn)行更新。
IV. 代碼實(shí)現(xiàn)
1. 初始化
class FedAvg:
def __init__(self, options):
self.C = options['C']
self.E = options['E']
self.B = options['B']
self.K = options['K']
self.r = options['r']
self.input_dim = options['input_dim']
self.type = options['type']
self.lr = options['lr']
self.clients = options['clients']
self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device)
self.nns = []
for i in range(K):
temp = copy.deepcopy(self.nn)
temp.name = self.clients[i]
self.nns.append(temp)
參數(shù):
- K,客戶端數(shù)量,本文為10個(gè),也就是10個(gè)地區(qū)。
- C:選擇率,每一輪通信時(shí)都只是選擇C * K個(gè)客戶端。
- E:客戶端更新本地模型的參數(shù)時(shí),在本地?cái)?shù)據(jù)集上訓(xùn)練E輪。
- B:客戶端更新本地模型的參數(shù)時(shí),本地?cái)?shù)據(jù)集batch大小為B
- r:服務(wù)器端和客戶端一共進(jìn)行r輪通信。
- clients:客戶端集合。
- type:指定數(shù)據(jù)類型,負(fù)荷預(yù)測(cè)or風(fēng)功率預(yù)測(cè)。
- lr:學(xué)習(xí)率。
- input_dim:數(shù)據(jù)輸入維度。
- nn:全局模型。
- nns: 客戶端模型集合。
2. 服務(wù)器端
服務(wù)器端代碼如下:
def server(self):
for t in range(self.r):
print('第', t + 1, '輪通信:')
m = np.max([int(self.C * self.K), 1])
# sampling
index = random.sample(range(0, self.K), m)
# dispatch
self.dispatch(index)
# local updating
self.client_update(index)
# aggregation
self.aggregation(index)
# return global model
return self.nn
其中client_update(index):
def client_update(self, index): # update nn
for k in index:
self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index):
s = 0
for j in index:
# normal
s += self.nns[j].len
params = {}
with torch.no_grad():
for k, v in self.nns[0].named_parameters():
params[k] = copy.deepcopy(v)
params[k].zero_()
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
params[k] += v * (self.nns[j].len / s)
with torch.no_grad():
for k, v in self.nn.named_parameters():
v.copy_(params[k])
dispatch(index):
def dispatch(self, index):
params = {}
with torch.no_grad():
for k, v in self.nn.named_parameters():
params[k] = copy.deepcopy(v)
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
v.copy_(params[k])
下面對(duì)重要代碼進(jìn)行分析:
客戶端的選擇
m = np.max([int(self.C * self.K), 1]) index = random.sample(range(0, self.K), m)
index中存儲(chǔ)中m個(gè)0~10間的整數(shù),表示被選中客戶端的序號(hào)。
客戶端的更新
for k in index:
self.client_update(self.nns[k])
服務(wù)器端匯總客戶端模型的參數(shù)
關(guān)于模型匯總方式,可以參考一下我的另一篇文章:對(duì)FedAvg中模型聚合過(guò)程的理解。
當(dāng)然,這只是一種很簡(jiǎn)單的匯總方式,還有一些其他類型的匯總方式。
論文Electricity Consumer Characteristics Identification: A Federated Learning Approach中總結(jié)了三種匯總方式:
normal:原始論文中的方式,即根據(jù)樣本數(shù)量來(lái)決定客戶端參數(shù)在最終組合時(shí)所占比例。
LA:根據(jù)客戶端模型的損失占所有客戶端損失和的比重來(lái)決定最終組合時(shí)參數(shù)所占比例。
LS:根據(jù)損失與樣本數(shù)量的乘積所占的比重來(lái)決定。 將更新后的參數(shù)分發(fā)給被選中的客戶端
def dispatch(self, index):
params = {}
with torch.no_grad():
for k, v in self.nn.named_parameters():
params[k] = copy.deepcopy(v)
for j in index:
with torch.no_grad():
for k, v in self.nns[j].named_parameters():
v.copy_(params[k])
3. 客戶端
客戶端只需要利用本地?cái)?shù)據(jù)來(lái)進(jìn)行更新就行了:
def client_update(self, index): # update nn
for k in index:
self.nns[k] = train(self.nns[k])
其中train():
def train(ann):
ann.train()
# print(p)
if ann.type == 'load':
Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
else:
Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
ann.len = len(Dtr)
# print(len(Dtr))
loss_function = nn.MSELoss().to(device)
loss = 0
optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
for epoch in range(ann.E):
cnt = 0
for (seq, label) in Dtr:
cnt += 1
seq = seq.to(device)
label = label.to(device)
y_pred = ann(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch', epoch, ':', loss.item())
return ann
4. 測(cè)試
def global_test(self):
model = self.nn
model.eval()
c = clients if self.type == 'load' else clients_wind
for client in c:
model.name = client
test(model)
V. 實(shí)驗(yàn)及結(jié)果
本次實(shí)驗(yàn)的參數(shù)選擇為:
| K | C | E | B | r |
|---|---|---|---|---|
| 10 | 0.5 | 50 | 50 | 5 |
if __name__ == '__main__':
K, C, E, B, r = 10, 0.5, 50, 50, 5
type = 'load'
input_dim = 30 if type == 'load' else 28
_client = clients if type == 'load' else clients_wind
lr = 0.08
options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
'input_dim': input_dim, 'lr': lr}
fedavg = FedAvg(options)
fedavg.server()
fedavg.global_test()
各個(gè)客戶端單獨(dú)訓(xùn)練(訓(xùn)練50輪,batch大小為50)后在本地的測(cè)試集上的表現(xiàn)為:
| 客戶端編號(hào) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
|---|---|---|---|---|---|---|---|---|---|---|
| MAPE / % | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
可以看到,由于各個(gè)客戶端的數(shù)據(jù)都十分充足,所以每個(gè)客戶端自己訓(xùn)練的本地模型的預(yù)測(cè)精度已經(jīng)很高了。
服務(wù)器與客戶端通信5輪后,服務(wù)器上的全局模型在10個(gè)客戶端測(cè)試集上的表現(xiàn)如下所示:
| 客戶端編號(hào) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
|---|---|---|---|---|---|---|---|---|---|---|
| MAPE / % | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
可以看到,經(jīng)過(guò)聯(lián)邦學(xué)習(xí)框架得到全局模型在各個(gè)客戶端上表現(xiàn)同樣很好ÿ0c;這是因?yàn)槭畟€(gè)地區(qū)上的數(shù)據(jù)分布類似。
給出numpy和PyTorch的對(duì)比:
| 客戶端編號(hào) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
|---|---|---|---|---|---|---|---|---|---|---|
| 本地 | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
| numpy | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
| PyTorch | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
同樣本地模型的效果是最好的,PyTorch搭建的網(wǎng)絡(luò)和numpy搭建的網(wǎng)絡(luò)效果差不多,但推薦使用PyTorch,不要造輪子。
VI. 源碼及數(shù)據(jù)
我把數(shù)據(jù)和代碼放在了GitHub上:源碼及數(shù)據(jù),原創(chuàng)不易,下載時(shí)請(qǐng)隨手給個(gè)follow和star,感謝!
以上就是PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg的詳細(xì)內(nèi)容,更多關(guān)于PyTorch實(shí)現(xiàn)FedAvg算法的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python非標(biāo)準(zhǔn)時(shí)間的轉(zhuǎn)換
本文主要介紹了python非標(biāo)準(zhǔn)時(shí)間的轉(zhuǎn)換,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-07-07
python3 pillow生成簡(jiǎn)單驗(yàn)證碼圖片的示例
本篇文章主要介紹了python3 pillow生成簡(jiǎn)單驗(yàn)證碼圖片的示例,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2017-09-09
在django中查詢獲取數(shù)據(jù),get, filter,all(),values()操作
這篇文章主要介紹了在django中查詢獲取數(shù)據(jù),get, filter,all(),values()操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-08-08
如何使用?profile?進(jìn)行python代碼性能分析
對(duì)代碼優(yōu)化的前提是需要了解性能瓶頸在什么地方,程序運(yùn)行的主要時(shí)間是消耗在哪里,對(duì)于比較復(fù)雜的代碼可以借助一些工具來(lái)定位,python?內(nèi)置了豐富的性能分析工具,本文介紹如何使用profile進(jìn)行python代碼性能分析,感興趣的朋友一起看看吧2024-12-12
Python實(shí)現(xiàn)npy/mat文件的保存與讀取
除了常用的csv文件和excel文件之外,我們還可以通過(guò)Python把數(shù)據(jù)保存文npy文件格式和mat文件格式。本文為大家展示了實(shí)現(xiàn)npy文件與mat文件的保存與讀取的示例代碼,需要的可以參考一下2022-04-04
簡(jiǎn)單了解python PEP的一些知識(shí)
這篇文章主要介紹了簡(jiǎn)單了解python PEP的一些知識(shí),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07
Python自動(dòng)創(chuàng)建Markdown表格使用實(shí)例探究
Markdown表格是文檔中整理和展示數(shù)據(jù)的重要方式之一,然而,手動(dòng)編寫大型表格可能會(huì)費(fèi)時(shí)且容易出錯(cuò),本文將介紹如何使用Python自動(dòng)創(chuàng)建Markdown表格,通過(guò)示例代碼詳細(xì)展示各種場(chǎng)景下的創(chuàng)建方法,提高表格生成的效率2024-01-01
python之tensorflow手把手實(shí)例講解貓狗識(shí)別實(shí)現(xiàn)
要說(shuō)到深度學(xué)習(xí)圖像分類的經(jīng)典案例之一,那就是貓狗大戰(zhàn)了。貓和狗在外觀上的差別還是挺明顯的,無(wú)論是體型、四肢、臉龐和毛發(fā)等等, 都是能通過(guò)肉眼很容易區(qū)分的。那么如何讓機(jī)器來(lái)識(shí)別貓和狗呢?網(wǎng)上已經(jīng)有不少人寫過(guò)這案例了,我也來(lái)嘗試下練練手。2021-09-09
PyQt5結(jié)合matplotlib繪圖的實(shí)現(xiàn)示例
這篇文章主要介紹了PyQt5結(jié)合matplotlib繪圖的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09

