Example: q-FFL
This section describes how to use FLGo to implement algorithms that make changes during the communication phase. An example is used here of a method that has only been modified less in the communication phase, qffl, proposed by Li Tian et al. in 2019 and published in ICLR 2020 (link to paper), which aims to improve the fairness of federated learning. The following explains how to implement the algorithm with FLGo.
The algorithm is inspired by load balancing in the network and proposes a fairer optimization goal:
where \(q\) is an artificially set hyperparameter, \(F_k(w)\) is the local loss of user \(k\), and \(p_k\) is the original objective function weight of user \(k\). By observing the above goal, it can be found that as long as \(q>0\), each user's loss in that goal \(F'_k=\frac{F_k^{q 1}}{q 1}\) will have the property that as \(F_k\) increases, \(F'_k\) increases rapidly (greater than the growth rate of \(F_k\)), so that the global objective function \(f_q\) also increases rapidly. Therefore, in order to prevent \(f_q\) from skyrocketing, optimizing the objective function will be forced to automatically balance the loss value of different users to prevent the occurrence of any larger value, where \(q\) determines the growth rate of \(F'_k\), and the larger the \(q\), the stronger the fairness.
In order to optimize this fairness objective function, the authors propose the q-FedAVG algorithm, the core steps of which are as follows:
-
After user \(k\) receives the global model, use the global model \(w^t\) to evaluate the loss of the local training set, and obtain \(F_k(w^t)\);
-
User \(k\) trains the global model, obtains \(\bar{w}_k^{t 1}\), and calculates the following variables:
-
Users upload \(h_k^t\) and \(\Delta_k^t\);
-
The global model for server aggregation is:
Implementation
Compared with the global model of fedavg communication, qffl communicates \(h_k^t\) and \(\Delta_k^t\), so complete the calculation of these two items in the pack function local of the client and modify the returned dictionary. In contrast, there are more than models in the package received by the server, so the keywords (dk and hk) are used to take out the results in the package, and the aggregation strategy is directly adjusted to the form of qffl in the iterate.
import flgo
import flgo.algorithm.fedbase as fedbase
import torch
import flgo.utils.fmodule as fmodule
import flgo.algorithm.fedavg as fedavg
import copy
import os
class Client(fedbase.BasicClient):
def unpack(self, package):
model = package['model']
self.global_model = copy.deepcopy(model)
return model
def pack(self, model):
Fk = self.test(self.global_model, 'train')['loss']+1e-8
L = 1.0/self.learning_rate
delta_wk = L*(self.global_model - model)
dk = (Fk**self.q)*delta_wk
hk = self.q*(Fk**(self.q-1))*(delta_wk.norm()**2) + L*(Fk**self.q)
self.global_model = None
return {'dk':dk, 'hk':hk}
class Server(fedbase.BasicServer):
def initialize(self, *args, **kwargs):
self.init_algo_para({'q': 1.0})
def iterate(self):
self.selected_clients = self.sample()
res = self.communicate(self.selected_clients)
self.model = self.model - fmodule._model_sum(res['dk'])/sum(res['hk'])
return len(self.received_clients)>0
class qffl:
Server = Server
Client = Client
Experiment
task = './synthetic11_client100'
config = {'benchmark':{'name':'flgo.benchmark.synthetic_regression', 'para':{'alpha':1, 'beta':1, 'num_clients':100}}}
if not os.path.exists(task): flgo.gen_task(config, task_path = task)
option = {'num_rounds':2000, 'num_epochs':1, 'batch_size':10, 'learning_rate':0.1, 'gpu':0, 'proportion':0.1,'lr_scheduler':0}
fedavg_runner = flgo.init(task, fedavg, option=option)
qffl_runner = flgo.init(task, qffl, option=option)
fedavg_runner.run()
qffl_runner.run()
analysis_on_q = {
'Selector':{
'task': task,
'header':['fedavg','qffl' ]
},
'Painter':{
'Curve':[
{'args':{'x': 'communication_round', 'y':'test_accuracy'}, 'fig_option':{'title':'test accuracy on Synthetic(1,1)'}},
{'args':{'x': 'communication_round', 'y':'std_valid_loss'}, 'fig_option':{'title':'std_valid_loss on Synthetic(1,1)'}},
{'args':{'x': 'communication_round', 'y':'mean_valid_accuracy'}, 'fig_option':{'title':'mean valid accuracy on Synthetic(1,1)'}},
]
}
}
flgo.experiment.analyzer.show(analysis_on_q)