Example: q-FFL
This section describes how to use FLGo to implement algorithms that make changes during the communication phase. An example is used here of a method that has only been modified less in the communication phase, qffl, proposed by Li Tian et al. in 2019 and published in ICLR 2020 (link to paper), which aims to improve the fairness of federated learning. The following explains how to implement the algorithm with FLGo.
The algorithm is inspired by load balancing in the network and proposes a fairer optimization goal:
where
In order to optimize this fairness objective function, the authors propose the q-FedAVG algorithm, the core steps of which are as follows:
-
After user
receives the global model, use the global model to evaluate the loss of the local training set, and obtain ; -
User
trains the global model, obtains , and calculates the following variables:
-
Users upload
and ; -
The global model for server aggregation is:
Implementation
Compared with the global model of fedavg communication, qffl communicates
import flgo
import flgo.algorithm.fedbase as fedbase
import torch
import flgo.utils.fmodule as fmodule
import flgo.algorithm.fedavg as fedavg
import copy
import os
class Client(fedbase.BasicClient):
def unpack(self, package):
model = package['model']
self.global_model = copy.deepcopy(model)
return model
def pack(self, model):
Fk = self.test(self.global_model, 'train')['loss']+1e-8
L = 1.0/self.learning_rate
delta_wk = L*(self.global_model - model)
dk = (Fk**self.q)*delta_wk
hk = self.q*(Fk**(self.q-1))*(delta_wk.norm()**2) + L*(Fk**self.q)
self.global_model = None
return {'dk':dk, 'hk':hk}
class Server(fedbase.BasicServer):
def initialize(self, *args, **kwargs):
self.init_algo_para({'q': 1.0})
def iterate(self):
self.selected_clients = self.sample()
res = self.communicate(self.selected_clients)
self.model = self.model - fmodule._model_sum(res['dk'])/sum(res['hk'])
return len(self.received_clients)>0
class qffl:
Server = Server
Client = Client
Experiment
task = './synthetic11_client100'
config = {'benchmark':{'name':'flgo.benchmark.synthetic_regression', 'para':{'alpha':1, 'beta':1, 'num_clients':100}}}
if not os.path.exists(task): flgo.gen_task(config, task_path = task)
option = {'num_rounds':2000, 'num_epochs':1, 'batch_size':10, 'learning_rate':0.1, 'gpu':0, 'proportion':0.1,'lr_scheduler':0}
fedavg_runner = flgo.init(task, fedavg, option=option)
qffl_runner = flgo.init(task, qffl, option=option)
fedavg_runner.run()
qffl_runner.run()
analysis_on_q = {
'Selector':{
'task': task,
'header':['fedavg','qffl' ]
},
'Painter':{
'Curve':[
{'args':{'x': 'communication_round', 'y':'test_accuracy'}, 'fig_option':{'title':'test accuracy on Synthetic(1,1)'}},
{'args':{'x': 'communication_round', 'y':'std_valid_loss'}, 'fig_option':{'title':'std_valid_loss on Synthetic(1,1)'}},
{'args':{'x': 'communication_round', 'y':'mean_valid_accuracy'}, 'fig_option':{'title':'mean valid accuracy on Synthetic(1,1)'}},
]
}
}
flgo.experiment.analyzer.show(analysis_on_q)