Skip to content

Example: MNIST

1. Prepare the following dataset configuration file my_dataset.py (remark: The name can be arbitrarily set)

The constructed dataset configuration .py file needs to have variables: train_data, test_data (optional), and function get_model()

import os
import torchvision
import torch

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
# Define Variable: train_data
train_data = torchvision.datasets.MNIST(root=path, train=True, download=True, transform=transform)
# Define Variable: test_data
test_data = torchvision.datasets.MNIST(root=path, train=False, download=True, transform=transform)

class mlp(torch.nn.Module):
    def __init__(self):
        super(mlp, self).__init__()
        self.fc1 = torch.nn.Linear(784, 200)
        self.fc2 = torch.nn.Linear(200, 200)
        self.fc3 = torch.nn.Linear(200, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define Model: get_model()
def get_model():
    return mlp()

2. Using flgo.gen_benchmark to generate benchmark from configuration

import flgo
import os
bmkname = 'my_mnist_classification'           # the name of benchmark
bmk_config = './my_dataset.py'                # the path of the configuration file
# Constructing benchmark by flgo.gen_benchmark
if not os.path.exists(bmkname): 
    bmk = flgo.gen_benchmark_from_file(bmkname, bmk_config, target_path='.', data_type='cv', task_type='classification')
print(bmk)

3. Test

Use the constructed benchmark as other plugins.

import flgo.algorithm.fedavg as fedavg
bmk = 'my_mnist_classification'
task = './my_mnist'
task_config = {
    'benchmark':bmk,
}
if not os.path.exists(task): flgo.gen_task(task_config, task_path=task)
# 运行fedavg算法
runner = flgo.init(task, fedavg, {'gpu':[0,],'log_file':True, 'num_steps':5, 'num_rounds':3})
runner.run()