Skip to content

Example: q-FFL

This section describes how to use FLGo to implement algorithms that make changes during the communication phase. An example is used here of a method that has only been modified less in the communication phase, qffl, proposed by Li Tian et al. in 2019 and published in ICLR 2020 (link to paper), which aims to improve the fairness of federated learning. The following explains how to implement the algorithm with FLGo.

The algorithm is inspired by load balancing in the network and proposes a fairer optimization goal:

minwfq(w)=k=1mpkq1Fkq1(w)

where q is an artificially set hyperparameter, Fk(w) is the local loss of user k, and pk is the original objective function weight of user k. By observing the above goal, it can be found that as long as q>0, each user's loss in that goal Fk=Fkq1q1 will have the property that as Fk increases, Fk increases rapidly (greater than the growth rate of Fk), so that the global objective function fq also increases rapidly. Therefore, in order to prevent fq from skyrocketing, optimizing the objective function will be forced to automatically balance the loss value of different users to prevent the occurrence of any larger value, where q determines the growth rate of Fk, and the larger the q, the stronger the fairness.

In order to optimize this fairness objective function, the authors propose the q-FedAVG algorithm, the core steps of which are as follows:

  1. After user k receives the global model, use the global model wt to evaluate the loss of the local training set, and obtain Fk(wt);

  2. User k trains the global model, obtains w¯kt1, and calculates the following variables:

Δwkt=L(wtw¯kt1)1η(wtw¯kt1)Δkt=Fkq(wt)Δwkthkt=qFkq1(wt)Δwkt2LFkq(wt)
  1. Users upload hkt and Δkt;

  2. The global model for server aggregation is:

wt1=wtkStΔktkSthkt

Implementation

Compared with the global model of fedavg communication, qffl communicates hkt and Δkt, so complete the calculation of these two items in the pack function local of the client and modify the returned dictionary. In contrast, there are more than models in the package received by the server, so the keywords (dk and hk) are used to take out the results in the package, and the aggregation strategy is directly adjusted to the form of qffl in the iterate.

import flgo
import flgo.algorithm.fedbase as fedbase
import torch
import flgo.utils.fmodule as fmodule
import flgo.algorithm.fedavg as fedavg
import copy
import os

class Client(fedbase.BasicClient):
    def unpack(self, package):
        model = package['model']
        self.global_model = copy.deepcopy(model)
        return model

    def pack(self, model):
        Fk = self.test(self.global_model, 'train')['loss']+1e-8
        L = 1.0/self.learning_rate
        delta_wk = L*(self.global_model - model)
        dk = (Fk**self.q)*delta_wk
        hk = self.q*(Fk**(self.q-1))*(delta_wk.norm()**2) + L*(Fk**self.q)
        self.global_model = None
        return {'dk':dk, 'hk':hk}

class Server(fedbase.BasicServer):
    def initialize(self, *args, **kwargs):
        self.init_algo_para({'q': 1.0})

    def iterate(self):
        self.selected_clients = self.sample()
        res = self.communicate(self.selected_clients)
        self.model = self.model - fmodule._model_sum(res['dk'])/sum(res['hk'])
        return len(self.received_clients)>0

class qffl:
    Server = Server
    Client = Client

Experiment

task = './synthetic11_client100'
config = {'benchmark':{'name':'flgo.benchmark.synthetic_regression', 'para':{'alpha':1, 'beta':1, 'num_clients':100}}}
if not os.path.exists(task): flgo.gen_task(config, task_path = task)
option = {'num_rounds':2000, 'num_epochs':1, 'batch_size':10, 'learning_rate':0.1, 'gpu':0, 'proportion':0.1,'lr_scheduler':0}
fedavg_runner = flgo.init(task, fedavg, option=option)
qffl_runner = flgo.init(task, qffl, option=option)
fedavg_runner.run()
qffl_runner.run()
analysis_on_q = {
    'Selector':{
        'task': task,
        'header':['fedavg','qffl' ]
    },
    'Painter':{
        'Curve':[
            {'args':{'x': 'communication_round', 'y':'test_accuracy'},  'fig_option':{'title':'test accuracy on Synthetic(1,1)'}},
            {'args':{'x': 'communication_round', 'y':'std_valid_loss'}, 'fig_option':{'title':'std_valid_loss on Synthetic(1,1)'}},
            {'args':{'x': 'communication_round', 'y':'mean_valid_accuracy'},  'fig_option':{'title':'mean valid accuracy on Synthetic(1,1)'}},


        ]
    }
}
flgo.experiment.analyzer.show(analysis_on_q)