DEV Community

Cover image for Don't be wrong because you might be fooled:  Tips on how secure your ML model
Andreas Messalas for Code4Thought

Posted on • Originally published at code4thought.eu

Don't be wrong because you might be fooled:  Tips on how secure your ML model

Figuring out the reasons why your ML model might be consistently less accurate in certain classes than others, might help you increase not only its total accuracy but also its adversarial robustness.

Introduction

Machine Learning (ML) models and especially the Deep Learning (DL) ones can achieve impressive results, especially on unstructured data like images and text. However, there is a fundamental limitation with the (supervised) ML framework: *the distributions we end up using ML models on are NOT always the ones we train it on. *

This leads to models that on the one hand seem accurate but on the other hand they are brittle.
Adversarial attacks exploit this brittleness and introduce unnoticeable perturbations to the images that force the model to flip their predictions.

For example, lighting a "Stop" sign in a specific way can make a traffic sign model predict it as a "Speed 30" sign (Fig. 1 left). In another example, just by rotating an image or replacing some words in a sentence of a medical diagnosis with their synonyms, you can fool models to give a wrong medical diagnosis and risk scores respectively (Fig. 1 right).


Fig. 1 (references [1], [2])

It is evident that adversarial robustness of a model is associated with its security and safety, thus it is important to be aware of its existence and its implications.
In this article, we will use a model trained on the CIFAR10 public dataset to:
📌 Investigate the intuition that the model's inability to correctly predict an image also leads to higher susceptibility to adversarial attacks.
📌 We will measure the class disparities, meaning that we will check the performance of the model across the 10 classes.
📌 In the end, we will conduct a root cause analysis that will pinpoint the causes of these class differences that hopefully will help us fix not only the miss-classifications but also will increase the adversarial robustness of the model.


The "cat", "bird" and "dog" classes are harder to correctly classify and easier to attack

We trained a simple model, using the well-known ResNet architectural pattern on a 20 layer deep network, which achieves a 89.4% accuracy on the validation set. We then plotted the miss-rate (per class) to check if there are any disparities between the classes (Fig. 2).
It is evident that the "cat", "bird" and "dog" classes are harder to correctly classify than the rest of the classes.


Fig. 2: Miss-Rates across the 10 classes of CIFAR10

We then applied two kind of adversarial attacks:

  • An untargeted attack, where an attack is considered successful when the predicted class label is changed (to any other label)
  • A targeted attack with the least-likely target, where we have a successful attack when the predicted class label is changed specifically to the label that the model has the least confidence for the specific instance.

Afterwards, we plotted the attack-success rate per class, which measures the percentage of successful attacks per class (note: each class in the test set has 1,000 images). We can observe that the most successfully attacked classes are the same ones that are also miss-classified (Fig 3).

This is intuitive and to some extent expected, since the fact that the model miss-classifies some instances means that it pays more attention to features that are not very relevant to that class, so adding more perturbed features makes the model's job much harder and the attacker's goal easier.


Fig. 3: Attack Success Rate for Targeted and Untargeted attacks

Root cause analysis

We identified three possible root causes for the class disparities in the model's predictions:

1. Miss-labeled/Confusing Training Data

Data collection is probably the most costly and time-consuming part of most machine learning projects. It's perfectly reasonable to expect that this arduous process will entail some mistakes. We discovered that the CIFAR10 training set contains some images that either are miss-labeled or they are themselves confusing even for humans (Fig. 4).


Fig.4: Confusing/Miss-labeled images from CIFAR10 training dataset

This could be considered some kind of involuntary data-poisoning, since a significant amount of these poisoned images can shift the data distributions to a false direction, thus making it more difficult for the model to learn meaningful features and consequently make correct predictions.

Data poisoning is also linked with adversarial attacks, although in that case the poisoned data are carefully crafted. It is shown that even a single poisoned image can affect multiple test-images (Fig. 5).


Fig. 5 (Koh et al)

2. Is this a "cat" or a "dog"?

The "cat" class has the worst miss-rate of all the classes, followed by the "dog" class which has the third worst miss-rate.

Since these animals share some similar features (four legs, ears, tail), our intuition is that our model could not extract meaningful features to distinguish these two animals or has learned better "dog" features than the "cat" ones.

Using saliency map explanations, we can verify this intuition: we can see that the model has not learned some distinctive cat characteristics such as the pointy ears and nose and instead focuses on the whole animal's face or body (Fig. 6).


Fig. 6: Saliency map explanations for images of "cat" miss-classified as "dog"

3. Is this a "bird" or an "airplane"?

A similar situation is happening with the "bird" and "airplane" classes. The model in this case is confused by the blue background, since most airplane images contain an object in a blue background (Fig. 7).


Fig. 7: Saliency map explanations for images of "bird" miss-classified as "airplane"


Takeaways

👉 Good data means a good model: spend some time investigating your data and try to identify if there are any systematic errors in your training set.
👉 Use explanation methods as a debugger, in order to understand why your model model misses certain groups of instances more than others
👉 Adversarial attacks are a cost-effective way to check the adversarial robustness of your model.

Top comments (0)