Example: MNIST
1. Prepare the following dataset configuration file my_dataset.py
(remark: The name can be arbitrarily set)
The constructed dataset configuration .py
file needs to have variables: train_data
, test_data
(optional), and function get_model()
import os
import torchvision
import torch
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
# Define Variable: train_data
train_data = torchvision.datasets.MNIST(root=path, train=True, download=True, transform=transform)
# Define Variable: test_data
test_data = torchvision.datasets.MNIST(root=path, train=False, download=True, transform=transform)
class mlp(torch.nn.Module):
def __init__(self):
super(mlp, self).__init__()
self.fc1 = torch.nn.Linear(784, 200)
self.fc2 = torch.nn.Linear(200, 200)
self.fc3 = torch.nn.Linear(200, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Define Model: get_model()
def get_model():
return mlp()
2. Using flgo.gen_benchmark
to generate benchmark from configuration
import flgo
import os
bmkname = 'my_mnist_classification' # the name of benchmark
bmk_config = './my_dataset.py' # the path of the configuration file
# Constructing benchmark by flgo.gen_benchmark
if not os.path.exists(bmkname):
bmk = flgo.gen_benchmark_from_file(bmkname, bmk_config, target_path='.', data_type='cv', task_type='classification')
print(bmk)
3. Test
Use the constructed benchmark as other plugins.
import flgo.algorithm.fedavg as fedavg
bmk = 'my_mnist_classification'
task = './my_mnist'
task_config = {
'benchmark':bmk,
}
if not os.path.exists(task): flgo.gen_task(task_config, task_path=task)
# 运行fedavg算法
runner = flgo.init(task, fedavg, {'gpu':[0,],'log_file':True, 'num_steps':5, 'num_rounds':3})
runner.run()