Skip to content

flgo.algorithm.vflbase

ActiveParty

Bases: PassiveParty

This is the implementation of the active party in vertival FL. The active party owns the data label information and may also own parts of data features. If a active party owns data features, it is also a passive party simultaneously.

Parameters:

Name Type Description Default
option dict

running-time option

required
Source code in flgo\algorithm\vflbase.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class ActiveParty(PassiveParty):
    r"""
    This is the implementation of the active party in vertival FL. The active party owns
    the data label information and may also own parts of data features. If a active party owns
    data features, it is also a passive party simultaneously.

    Args:
        option (dict): running-time option
    """
    def __init__(self, option):
        super().__init__(option)
        self.actions = {0: self.forward, 1: self.backward,2:self.forward_test}
        self.device = torch.device('cpu') if option['server_with_cpu'] else self.gv.apply_for_device()
        self.calculator = self.gv.TaskCalculator(self.device, optimizer_name = option['optimizer'])
        # basic configuration
        self.task = option['task']
        self.eval_interval = option['eval_interval']
        self.num_parallels = option['num_parallels']
        # hyper-parameters during training process
        self.num_rounds = option['num_rounds']
        self.proportion = option['proportion']
        self.batch_size = option['batch_size']
        self.decay_rate = option['learning_rate_decay']
        self.lr_scheduler_type = option['lr_scheduler']
        self.lr = option['learning_rate']
        self.sample_option = option['sample']
        self.aggregation_option = option['aggregate']
        # systemic option
        self.tolerance_for_latency = 999999
        self.sending_package_buffer = [None for _ in range(9999)]
        # algorithm-dependent parameters
        self.algo_para = {}
        self.current_round = 1
        # all options
        self.option = option
        self.id = 0

    def communicate(self, selected_clients, mtype=0, asynchronous=False):
        """
        The whole simulating communication procedure with the selected clients.
        This part supports for simulating the client dropping out.
        Args:
            selected_clients: the clients to communicate with
        Returns:
            :the unpacked response from clients that is created ny self.unpack()
        """
        packages_received_from_clients = []
        received_package_buffer = {}
        communicate_clients = list(set(selected_clients))
        # prepare packages for clients
        for cid in communicate_clients:
            received_package_buffer[cid] = None
        # communicate with selected clients
        if self.num_parallels <= 1:
            # computing iteratively
            for client_id in communicate_clients:
                server_pkg = self.pack(client_id, mtype=mtype)
                server_pkg['__mtype__'] = mtype
                response_from_client_id = self.communicate_with(client_id, package=server_pkg)
                packages_received_from_clients.append(response_from_client_id)
        else:
            # computing in parallel with torch.multiprocessing
            pool = mp.Pool(self.num_parallels)
            for client_id in communicate_clients:
                server_pkg = self.pack(client_id, mtype=mtype)
                server_pkg['__mtype__'] = mtype
                self.clients[client_id].update_device(self.gv.apply_for_device())
                args = (int(client_id), server_pkg)
                packages_received_from_clients.append(pool.apply_async(self.communicate_with, args=args))
            pool.close()
            pool.join()
            packages_received_from_clients = list(map(lambda x: x.get(), packages_received_from_clients))
        for i,cid in enumerate(communicate_clients): received_package_buffer[cid] = packages_received_from_clients[i]
        packages_received_from_clients = [received_package_buffer[cid] for cid in selected_clients if received_package_buffer[cid]]
        self.received_clients = selected_clients
        return self.unpack(packages_received_from_clients)

    def unpack(self, packages_received_from_clients):
        """
        Unpack the information from the received packages. Return models and losses as default.
        Args:
            packages_received_from_clients (list of dict):
        Returns:
            res (dict): collections.defaultdict that contains several lists of the clients' reply
        """
        if len(packages_received_from_clients)==0: return collections.defaultdict(list)
        res = {pname:[] for pname in packages_received_from_clients[0]}
        for cpkg in packages_received_from_clients:
            for pname, pval in cpkg.items():
                res[pname].append(pval)
        return res

    def run(self):
        """
        Start the federated learning symtem where the global model is trained iteratively.
        """
        self.gv.logger.time_start('Total Time Cost')
        self.gv.logger.info("--------------Initial Evaluation--------------")
        self.gv.logger.time_start('Eval Time Cost')
        self.gv.logger.log_once()
        self.gv.logger.time_end('Eval Time Cost')
        while self.current_round <= self.num_rounds:
            # iterate
            updated = self.iterate()
            # using logger to evaluate the model if the model is updated
            if updated is True or updated is None:
                self.gv.logger.info("--------------Round {}--------------".format(self.current_round))
                # check log interval
                if self.gv.logger.check_if_log(self.current_round, self.eval_interval):
                    self.gv.logger.time_start('Eval Time Cost')
                    self.gv.logger.log_once()
                    self.gv.logger.time_end('Eval Time Cost')
                self.current_round += 1
        self.gv.logger.info("=================End==================")
        self.gv.logger.time_end('Total Time Cost')
        # save results as .json file
        self.gv.logger.save_output_as_json()
        return

    def iterate(self):
        r"""
        The standard VFL process.

         1. The active party first generates the batch information.

         2. Then, it collects activations from all the passive parties.

         3. Thirdly, it continues the forward passing and backward passing to update the decoder part of the model, and distributes the derivations to parties.

         4. Finally, each passive party will update its local_movielens_recommendation modules accoring to the derivations and activations.

        Returns:
            updated (bool): whether the model is updated in this iteration
        """
        self._data_type='train'
        self.crt_batch = self.get_batch_data()
        activations = self.communicate([p.id for p in self.parties], mtype=0)['activation']
        self.defusions = self.update_global_module(activations, self.global_module)
        _ = self.communicate([pid for pid in range(len(self.parties))], mtype=1)
        return True

    def pack(self, party_id, mtype=0):
        r"""
        Pack the necessary information to parties into packages.

        Args:
            party_id (int): the id of the party
            mtype (Any): the message type

        Returns:
            package (dict): the package
        """
        if mtype==0:
            return {'batch': self.crt_batch[2], 'data_type': self._data_type}
        elif mtype==1:
            return {'derivation': self.defusion[party_id]}
        elif mtype==2:
            return {'batch': self.crt_test_batch[2], 'data_type': self._data_type}

    def get_batch_data(self):
        """
        Get the batch of data
        Returns:
            batch_data (Any): a batch of data
        """
        try:
            batch_data = next(self.data_loader)
        except:
            self.data_loader = iter(self.calculator.get_dataloader(self.train_data, batch_size=self.batch_size))
            batch_data = next(self.data_loader)
        return batch_data

    def update_global_module(self, activations:list, model:torch.nn.Module|flgo.utils.fmodule.FModule):
        r"""
        Update the global module by computing the forward passing and the backward passing. The attribute
        self.defusion and self.fusion.grad will be changed after calling this method.

        Args:
            activations (list): a list of activations from all the passive parties
            model (torch.nn.Module|flgo.utils.fmodule.FModule): the model
        """
        self.fusion = self.fuse(activations)
        self.fusion.requires_grad=True
        optimizer = self.calculator.get_optimizer(self.global_module, lr=self.lr)
        loss = self.calculator.compute_loss(model, (self.fusion, self.crt_batch[1]))['loss']
        loss.backward()
        optimizer.step()
        self.defusion = self.defuse(self.fusion)

    def fuse(self, activations:list):
        r"""
        Fuse the activations into one.

        Args:
            activations (list): a list of activations from all the passive parties

        Returns:
            fusion (Any): the fused result
        """
        return torch.stack(activations).mean(dim=0)

    def defuse(self, fusion):
        r"""
        Defuse the fusion into derivations.

        Args:
            fusion (Any): the fused result

        Returns:
            derivations (list): a list of derivations
        """
        return [fusion.grad for _ in self.parties]

    def test(self, flag:str='test') -> dict:
        r"""
        Test the performance of the model

        Args:
            flag (str): the type of dataset

        Returns:
            result (dict): a dict that contains the testing result
        """
        self.set_model_mode('eval')
        flag_dict = {'test':self.test_data, 'train':self.train_data, 'val':self.val_data}
        dataset = flag_dict[flag]
        self._data_type = flag
        dataloader = self.calculator.get_dataloader(dataset, batch_size=128)
        total_loss = 0.0
        num_correct = 0
        for batch_id, batch_data in enumerate(dataloader):
            self.crt_test_batch = batch_data
            activations = self.communicate([pid for pid in range(len(self.parties))], mtype=2)['activation']
            fusion = self.fuse(activations)
            outputs = self.global_module(fusion.to(self.device))
            batch_mean_loss = self.calculator.criterion(outputs, batch_data[1].to(self.device)).item()
            y_pred = outputs.data.max(1, keepdim=True)[1].cpu()
            correct = y_pred.eq(batch_data[1].data.view_as(y_pred)).long().cpu().sum()
            num_correct += correct.item()
            total_loss += batch_mean_loss * len(batch_data[1])
        self.set_model_mode('train')
        return {'accuracy': 1.0 * num_correct / len(dataset), 'loss': total_loss / len(dataset)}

    def set_model_mode(self,mode = 'train'):
        r"""
        Set all the modes of the modules owned by all the parties.

        Args:
            mode (str): the mode of models
        """
        for party in self.parties:
            if hasattr(party, 'local_module') and party.local_module is not None:
                if mode == 'train':
                    party.local_module.train()
                else:
                    party.local_module.eval()
            if hasattr(party, 'global_module') and party.global_module is not None:
                if mode == 'train':
                    party.global_module.train()
                else:
                    party.global_module.eval()

communicate(selected_clients, mtype=0, asynchronous=False)

The whole simulating communication procedure with the selected clients. This part supports for simulating the client dropping out.

Parameters:

Name Type Description Default
selected_clients

the clients to communicate with

required

Returns:

Type Description

the unpacked response from clients that is created ny self.unpack()

Source code in flgo\algorithm\vflbase.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def communicate(self, selected_clients, mtype=0, asynchronous=False):
    """
    The whole simulating communication procedure with the selected clients.
    This part supports for simulating the client dropping out.
    Args:
        selected_clients: the clients to communicate with
    Returns:
        :the unpacked response from clients that is created ny self.unpack()
    """
    packages_received_from_clients = []
    received_package_buffer = {}
    communicate_clients = list(set(selected_clients))
    # prepare packages for clients
    for cid in communicate_clients:
        received_package_buffer[cid] = None
    # communicate with selected clients
    if self.num_parallels <= 1:
        # computing iteratively
        for client_id in communicate_clients:
            server_pkg = self.pack(client_id, mtype=mtype)
            server_pkg['__mtype__'] = mtype
            response_from_client_id = self.communicate_with(client_id, package=server_pkg)
            packages_received_from_clients.append(response_from_client_id)
    else:
        # computing in parallel with torch.multiprocessing
        pool = mp.Pool(self.num_parallels)
        for client_id in communicate_clients:
            server_pkg = self.pack(client_id, mtype=mtype)
            server_pkg['__mtype__'] = mtype
            self.clients[client_id].update_device(self.gv.apply_for_device())
            args = (int(client_id), server_pkg)
            packages_received_from_clients.append(pool.apply_async(self.communicate_with, args=args))
        pool.close()
        pool.join()
        packages_received_from_clients = list(map(lambda x: x.get(), packages_received_from_clients))
    for i,cid in enumerate(communicate_clients): received_package_buffer[cid] = packages_received_from_clients[i]
    packages_received_from_clients = [received_package_buffer[cid] for cid in selected_clients if received_package_buffer[cid]]
    self.received_clients = selected_clients
    return self.unpack(packages_received_from_clients)

defuse(fusion)

Defuse the fusion into derivations.

Parameters:

Name Type Description Default
fusion Any

the fused result

required

Returns:

Name Type Description
derivations list

a list of derivations

Source code in flgo\algorithm\vflbase.py
302
303
304
305
306
307
308
309
310
311
312
def defuse(self, fusion):
    r"""
    Defuse the fusion into derivations.

    Args:
        fusion (Any): the fused result

    Returns:
        derivations (list): a list of derivations
    """
    return [fusion.grad for _ in self.parties]

fuse(activations)

Fuse the activations into one.

Parameters:

Name Type Description Default
activations list

a list of activations from all the passive parties

required

Returns:

Name Type Description
fusion Any

the fused result

Source code in flgo\algorithm\vflbase.py
290
291
292
293
294
295
296
297
298
299
300
def fuse(self, activations:list):
    r"""
    Fuse the activations into one.

    Args:
        activations (list): a list of activations from all the passive parties

    Returns:
        fusion (Any): the fused result
    """
    return torch.stack(activations).mean(dim=0)

get_batch_data()

Get the batch of data

Returns:

Name Type Description
batch_data Any

a batch of data

Source code in flgo\algorithm\vflbase.py
260
261
262
263
264
265
266
267
268
269
270
271
def get_batch_data(self):
    """
    Get the batch of data
    Returns:
        batch_data (Any): a batch of data
    """
    try:
        batch_data = next(self.data_loader)
    except:
        self.data_loader = iter(self.calculator.get_dataloader(self.train_data, batch_size=self.batch_size))
        batch_data = next(self.data_loader)
    return batch_data

iterate()

The standard VFL process.

  1. The active party first generates the batch information.

  2. Then, it collects activations from all the passive parties.

  3. Thirdly, it continues the forward passing and backward passing to update the decoder part of the model, and distributes the derivations to parties.

  4. Finally, each passive party will update its local_movielens_recommendation modules accoring to the derivations and activations.

Returns:

Name Type Description
updated bool

whether the model is updated in this iteration

Source code in flgo\algorithm\vflbase.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def iterate(self):
    r"""
    The standard VFL process.

     1. The active party first generates the batch information.

     2. Then, it collects activations from all the passive parties.

     3. Thirdly, it continues the forward passing and backward passing to update the decoder part of the model, and distributes the derivations to parties.

     4. Finally, each passive party will update its local_movielens_recommendation modules accoring to the derivations and activations.

    Returns:
        updated (bool): whether the model is updated in this iteration
    """
    self._data_type='train'
    self.crt_batch = self.get_batch_data()
    activations = self.communicate([p.id for p in self.parties], mtype=0)['activation']
    self.defusions = self.update_global_module(activations, self.global_module)
    _ = self.communicate([pid for pid in range(len(self.parties))], mtype=1)
    return True

pack(party_id, mtype=0)

Pack the necessary information to parties into packages.

Parameters:

Name Type Description Default
party_id int

the id of the party

required
mtype Any

the message type

0

Returns:

Name Type Description
package dict

the package

Source code in flgo\algorithm\vflbase.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def pack(self, party_id, mtype=0):
    r"""
    Pack the necessary information to parties into packages.

    Args:
        party_id (int): the id of the party
        mtype (Any): the message type

    Returns:
        package (dict): the package
    """
    if mtype==0:
        return {'batch': self.crt_batch[2], 'data_type': self._data_type}
    elif mtype==1:
        return {'derivation': self.defusion[party_id]}
    elif mtype==2:
        return {'batch': self.crt_test_batch[2], 'data_type': self._data_type}

run()

Start the federated learning symtem where the global model is trained iteratively.

Source code in flgo\algorithm\vflbase.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def run(self):
    """
    Start the federated learning symtem where the global model is trained iteratively.
    """
    self.gv.logger.time_start('Total Time Cost')
    self.gv.logger.info("--------------Initial Evaluation--------------")
    self.gv.logger.time_start('Eval Time Cost')
    self.gv.logger.log_once()
    self.gv.logger.time_end('Eval Time Cost')
    while self.current_round <= self.num_rounds:
        # iterate
        updated = self.iterate()
        # using logger to evaluate the model if the model is updated
        if updated is True or updated is None:
            self.gv.logger.info("--------------Round {}--------------".format(self.current_round))
            # check log interval
            if self.gv.logger.check_if_log(self.current_round, self.eval_interval):
                self.gv.logger.time_start('Eval Time Cost')
                self.gv.logger.log_once()
                self.gv.logger.time_end('Eval Time Cost')
            self.current_round += 1
    self.gv.logger.info("=================End==================")
    self.gv.logger.time_end('Total Time Cost')
    # save results as .json file
    self.gv.logger.save_output_as_json()
    return

set_model_mode(mode='train')

Set all the modes of the modules owned by all the parties.

Parameters:

Name Type Description Default
mode str

the mode of models

'train'
Source code in flgo\algorithm\vflbase.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def set_model_mode(self,mode = 'train'):
    r"""
    Set all the modes of the modules owned by all the parties.

    Args:
        mode (str): the mode of models
    """
    for party in self.parties:
        if hasattr(party, 'local_module') and party.local_module is not None:
            if mode == 'train':
                party.local_module.train()
            else:
                party.local_module.eval()
        if hasattr(party, 'global_module') and party.global_module is not None:
            if mode == 'train':
                party.global_module.train()
            else:
                party.global_module.eval()

test(flag='test')

Test the performance of the model

Parameters:

Name Type Description Default
flag str

the type of dataset

'test'

Returns:

Name Type Description
result dict

a dict that contains the testing result

Source code in flgo\algorithm\vflbase.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def test(self, flag:str='test') -> dict:
    r"""
    Test the performance of the model

    Args:
        flag (str): the type of dataset

    Returns:
        result (dict): a dict that contains the testing result
    """
    self.set_model_mode('eval')
    flag_dict = {'test':self.test_data, 'train':self.train_data, 'val':self.val_data}
    dataset = flag_dict[flag]
    self._data_type = flag
    dataloader = self.calculator.get_dataloader(dataset, batch_size=128)
    total_loss = 0.0
    num_correct = 0
    for batch_id, batch_data in enumerate(dataloader):
        self.crt_test_batch = batch_data
        activations = self.communicate([pid for pid in range(len(self.parties))], mtype=2)['activation']
        fusion = self.fuse(activations)
        outputs = self.global_module(fusion.to(self.device))
        batch_mean_loss = self.calculator.criterion(outputs, batch_data[1].to(self.device)).item()
        y_pred = outputs.data.max(1, keepdim=True)[1].cpu()
        correct = y_pred.eq(batch_data[1].data.view_as(y_pred)).long().cpu().sum()
        num_correct += correct.item()
        total_loss += batch_mean_loss * len(batch_data[1])
    self.set_model_mode('train')
    return {'accuracy': 1.0 * num_correct / len(dataset), 'loss': total_loss / len(dataset)}

unpack(packages_received_from_clients)

Unpack the information from the received packages. Return models and losses as default.

Parameters:

Name Type Description Default
packages_received_from_clients list of dict required

Returns:

Name Type Description
res dict

collections.defaultdict that contains several lists of the clients' reply

Source code in flgo\algorithm\vflbase.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def unpack(self, packages_received_from_clients):
    """
    Unpack the information from the received packages. Return models and losses as default.
    Args:
        packages_received_from_clients (list of dict):
    Returns:
        res (dict): collections.defaultdict that contains several lists of the clients' reply
    """
    if len(packages_received_from_clients)==0: return collections.defaultdict(list)
    res = {pname:[] for pname in packages_received_from_clients[0]}
    for cpkg in packages_received_from_clients:
        for pname, pval in cpkg.items():
            res[pname].append(pval)
    return res

update_global_module(activations, model)

Update the global module by computing the forward passing and the backward passing. The attribute self.defusion and self.fusion.grad will be changed after calling this method.

Parameters:

Name Type Description Default
activations list

a list of activations from all the passive parties

required
model torch.nn.Module | flgo.utils.fmodule.FModule

the model

required
Source code in flgo\algorithm\vflbase.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def update_global_module(self, activations:list, model:torch.nn.Module|flgo.utils.fmodule.FModule):
    r"""
    Update the global module by computing the forward passing and the backward passing. The attribute
    self.defusion and self.fusion.grad will be changed after calling this method.

    Args:
        activations (list): a list of activations from all the passive parties
        model (torch.nn.Module|flgo.utils.fmodule.FModule): the model
    """
    self.fusion = self.fuse(activations)
    self.fusion.requires_grad=True
    optimizer = self.calculator.get_optimizer(self.global_module, lr=self.lr)
    loss = self.calculator.compute_loss(model, (self.fusion, self.crt_batch[1]))['loss']
    loss.backward()
    optimizer.step()
    self.defusion = self.defuse(self.fusion)

PassiveParty

Bases: BasicParty

This is the implementation of the passive party in vertival FL. The passive party owns only a part of data features without label information.

Parameters:

Name Type Description Default
option dict

running-time option

required
Source code in flgo\algorithm\vflbase.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class PassiveParty(BasicParty):
    r"""This is the implementation of the passive party in vertival FL.
    The passive party owns only a part of data features without label information.

    Args:
        option (dict): running-time option
    """
    def __init__(self, option:dict):
        super().__init__()
        self.option = option
        self.actions = {0: self.forward, 1:self.backward, 2:self.forward_test}
        self.id = None
        # create local_movielens_recommendation dataset
        self.data_loader = None
        # local_movielens_recommendation calculator
        self.device = self.gv.apply_for_device()
        self.calculator = self.gv.TaskCalculator(self.device, option['optimizer'])
        # hyper-parameters for training
        self.optimizer_name = option['optimizer']
        self.lr = option['learning_rate']
        self.momentum = option['momentum']
        self.weight_decay = option['weight_decay']
        self.batch_size = option['batch_size']
        self.num_steps = option['num_steps']
        self.num_epochs = option['num_epochs']
        self.model = None
        self.test_batch_size = option['test_batch_size']
        self.loader_num_workers = option['num_workers']
        self.current_steps = 0
        # system setting
        self._effective_num_steps = self.num_steps
        self._latency = 0

    def forward(self, package:dict={}):
        r"""
        Local forward to computing the activations on local_movielens_recommendation features

        Args:
            package (dict): the package from the active party that contains batch information and the type of data

        Returns:
            passive_package (dict): the package that contains the activation to be sent to the active party
        """
        batch_ids = package['batch']
        tmp = {'train': self.train_data, 'val': self.val_data, 'test':self.test_data}
        dataset = tmp[package['data_type']]
        # select samples in batch
        self.activation = self.local_module(dataset.get_batch_by_id(batch_ids)[0].to(self.device))
        return {'activation': self.activation.clone().detach()}

    def backward(self, package):
        r"""
        Local backward to computing the gradients on local_movielens_recommendation modules

        Args:
            package (dict): the package from the active party that contains the derivations
        """
        derivation = package['derivation']
        self.update_local_module(derivation, self.activation)
        return

    def update_local_module(self, derivation, activation):
        r"""
        Update local_movielens_recommendation modules according to the derivation and the activation

        Args:
            derivation (Any): the derivation from the active party
            activation (Any): the local_movielens_recommendation computed activation
        """
        optimizer = self.calculator.get_optimizer(self.local_module, self.lr)
        loss_surrogat = (derivation*activation).sum()
        loss_surrogat.backward()
        optimizer.step()
        return

    def forward_test(self, package):
        r"""
        Local forward to computing the activations on local_movielens_recommendation features for testing

        Args:
            package (dict): the package from the active party that contains batch information and the type of data

        Returns:
            passive_package (dict): the package that contains the activation to be sent to the active party
        """
        batch_ids = package['batch']
        tmp = {'train': self.train_data, 'val': self.val_data, 'test':self.test_data}
        dataset = tmp[package['data_type']]
        # select samples in batch
        self.activation = self.local_module(dataset.get_batch_by_id(batch_ids)[0].to(self.device))
        return {'activation': self.activation}

backward(package)

Local backward to computing the gradients on local_movielens_recommendation modules

Parameters:

Name Type Description Default
package dict

the package from the active party that contains the derivations

required
Source code in flgo\algorithm\vflbase.py
59
60
61
62
63
64
65
66
67
68
def backward(self, package):
    r"""
    Local backward to computing the gradients on local_movielens_recommendation modules

    Args:
        package (dict): the package from the active party that contains the derivations
    """
    derivation = package['derivation']
    self.update_local_module(derivation, self.activation)
    return

forward(package={})

Local forward to computing the activations on local_movielens_recommendation features

Parameters:

Name Type Description Default
package dict

the package from the active party that contains batch information and the type of data

{}

Returns:

Name Type Description
passive_package dict

the package that contains the activation to be sent to the active party

Source code in flgo\algorithm\vflbase.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def forward(self, package:dict={}):
    r"""
    Local forward to computing the activations on local_movielens_recommendation features

    Args:
        package (dict): the package from the active party that contains batch information and the type of data

    Returns:
        passive_package (dict): the package that contains the activation to be sent to the active party
    """
    batch_ids = package['batch']
    tmp = {'train': self.train_data, 'val': self.val_data, 'test':self.test_data}
    dataset = tmp[package['data_type']]
    # select samples in batch
    self.activation = self.local_module(dataset.get_batch_by_id(batch_ids)[0].to(self.device))
    return {'activation': self.activation.clone().detach()}

forward_test(package)

Local forward to computing the activations on local_movielens_recommendation features for testing

Parameters:

Name Type Description Default
package dict

the package from the active party that contains batch information and the type of data

required

Returns:

Name Type Description
passive_package dict

the package that contains the activation to be sent to the active party

Source code in flgo\algorithm\vflbase.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def forward_test(self, package):
    r"""
    Local forward to computing the activations on local_movielens_recommendation features for testing

    Args:
        package (dict): the package from the active party that contains batch information and the type of data

    Returns:
        passive_package (dict): the package that contains the activation to be sent to the active party
    """
    batch_ids = package['batch']
    tmp = {'train': self.train_data, 'val': self.val_data, 'test':self.test_data}
    dataset = tmp[package['data_type']]
    # select samples in batch
    self.activation = self.local_module(dataset.get_batch_by_id(batch_ids)[0].to(self.device))
    return {'activation': self.activation}

update_local_module(derivation, activation)

Update local_movielens_recommendation modules according to the derivation and the activation

Parameters:

Name Type Description Default
derivation Any

the derivation from the active party

required
activation Any

the local_movielens_recommendation computed activation

required
Source code in flgo\algorithm\vflbase.py
70
71
72
73
74
75
76
77
78
79
80
81
82
def update_local_module(self, derivation, activation):
    r"""
    Update local_movielens_recommendation modules according to the derivation and the activation

    Args:
        derivation (Any): the derivation from the active party
        activation (Any): the local_movielens_recommendation computed activation
    """
    optimizer = self.calculator.get_optimizer(self.local_module, self.lr)
    loss_surrogat = (derivation*activation).sum()
    loss_surrogat.backward()
    optimizer.step()
    return