Early Stopping Strategies for Federated Learning
What is early stopping?
Early stopping is a neural network training technique that maximises the generalisability of a trained model during the training process. It is one of the classical regularisation methods in deep learning. It is also commonly used in combination with model checkpointing – a technique to save snapshots of the model during the training process.
Examples of early stopping
The strategies for early stopping fall into two broad classes:
- When validation error exceeds threshold. For example, training is stopped when the current validation error is higher than that previously checked.
- When a prespecified number of intervals (e.g. epochs, validation checks, etc…) have elapsed. Sometimes referred as ‘patience’, this strategy stops the neural network training when the validation error has been increasing or not reducing over a prespecified number of successive checks.
Why is early stopping important?
In neural network training, we want to obtain a network with the optimal generalisable performance. If the neural network is trained for too long, the network tends to learn the noise that is specific to the dataset. This is the so-called overfitting or overtraining phenomenon of neural network training. A symptom of this is when the variance – which is the amount that the model output varies with dataset – is greatest when overfitting because the noise varies between datasets.
Overfitting is best illustrated by the sketch below.
Sketch of idealised training and validation curves.
The sketch above shows the idealised training and validation error curves in a supervised neural network training. Let’s assume that the validation dataset is representative of the real-world dataset and as such, we can treat the validation error as a proxy metric for generalisable error. It is desirable to avoid training models that yield large generalisable error, that is, the red section of the curves. This is when mechanisms are needed that can tell us when to stop training, i.e. the early stopping techniques.
The challenges are compounded in real world use-cases: Firstly, it is widely accepted among machine learning practitioners that neural network trainings on real world datasets have more complex non-monotonic validation error curves. Secondly, in federated learning, where a global model is trained collaboratively across N datasets from N clients, it is crucial to ensure that each client yields their optimal model per aggregation round to ensure that an optimum global model is aggregated.
This scenario is sketched below:
Real world validation curves in federated learning with N clients.
Therefore, for federated learning, the speed of training should be taken into account when selecting an early stopping criterion. As a rule-ofthumb, generalisability of the model is sacrificed when moving from a slower to a faster early stopping criterion.
In short, the modelling mindset for implementing early stopping for collaborative learning is to judiciously select a criterion that has the best “price-performance ratio”, i.e. minimum training time for a given generalisation error.
Early Stopping for Federated Learning
Given the common server-client topology of federated learning, two early stopping criteria hierarchies can be utilised:
- Client early stopping
- Server early stopping
In the client hierarchy, early stopping is implemented on the client-side and is exactly the same as described above. E.g. we can specify a validation error threshold, a fixed interval, or a combination of both. These strategies are well-documented for deep learning frameworks, for example, in PyTorch Lightning and can be straighforwardly implemented in the fit method of your custom OctaiPipe FL client. Additionally, since each client may have training datasets with different sizes, it has been proven useful to stop local training once the local dataset is exhausted and then use this state of the local model for all future federated learning update steps until the round is completed.
In the server hierarchy, a similar criterion can be implemented as the client-side approach, but this time using the aggregated neural network and on the error metric of the server-side validation dataset. The optimal global model can then be selected with the appropriate criterion as before.