DEV Community

Cover image for Privacy-Preserving Machine Learning with AIJack - 1: Federated Learning on PyTorch
Syumei
Syumei

Posted on • Updated on

Privacy-Preserving Machine Learning with AIJack - 1: Federated Learning on PyTorch

This post is part of our Privacy-Preserving Machine Learning with AIJack series.

Overview

In this tutorial, we will learn the novel distributed learning algorithm, Federated Learning, which allows you to train a neural network while preserving privacy.

While deep learning achieves substantial success in various areas, training deep learning models require much data. Thus, acquiring high performance in deep learning while preserving privacy is challenging. One way to solve this problem is Federated Learning, where multiple clients collaboratively train a single global model without sharing their local dataset.

The procedure of typical Federated Learning is as follows:

1. The central server initializes the global model.
2. The server distributes global model to each client.
3. Each client locally calculates the gradient of the loss function on their dataset.
4. Each client sends the gradient to the server.
5. The server aggregates the received gradients with some method (e.g., average) and updates the global model with the aggregated gradient.
6. Repeat 2 ~ 5 until converge.
Enter fullscreen mode Exit fullscreen mode

The mathematical notification when the aggregation is the weighted average is as follows:

wtwt1ηc=1CncNl(wt1,Xc,Yc) w_{t} \leftarrow w_{t - 1} - \eta \sum_{c=1}^{C} \frac{n_{c}}{N} \nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c})

, where wtw_{t} is the parameter of the global model in tt -th round, l(wt1,Xc,Yc)\nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c}) is the gradient calculated on cc -th client's dataset ((Xc,Yc))((X_{c}, Y_{c})) , ncn_{c} is the number of cc -th client's dataset, and N is the total number of samples.

Code

Next, we will implement FedAVG [1], one of the most representative methods of Federated Learning. We use AIJack, an OSS, to simulate machine learning algorithms' security and privacy risks. AIJack supports both single-process and MPI as its backend.

First, we install AIJack with pip.

apt install -y libboost-all-dev
pip install -U pip
pip install "pybind11[global]"

pip install aijack
Enter fullscreen mode Exit fullscreen mode

Single-process

We import the following modules.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI
Enter fullscreen mode Exit fullscreen mode

The hyper-parameters are as follows.

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
client_size = 2
criterion = F.nll_loss
Enter fullscreen mode Exit fullscreen mode

This tutorial uses the MNIST dataset.

def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader
Enter fullscreen mode Exit fullscreen mode

AIJack allows you to implement the clients and server of Federated Learning with PyTorch model.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))
Enter fullscreen mode Exit fullscreen mode

Then, you can execute the training via run method of FedAVGAPI.

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
Enter fullscreen mode Exit fullscreen mode

MPI

You can easily convert the above code to MPI-compatible code that can run in the parallel programming environment.

# mpi_FedAVG.py

import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    mpi_client_manager = MPIFedAVGClientManager()
    mpi_server_manager = MPIFedAVGServerManager()
    MPIFedAVGClient = mpi_client_manager.attach(FedAVGClient)
    MPIFedAVGServer = mpi_server_manager.attach(FedAVGServer)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPIFedAVGServer(comm, [1, 2], model)
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPIFedAVGClient(comm, model, user_id=myid)
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device
        )

    api.run()


if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

You can run the above code with the standard MPI command.


!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py
Enter fullscreen mode Exit fullscreen mode

Summary

In this tutorial, we learned Federated Learning, one promising approach to securely train deep learning models without violating privacy. You can find more examples and notebooks in the document of AIJack. Although this scheme seems safe since each client does not have to share its local dataset, the next tutorial demonstrates that shared local gradients might leak private information.

Reference

[1] McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial intelligence and statistics. PMLR, 2017.

Top comments (0)