Staleness
Staleness refers to when the server aggregates multiple models where there are models that were not trained on the latest model. For example, the server receives the model returned from the sampled users in the first round at the 10th aggregation round. At this point, the server can choose to discard it outright or find ways to exploit the stale information. In FLGo, the ability to receive stale models means that the server does not wait forever for the current sampled user to return all models before aggregating the models. Therefore, the server aggregation model needs to set a condition, that is, when the conditions are met, the server turns on model aggregation, and the next aggregation round (sampling, waiting, aggregation) is opened after aggregation. To achieve this, FLGo allows defining the behavior of the server at each moment, rather than the behavior of each turn, for more complex policy design. Here, I try to set up three aggregation conditions as an example and give the corresponding implementation: - Cond=0: always waiting for the selected client before aggregation - Cond=1: aggregating once the waiting time exceeds a specific value - Cond=2: aggregating once the number of received models is no smaller than K
The three ways will produce different degree of staleness when aggregating models, which is confirmed to harm FL performance. For example, the first way (i.e. cond=0) won't produce any staleness, but will largely increase the time cost at each communication round due to the waiting for the slowest clients. The second way costs a fixed time to wait for clients, and usually suffers non-trivial staleness. The third way will severely introduce preference towards clients that with fast responsiveness. Now we show how to implement the three ways in our FLGo:
import copy
import numpy as np
import flgo.utils.fmodule as fmodule
from flgo.algorithm.fedbase import BasicServer
from flgo.algorithm.fedbase import BasicClient
class Server(BasicServer):
def initialize(self, *args, **kwargs):
self.init_algo_para({'cond': 0, 'time_budget':100, 'K':10})
self.round_finished = True
self.buffer = {
'model': [],
'round': [],
't': [],
'client_id':[],
}
self.sampling_timestamp = 0
self.sample_option = 'uniform_available'
def pack(self, client_id, mtype=0, *args, **kwargs):
return {
'model': copy.deepcopy(self.model),
'round': self.current_round, # model version
}
def iterate(self):
# sampling clients to start a new round \ only listening for new coming models
if self.round_finished:
self.selected_clients = self.sample()
self.sampling_timestamp = self.gv.clock.current_time
self.round_finished = False
res = self.communicate(self.selected_clients, asynchronous=True)
else:
res = self.communicate([], asynchronous=True)
if res!={}:
self.buffer['model'].extend(res['model'])
self.buffer['round'].extend(res['round'])
self.buffer['t'].extend([self.gv.clock.current_time for _ in res['model']])
self.buffer['client_id'].extend(res['__cid'])
if self.aggregation_condition():
# update the global model
stale_clients = []
stale_rounds = []
for cid, round in zip(self.buffer['client_id'], self.buffer['round']):
if round<self.current_round:
stale_clients.append(cid)
stale_rounds.append(round)
if len(stale_rounds)>0:
self.gv.logger.info('Receiving stale models from clients: {}'.format(stale_clients))
self.gv.logger.info('The staleness are {}'.format([r-self.current_round for r in stale_rounds]))
self.gv.logger.info('Averaging Staleness: {}'.format(np.mean([r-self.current_round for r in stale_rounds])))
self.model = fmodule._model_average(self.buffer['model'])
self.round_finished = True
# clear buffer
for k in self.buffer.keys(): self.buffer[k] = []
return self.round_finished
def aggregation_condition(self):
if self.cond==0:
for cid in self.selected_clients:
if cid not in self.buffer['client_id']:
# aggregate only when receiving all the packages from selected clients
return False
return True
elif self.cond==1:
# aggregate if the time budget for waiting is exhausted
if self.gv.clock.current_time-self.sampling_timestamp>=self.time_budget or all([(cid in self.buffer['client_id']) for cid in self.selected_clients]):
if len(self.buffer['model'])>0:
return True
return False
elif self.cond==2:
# aggregate when the number of models in the buffer is larger than K
return len(self.buffer['client_id'])>=min(len(self.selected_clients), self.K)
class Client(BasicClient):
def unpack(self, received_pkg):
self.round = received_pkg['round']
return received_pkg['model']
def pack(self, model, *args, **kwargs):
return {
'model': model,
'round': self.round
}
if __name__ =='__main__':
import flgo
import flgo.benchmark.mnist_classification as mnist
import os
task = './mnist_100clients'
if not os.path.exists(task):
flgo.gen_task({'benchmark': mnist, 'partitioner': {'name': 'IIDPartitioner', 'para': {'num_clients': 100}}}, task)
class algo:
Server = Server
Client = Client
runner0 = flgo.init(task, algo, option={'num_rounds':10, 'algo_para':[0, 200, 0], "gpu": 0, 'proportion': 0.2, 'num_steps': 5, 'responsiveness': 'UNI-5-1000'})
runner0.run()
runner1 = flgo.init(task, algo, option={'num_rounds':10, 'algo_para':[1, 200, 0], "gpu": 0, 'proportion': 0.2, 'num_steps': 5, 'responsiveness': 'UNI-5-1000'})
runner1.run()
runner2 = flgo.init(task, algo, option={'num_rounds':10, 'algo_para':[2, 0, 10], "gpu": 0, 'proportion': 0.2, 'num_steps': 5, 'responsiveness': 'UNI-5-1000'})
runner2.run()
Experiment
We first create heterogeneity of responsiveness for different clients. Here we specify the responsing time of clients to obey distribution \(UNIFORM(5,1000)\) in option by setting the keyword 'responsiveness' as 'UNI-5-1000'. We conduct the simple experiment on i.i.d.-partitioned MNIST with 20 clients. Now let's see the information of different ways at the 6th rounds : Cond 0 2023-08-29 09:49:45,257 fedbase.py run [line:246] INFO --------------Round 6-------------- 2023-08-29 09:49:45,257 simple_logger.py log_once [line:14] INFO Current_time:5545 2023-08-29 09:49:46,919 simple_logger.py log_once [line:28] INFO test_accuracy 0.8531 2023-08-29 09:49:46,919 simple_logger.py log_once [line:28] INFO test_loss 0.6368 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO val_accuracy 0.8400 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO mean_val_accuracy 0.8400 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO std_val_accuracy 0.0508 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO val_loss 0.6688 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO mean_val_loss 0.6688 2023-08-29 09:49:46,920 simple_logger.py log_once [line:28] INFO std_val_loss 0.0898 2023-08-29 09:49:46,920 fedbase.py run [line:251] INFO Eval Time Cost: 1.6629s 2023-08-29 09:49:48,299 con_fedavg.py iterate [line:50] INFO Receiving stale models from clients: [] 2023-08-29 09:49:48,299 con_fedavg.py iterate [line:51] INFO The staleness are [] 2023-08-29 09:49:48,299 con_fedavg.py iterate [line:52] INFO Averaging Staleness: nan Cond 1 2023-08-29 09:50:17,109 fedbase.py run [line:246] INFO --------------Round 6-------------- 2023-08-29 09:50:17,109 simple_logger.py log_once [line:14] INFO Current_time:1206 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO test_accuracy 0.6874 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO test_loss 1.4880 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO val_accuracy 0.6745 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO mean_val_accuracy 0.6745 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO std_val_accuracy 0.0569 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO val_loss 1.5130 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO mean_val_loss 1.5130 2023-08-29 09:50:18,785 simple_logger.py log_once [line:28] INFO std_val_loss 0.0650 2023-08-29 09:50:18,785 fedbase.py run [line:251] INFO Eval Time Cost: 1.6764s 2023-08-29 09:50:20,072 con_fedavg.py iterate [line:50] INFO Receiving stale models from clients: [12, 1, 16, 93, 15, 60, 25, 54, 68, 46, 23, 10, 30, 87, 64, 66, 0, 47, 73, 51, 62, 26] 2023-08-29 09:50:20,072 con_fedavg.py iterate [line:51] INFO The staleness are [-1, -1, -2, -2, -2, -2, -4, -1, -2, -1, -1, -2, -3, -1, -1, -4, -4, -1, -4, -3, -1, -2] 2023-08-29 09:50:20,072 con_fedavg.py iterate [line:52] INFO Averaging Staleness: -2.0454545454545454 Cond 2 2023-08-29 09:50:48,869 fedbase.py run [line:246] INFO --------------Round 6-------------- 2023-08-29 09:50:48,869 simple_logger.py log_once [line:14] INFO Current_time:1113 2023-08-29 09:50:50,556 simple_logger.py log_once [line:28] INFO test_accuracy 0.7133 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO test_loss 1.5453 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO val_accuracy 0.6957 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO mean_val_accuracy 0.6957 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO std_val_accuracy 0.0558 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO val_loss 1.5653 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO mean_val_loss 1.5653 2023-08-29 09:50:50,557 simple_logger.py log_once [line:28] INFO std_val_loss 0.0565 2023-08-29 09:50:50,557 fedbase.py run [line:251] INFO Eval Time Cost: 1.6880s 2023-08-29 09:50:51,886 con_fedavg.py iterate [line:50] INFO Receiving stale models from clients: [83, 12, 76, 84, 71, 46, 16] 2023-08-29 09:50:51,886 con_fedavg.py iterate [line:51] INFO The staleness are [-4, -2, -4, -2, -3, -3, -4] 2023-08-29 09:50:51,886 con_fedavg.py iterate [line:52] INFO Averaging Staleness: -3.142857142857143 Cond 0 costs the longest time to communication 6 rounds, but enjoys the highest aggregation efficiency (i.e. testing accuracy 85% v.s. others). Both Cond 1 and Cond 2 produce non-trivial staleness, but significantly reduce the time cost for communication (i.e. nearly 1/5 compared to Cond 0).
The results suggest that it's emergency to develop more aggregation efficient and less time-cost method in practical FL.