DEV Community

Cover image for Privacy-Preserving Machine Learning with AIJack - 2: Model Inversion Attack against Federated Learning on PyTorch
Syumei
Syumei

Posted on • Updated on

Privacy-Preserving Machine Learning with AIJack - 2: Model Inversion Attack against Federated Learning on PyTorch

This post is part of our Privacy-Preserving Machine Learning with AIJack series.

  • Part 1: Federated Learning
  • Part 2: Model Inversion Attack against Federated Learning
  • Part 3: Federated Learning with Homomorphic Encryption
  • Part 4: Federated Learning with Differential Privacy
  • Part 5: Federated Learning with Sparse Gradient
  • Part 6: Poisoning Attack against Federated Learning
  • Part 7: Federated Learning with FoolsGold
  • Part 8: Split Learning
  • Part 9: Label Leakage against Split Learning

Overview

Although Federated Learning allows clients to hide their private datasets, many papers [1, 2, 3] show that the malicious server can recover private training samples from the uploaded local gradient l(wt1,X,Y)\nabla \mathcal{l}(w_{t - 1}, X, Y) .

Since the server already knows the parameters of the global model wt1w_{t - 1} , the server can estimate the private training sample (X,Y)(X, Y) with the following optimization.

XXλXD X' \leftarrow X' - \lambda \nabla_{X'} D
YYλYD Y' \leftarrow Y' - \lambda \nabla_{Y'} D

, where

DD
is the loss function calculated as follows:
D=l(wt1,X,Y)l(wt1,X,Y)2 D = || \nabla \mathcal{l}(w_{t - 1}, X, Y) - \nabla \mathcal{l}(w_{t - 1}, X', Y') ||_{2}

In other words, this attack tries to reconstruct the private training data by optimizing the fake data to generate gradients close enough to the received gradients from the client.

Code

Although many works propose various distance metrics, regularization terms, and optimization methods, AIJack supports many popular components.

First, we need to import the necessary libraries.

import cv2
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from numpy import e
from matplotlib import pyplot as plt
import torch.optim as optim
from tqdm.notebook import tqdm

from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer
from aijack.attack.inversion import GradientInversionAttackServerManager
from torch.utils.data import DataLoader, TensorDataset
from aijack.utils import NumpyDataset
Enter fullscreen mode Exit fullscreen mode

We use LeNet and MNIST for demonstration purpose.

class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            nn.BatchNorm2d(12),
            act(),
        )
        self.fc = nn.Sequential(nn.Linear(hideen, num_classes))

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def prepare_dataloader(path="MNIST/.", batch_size=64, shuffle=True):
    at_t_dataset_train = torchvision.datasets.MNIST(
        root=path, train=True, download=True
    )

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    dataset = NumpyDataset(
        at_t_dataset_train.train_data.numpy(),
        at_t_dataset_train.train_labels.numpy(),
        transform=transform,
    )

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0
    )
    return dataloader
Enter fullscreen mode Exit fullscreen mode

The hyper-parameters are as follows:

torch.manual_seed(7777)

shape_img = (28, 28)
num_classes = 10
channel = 1
hidden = 588
criterion = nn.CrossEntropyLoss()

num_seeds = 5
Enter fullscreen mode Exit fullscreen mode

We will try to recover the below data.

device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
dataloader = prepare_dataloader()
for data in dataloader:
    xs, ys = data[0], data[1]
    break

x = xs[:1]
y = ys[:1]

fig = plt.figure(figsize=(1, 1))
plt.axis("off")
plt.imshow(x.detach().numpy()[0][0], cmap="gray")
plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Like Part 1, we can easily implement Federated Learning with AIJack. One big difference is that we wrap FedAVGServer class with GradientInversionAttackServerManager, so the server can execute gradient-based model inversion attack. This manager class makes the server estimate the private data from the uploaded gradient in each communication. We attack five times with different random seeds.

manager = GradientInversionAttackServerManager(
    (1, 28, 28),
    num_trial_per_communication=num_seeds,
    log_interval=0,
    num_iteration=100,
    distancename="l2",
    device=device,
    gradinvattack_kwargs={"lr": 1.0},
)
DLGFedAVGServer = manager.attach(FedAVGServer)

client = FedAVGClient(
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)
server = DLGFedAVGServer(
    [client],
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)

local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]

api = FedAVGAPI(
    server,
    [client],
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=1,
    local_epoch=1,
    use_gradients=True,
    device=device,
)

api.run()
Enter fullscreen mode Exit fullscreen mode

Then, we can confirm that the attacker can successfully recover the original private image with all random seeds.

fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
    ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
    ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Summary

This tutorial taught us that Federated Learning is unsafe since the server can steal private training data from the received gradients. You can see more examples of Model Inversion Attacks against Federated Learning in AIJack's document. To prevent this attack, the following tutorial introduces Federated Learning with Homomorphic Encryption, where each client encrypts its local gradients before uploading.

Reference

[1] Zhu, Ligeng, Zhijian Liu, and Song Han. "Deep leakage from gradients." Advances in neural information processing systems 32 (2019).
[2] Zhao, Bo, Konda Reddy Mopuri, and Hakan Bilen. "idlg: Improved deep leakage from gradients." arXiv preprint arXiv:2001.02610 (2020).
[3] Yin, Hongxu, et al. "See through gradients: Image batch recovery via gradinversion." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.

Top comments (0)