DEV Community

Midas/XIV
Midas/XIV

Posted on • Updated on

Federated Learning

  • Flower is the framework to do distributed machine learning training jobs in a privacy-enhancing way.
  • For some reason if all data cannot be shared, you would have to train a model in a distributed way without having to collect all the data centrally. Instead of moving data to the training process, you can move the training to the data in different centers, then you can collect all the model parameters and centralize it, not the raw data itself. this way you can benefit from all the data across all hospitals without ever needing to raw data to leave any place.

Why FL

data is naturally distributed, for example in governments data is distributed across many governments. Data might be sensitive, volume of data might be too high, regulations on data, data practicality.

When a model is trained without certain datatypes, it not only fails to predict the scenario correctly, but predicts it incorrectly which may be harmful.
it is impossible to centralize data, a large amount of data is unused as it is sensitive.


How does a federated tuning system works.

architecture; global model => server => clients.
server there's is usually no data, but there can be test data to check global models

global model sends initial parameter to server which passes it to all the clients.
role of server is to coordinate the training on the clients,

FL Algorithm

  1. initialization:
    server initializes the global model.

  2. Communication round:
    for each communication round

  3. Server sends the global model to participating clients

  4. Each client receives the global model.

  5. Client training and model update:
    for each participating , client trains on local data, client sends it's locally updated model to the server.

  6. Model Aggregation
    Server aggregates the models received from all clients using aggregation algorithm ( for fedAvg [strategy])

  7. Convergence Check
    if convergence criteria are met, then we end this process else we go again.


federated learning is an iterative process, clients train the model and the server aggregates model updates.


How to tune a federated system.

[entire chapter was for configuring the server and client, the hyper parameters you can pass to the flower framework.]

A common way to aggregate the model, is to average the weights.


what's data privacy in federated learning.

Privacy enhancing technologies , PETs for short. Fredarated learning in itself is a data minimizing solution, by preventing direct access to data, but exchange between client and server can lead to potential privacy leaks.

3 types of attacks

  1. Membership inference atttack -> try to infer participation of data smaples.
  2. Attribute inference attack -> Infer unseen attributes of the training data.
  3. Reconstruction attack -> infer specific training data samples

Malicious server was able to reconstruct data using the model updates sent be client. ( not exact same, but quality is close )

here comes differential privacy ?
DP protects individual privacy during data analysis.
It obscures individual data by adding calibrated noise to query results which ensures that the presence or absence of any single data point does not significantly impact the outcome of the analysis.

Adding minor noise to a sensitive dataset which could potentially be in the dataset.

The level of privacy protection ?
if we remove a single data point and retain the model (M2), how different is it form the original model (M1). the models M1 and M@ are indistinguishable to a certain degree, this is quantified by the level of privacy protection we aim to achieve.

Central DP and local DP.
Clipping, bounds the sensitivity and mitigates the impact of outliers; This is the maximum amount the output can change when a single datapoint is added or removed.

Noising, Add calibrated noise to to make the output statistically indistinguable.

In central DP the central server is responsible for adding noise to the globally aggregated parameters.
In local DP, each client is responsible for performing DP, this needs to be done before sending the updated model to the server.

it takes a lot of time for convergence and more time needs to be spent on rounds.

How to discuss bandwidth ?

bandwidth requirement for a frederated learning system: The server sends the models to the client to update but when the client sends back the compressed model.

(model_size_out + model_size_in ) * cohort_size * fraction_selected * number_of_rounds

To reduce bandwidth usage
reduce update size:

  • Sparsification: if the gradients to be communicated are below a certain threshold, you skip the update. This is likely towards the end of the cycle where the updates are not very drastic.
  • Quantization is to reduce the number of bits to represent scalar, this reduces the size further.

communicate less

  • Use pre-trained models, we may not need to train every single layer and only communicate modified layer.
  • Train more epochs locally and send the model to server later. This may cause a delay in convergence as the models diverge more and more with each round that it does not pass to the server.

Top comments (0)