Skip to content

Example: q-FFL

This section describes how to use FLGo to implement algorithms that make changes during the communication phase. An example is used here of a method that has only been modified less in the communication phase, qffl, proposed by Li Tian et al. in 2019 and published in ICLR 2020 (link to paper), which aims to improve the fairness of federated learning. The following explains how to implement the algorithm with FLGo.

The algorithm is inspired by load balancing in the network and proposes a fairer optimization goal:

\[\min_w f_q(w)=\sum_{k=1}^m \frac{p_k}{q 1}F_k^{q 1}(w)\]

where \(q\) is an artificially set hyperparameter, \(F_k(w)\) is the local loss of user \(k\), and \(p_k\) is the original objective function weight of user \(k\). By observing the above goal, it can be found that as long as \(q>0\), each user's loss in that goal \(F'_k=\frac{F_k^{q 1}}{q 1}\) will have the property that as \(F_k\) increases, \(F'_k\) increases rapidly (greater than the growth rate of \(F_k\)), so that the global objective function \(f_q\) also increases rapidly. Therefore, in order to prevent \(f_q\) from skyrocketing, optimizing the objective function will be forced to automatically balance the loss value of different users to prevent the occurrence of any larger value, where \(q\) determines the growth rate of \(F'_k\), and the larger the \(q\), the stronger the fairness.

In order to optimize this fairness objective function, the authors propose the q-FedAVG algorithm, the core steps of which are as follows:

  1. After user \(k\) receives the global model, use the global model \(w^t\) to evaluate the loss of the local training set, and obtain \(F_k(w^t)\);

  2. User \(k\) trains the global model, obtains \(\bar{w}_k^{t 1}\), and calculates the following variables:

\[\Delta w_k^t=L(w^t-\bar{w}_k^{t 1})\approx\frac{1}{\eta}(w^t-\bar{w}_k^{t 1})\\\Delta_k^t=F_k^q(w^t) \Delta w_k^t\\h_k^t=qF_k^{q-1}(w^t)\|\Delta w_k^t\|^2 LF_k^q(w^t)\]
  1. Users upload \(h_k^t\) and \(\Delta_k^t\);

  2. The global model for server aggregation is:

\[w^{t 1}=w^t-\frac{\sum_{k\in S_t}\Delta_k^t}{\sum_{k\in S_t}h_k^t}\]

Implementation

Compared with the global model of fedavg communication, qffl communicates \(h_k^t\) and \(\Delta_k^t\), so complete the calculation of these two items in the pack function local of the client and modify the returned dictionary. In contrast, there are more than models in the package received by the server, so the keywords (dk and hk) are used to take out the results in the package, and the aggregation strategy is directly adjusted to the form of qffl in the iterate.

import flgo
import flgo.algorithm.fedbase as fedbase
import torch
import flgo.utils.fmodule as fmodule
import flgo.algorithm.fedavg as fedavg
import copy
import os

class Client(fedbase.BasicClient):
    def unpack(self, package):
        model = package['model']
        self.global_model = copy.deepcopy(model)
        return model

    def pack(self, model):
        Fk = self.test(self.global_model, 'train')['loss']+1e-8
        L = 1.0/self.learning_rate
        delta_wk = L*(self.global_model - model)
        dk = (Fk**self.q)*delta_wk
        hk = self.q*(Fk**(self.q-1))*(delta_wk.norm()**2) + L*(Fk**self.q)
        self.global_model = None
        return {'dk':dk, 'hk':hk}

class Server(fedbase.BasicServer):
    def initialize(self, *args, **kwargs):
        self.init_algo_para({'q': 1.0})

    def iterate(self):
        self.selected_clients = self.sample()
        res = self.communicate(self.selected_clients)
        self.model = self.model - fmodule._model_sum(res['dk'])/sum(res['hk'])
        return len(self.received_clients)>0

class qffl:
    Server = Server
    Client = Client

Experiment

task = './synthetic11_client100'
config = {'benchmark':{'name':'flgo.benchmark.synthetic_regression', 'para':{'alpha':1, 'beta':1, 'num_clients':100}}}
if not os.path.exists(task): flgo.gen_task(config, task_path = task)
option = {'num_rounds':2000, 'num_epochs':1, 'batch_size':10, 'learning_rate':0.1, 'gpu':0, 'proportion':0.1,'lr_scheduler':0}
fedavg_runner = flgo.init(task, fedavg, option=option)
qffl_runner = flgo.init(task, qffl, option=option)
fedavg_runner.run()
qffl_runner.run()
analysis_on_q = {
    'Selector':{
        'task': task,
        'header':['fedavg','qffl' ]
    },
    'Painter':{
        'Curve':[
            {'args':{'x': 'communication_round', 'y':'test_accuracy'},  'fig_option':{'title':'test accuracy on Synthetic(1,1)'}},
            {'args':{'x': 'communication_round', 'y':'std_valid_loss'}, 'fig_option':{'title':'std_valid_loss on Synthetic(1,1)'}},
            {'args':{'x': 'communication_round', 'y':'mean_valid_accuracy'},  'fig_option':{'title':'mean valid accuracy on Synthetic(1,1)'}},


        ]
    }
}
flgo.experiment.analyzer.show(analysis_on_q)