您好,登錄后才能下訂單哦!
本文小編為大家詳細介紹“PyTorch怎么實現基本算法FedAvg”,內容詳細,步驟清晰,細節處理妥當,希望這篇“PyTorch怎么實現基本算法FedAvg”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來學習新知識吧。
聯邦學習中存在多個客戶端,每個客戶端都有自己的數據集,這個數據集他們是不愿意共享的。
本文選用的數據集為中國北方某城市十個區/縣從2016年到2019年三年的真實用電負荷數據,采集時間間隔為1小時,即每一天都有24個負荷值。
我們假設這10個地區的電力部門不愿意共享自己的數據,但是他們又想得到一個由所有數據統一訓練得到的全局模型。
除了電力負荷數據以外,還有一個備選數據集:風功率數據集。兩個數據集通過參數type指定:type == 'load’表示負荷數據,'wind’表示風功率數據。
用某一時刻前24個時刻的負荷值以及該時刻的相關氣象數據(如溫度、濕度、壓強等)來預測該時刻的負荷值。
對于風功率數據,同樣使用某一時刻前24個時刻的風功率值以及該時刻的相關氣象數據來預測該時刻的風功率值。
各個地區應該就如何制定特征集達成一致意見,本文使用的各個地區上的數據的特征是一致的,可以直接使用。
原始論文中提出的FedAvg的框架為:
客戶端模型采用PyTorch搭建:
class ANN(nn.Module): def __init__(self, input_dim, name, B, E, type, lr): super(ANN, self).__init__() self.name = name self.B = B self.E = E self.len = 0 self.type = type self.lr = lr self.loss = 0 self.fc1 = nn.Linear(input_dim, 20) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.dropout = nn.Dropout() self.fc2 = nn.Linear(20, 20) self.fc3 = nn.Linear(20, 20) self.fc4 = nn.Linear(20, 1) def forward(self, data): x = self.fc1(data) x = self.sigmoid(x) x = self.fc2(x) x = self.sigmoid(x) x = self.fc3(x) x = self.sigmoid(x) x = self.fc4(x) x = self.sigmoid(x) return x
服務器端執行以下步驟:
簡單來說,每一輪通信時都只是選擇部分客戶端,這些客戶端利用本地的數據進行參數更新,然后將更新后的參數傳給服務器,服務器匯總客戶端更新后的參數形成最新的全局參數。下一輪通信時,服務器端將最新的參數分發給被選中的客戶端,進行下一輪更新。
客戶端沒什么可說的,就是利用本地數據對神經網絡模型的參數進行更新。
class FedAvg: def __init__(self, options): self.C = options['C'] self.E = options['E'] self.B = options['B'] self.K = options['K'] self.r = options['r'] self.input_dim = options['input_dim'] self.type = options['type'] self.lr = options['lr'] self.clients = options['clients'] self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device) self.nns = [] for i in range(K): temp = copy.deepcopy(self.nn) temp.name = self.clients[i] self.nns.append(temp)
參數:
K,客戶端數量,本文為10個,也就是10個地區。
C:選擇率,每一輪通信時都只是選擇C * K個客戶端。
E:客戶端更新本地模型的參數時,在本地數據集上訓練E輪。
B:客戶端更新本地模型的參數時,本地數據集batch大小為B
r:服務器端和客戶端一共進行r輪通信。
clients:客戶端集合。
type:指定數據類型,負荷預測or風功率預測。
lr:學習率。
input_dim:數據輸入維度。
nn:全局模型。
nns: 客戶端模型集合。
服務器端代碼如下:
def server(self): for t in range(self.r): print('第', t + 1, '輪通信:') m = np.max([int(self.C * self.K), 1]) # sampling index = random.sample(range(0, self.K), m) # dispatch self.dispatch(index) # local updating self.client_update(index) # aggregation self.aggregation(index) # return global model return self.nn
其中client_update(index):
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index): s = 0 for j in index: # normal s += self.nns[j].len params = {} with torch.no_grad(): for k, v in self.nns[0].named_parameters(): params[k] = copy.deepcopy(v) params[k].zero_() for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): params[k] += v * (self.nns[j].len / s) with torch.no_grad(): for k, v in self.nn.named_parameters(): v.copy_(params[k])
dispatch(index):
def dispatch(self, index): params = {} with torch.no_grad(): for k, v in self.nn.named_parameters(): params[k] = copy.deepcopy(v) for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): v.copy_(params[k])
下面對重要代碼進行分析:
客戶端的選擇
m = np.max([int(self.C * self.K), 1]) index = random.sample(range(0, self.K), m)
index中存儲中m個0~10間的整數,表示被選中客戶端的序號。
客戶端的更新
for k in index: self.client_update(self.nns[k])
服務器端匯總客戶端模型的參數
關于模型匯總方式,可以參考一下我的另一篇文章:對FedAvg中模型聚合過程的理解。
當然,這只是一種很簡單的匯總方式,還有一些其他類型的匯總方式。
論文Electricity Consumer Characteristics Identification: A Federated Learning Approach中總結了三種匯總方式:
normal:原始論文中的方式,即根據樣本數量來決定客戶端參數在最終組合時所占比例。
LA:根據客戶端模型的損失占所有客戶端損失和的比重來決定最終組合時參數所占比例。
LS:根據損失與樣本數量的乘積所占的比重來決定。 將更新后的參數分發給被選中的客戶端
def dispatch(self, index): params = {} with torch.no_grad(): for k, v in self.nn.named_parameters(): params[k] = copy.deepcopy(v) for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): v.copy_(params[k])
客戶端只需要利用本地數據來進行更新就行了:
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
其中train():
def train(ann): ann.train() # print(p) if ann.type == 'load': Dtr, Dte = nn_seq(ann.name, ann.B, ann.type) else: Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type) ann.len = len(Dtr) # print(len(Dtr)) loss_function = nn.MSELoss().to(device) loss = 0 optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr) for epoch in range(ann.E): cnt = 0 for (seq, label) in Dtr: cnt += 1 seq = seq.to(device) label = label.to(device) y_pred = ann(seq) loss = loss_function(y_pred, label) optimizer.zero_grad() loss.backward() optimizer.step() print('epoch', epoch, ':', loss.item()) return ann
def global_test(self): model = self.nn model.eval() c = clients if self.type == 'load' else clients_wind for client in c: model.name = client test(model)
本次實驗的參數選擇為:
K | C | E | B | r |
---|---|---|---|---|
10 | 0.5 | 50 | 50 | 5 |
if __name__ == '__main__': K, C, E, B, r = 10, 0.5, 50, 50, 5 type = 'load' input_dim = 30 if type == 'load' else 28 _client = clients if type == 'load' else clients_wind lr = 0.08 options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client, 'input_dim': input_dim, 'lr': lr} fedavg = FedAvg(options) fedavg.server() fedavg.global_test()
各個客戶端單獨訓練(訓練50輪,batch大小為50)后在本地的測試集上的表現為:
客戶端編號 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
可以看到,由于各個客戶端的數據都十分充足,所以每個客戶端自己訓練的本地模型的預測精度已經很高了。
服務器與客戶端通信5輪后,服務器上的全局模型在10個客戶端測試集上的表現如下所示:
客戶端編號 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
可以看到,經過聯邦學習框架得到全局模型在各個客戶端上表現同樣很好ÿ0c;這是因為十個地區上的數據分布類似。
給出numpy和PyTorch的對比:
客戶端編號 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
本地 | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
numpy | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
PyTorch | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
同樣本地模型的效果是最好的,PyTorch搭建的網絡和numpy搭建的網絡效果差不多,但推薦使用PyTorch,不要造輪子。
讀到這里,這篇“PyTorch怎么實現基本算法FedAvg”文章已經介紹完畢,想要掌握這篇文章的知識點還需要大家自己動手實踐使用過才能領會,如果想了解更多相關內容的文章,歡迎關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。