Federated Pruning for Federated Learning
This publication is a compilation of the studied materials on the topic and was written with the aim of understanding the material and streamlining one’s own knowledge – such an analogue of an essay / coursework. Because the source material is mainly in English, perhaps this article will also be useful to someone.
Instead of a prologue
Federated Learning – an approach to solving machine learning problems that allows you to analyze data directly in their sources, without combining them on some central resource, but combining the results in such a way that the learning outcome is no worse than in traditional approaches.
There are two “legends” of the appearance of FL. The first talks about data so large that training on them takes weeks and even months. The data began to be divided into parts. Then give parts to separate computing nodes. Then the data was no longer collected and began to be processed directly in the sources.
The second “legend” is about the confidentiality of training data. The difficulties of data depersonalization (no matter how hashed, but information with a specific object can be compared, especially if the number of observations in the source is small) led to the idea of transmitting not the data itself, but the results of training on them.
(“Legends” – because every version has examples of the opposite: medical data is not large, but on the contrary, it is small and difficult to collect them in one place – what kind of splitting is there, the problem is that such small samples will not give deep learning good efficiency, it is necessary not to process data, but to develop machine learning methods on small data.)
Whatever bothered mankind before – FL helps to solve both problems.
Data is important, but not everything
There is no technology in the world without limits. FL has it too. The computing resources of the data sources are small, and although only the end device’s own data is processed, the lack of resources can be a problem. Modern ML models are often large. For example, a successful model architecture for ASR, Conformer, has 130 million parameters and only requires at least 520 MB of memory to store them during training (that is, the phrase “Ok Google!” can load about a quarter of the computing power of average-sized smartphones – true , ML models only learn when the smartphone is idle, and this is also
not from a good life due to resource constraints). Reducing the volume of training data helps to combat this deficit. (Further we will use the words “reduction”, “truncation” and “rarefaction” as synonyms.)
There are several ways to achieve this reduction:
federated dropout (federal dropout). Reduced models on clients (submodels) are randomly generated; final training is performed on the full model on the server side. Submodels on clients when using federated pruning have an error rate that is independent of the full model on the server, making it easier to tune the model size on endpoints.
PruneFL. Enables trimming initial on the selected client and later trimming as part of the FL process. The size of the model is dynamically resized during this process. Experiments with different datasets on different peripherals (eg Raspberry Pi) show that training time is greatly reduced compared to conventional FL and various other reduction based methods; the auto-sized truncated model converges with an accuracy very close to the original model
Federated Pruning (FP). Removes (on end devices) redundant parameters from the model. Uses two tools: sparse patterns (masks) and sparse methods.
The last method, Federated Pruning, is the most popular at the moment. Let’s dwell on it in more detail.
FP has been extensively studied in centralized learning settings. It has been empirically proven that a model well initialized on peripherals can have the accuracy of a full model trained centrally.
The existence of the method is due to the fact that usually, to facilitate learning, models are initially over-parameterized. Therefore, parameters that are defined as redundant in a particular data set can be removed. The resulting smaller samples require less space on the device, less memory for training, and lower network bandwidth to transfer data to the server.
The main problem with this approach is that when parameters are considered insignificant and reduced in an early iteration, later iterations can be affected. To solve this problem, various methods of rarefaction in time and model are used.
Data sparseness begins with the generation of a sparse mask for variables on the server (1). The mask is binary, zeros mean that the corresponding parameters are insignificant and should be reduced. The number of cuts corresponds to the specified truncation level (it is a hyperparameter, set in fractions of the original number of parameters).
The server cuts data based on masks and sends it to clients (2).
Each client then trains this reduced model locally and returns the training result (3).
Finally, the server combines the training results (4) of all models from all clients and after averaging (5) moves on to the next stage.
Compared to standard federated learning, federated sparse truncates the data at the beginning of each epochand the trimmed samples are trained on clients and the trained (sub)model is transmitted over the network.
In the first step, the reduction level is increased from 0 to S. Depending on the dilution scheme, this can be done in one pass or iteratively.
The next FP stage after rarefaction (a) is refinement (b). At this stage, the vacuum level is fixed at S and the focus is on refining the masks. Note that there is an inverse transformation: if the value in the mask is inverted from 0 to 1, this means the restoration of the parameter cut off at the previous stage. This just solves the main problem of the FP approach – the possibility of erroneous clipping of the wrong parameters.
At the last tuning stage (c), the rarefaction level and mask are fixed, and the remaining parameters are optimized. This phase ends when the truncated models finish training.
It’s time to look at how to determine the significance of variables. Usually for FP, L1- and L2-norms are used, calculated from (one of):
weight coefficients of variables,
gradient descent momentum (or gradient descent)
the product of the weight and the magnitude of the gradient.
L1- and L2-norms are used as an estimate of the significance of model parameters. The resulting significance weights are sorted and, using the target sparsity level, a cutoff threshold is set. Matrix regions with a lower threshold value will be removed from the data received from clients to form a reduced model. On the server side, areas with lower significance scores still remain and will be permanently removed in the final tuning phase (or restored if required).
Depending on the method of setting the cutoff threshold, rarefaction schemes can be unstructured – less strong connections between any nodes are reduced – and structured – large structures, groups of nodes or entire layers are reduced.
The following structural pruning templates are implemented:
Whole row/column: Delete the entire row or column of the 2D weight matrix W. (Example here)
Half row / column (Half row / column): evenly divide the two-dimensional matrix W into [W1;W2] and cut half of the original row or column.
For matrices of dimensions greater than 2, they are first converted to two-dimensional ones, the specified templates are applied, and returned to the original dimensional space.
Upgrading the quality of the mask
During the refinement phase, the mask is still being reshaped. Since the reduced regions in the full model on the server receive 0 updates from compute nodes, their significance score does not change. For the rest of the data areas, their significance estimates are updated. If the significance of the unreduced parameters is reduced, they can be reduced. At the same time, in order to maintain the target sparsity level, the previously reduced parameters must be in the appropriate amount and restored in accordance with the updated set of all significance estimates. Therefore, the masks are refined and retrained in this phase.
Note that for gradient-based importance scoring, the masked areas will not have gradients and will lose their ability to recover.
In deep neural networks, layers contribute differently to model accuracy. Accordingly, some can be classified as critical – those whose cropping will lead to serious deterioration in quality, others as ambient layers – those whose cropping has little effect on the quality of the model. Therefore, the sparsity level must be adjusted on each layer separately. In FP, the sparsity across layers does not change.
To determine the layer-by-layer sparsity level, the significance of each layer is considered, using, for example, the average value of the significance of the parameters of this layer. Next, a sparsity level is assigned for each layer using a significance score. The less significant the layer, the more sparsity it gets.
Thus, when adaptive sparseness is applied, the architecture of the model changes – a separate layer of the neural network can be reduced, in whole or in part – while classical Federated Learning leaves the architecture of the model unchanged.
What about in practice?
In the article “Federated Pruning: Improving Neural Network Efficiency with Federated Learning» shows the results of an experiment conducted on a distributed learning simulator. We will give here (very briefly) the main points, for a full description – welcome to the original article.
The experiment is devoted to the study of the applicability of Federated Pruning in the task of speech recognition. Dataset – a set of open anonymous data LibriSpeech: 970 hours of manually marked speech. The base model is Conformer-Transducer: 17512-dimensional conformal layers of encoders (encoders), 640-dimensional embedding prediction networks and 2048-dimensional fully connected merging layer. The model was trained for 30,000 federation rounds.
Experiments with the reduction of various measurements have shown that the weight-based estimate achieves the same WER (Word Error Rate) as the estimates based on the gradient and the product of the weight and the gradient, however, the weight-based estimate is more stable, so it was used as indicator of the significance of parameters and layers.
At a sparsity level of 50%, adaptive sparsity outperforms uniform sparsity.
Column-based reduction and half-column reduction are superior to row-based reduction; thus reduction of the whole column is more efficient. That is, it is advisable to use the weight as the sparse method and the entire column as the sparse pattern.
One can observe an obvious deterioration in quality at a sparsity level > 30% in all clipping schemes.
Mask refinement reduces WER.
The Federated Pruning method reduces data but retains training efficiency. As they say, I recommend everyone to love FP. And adaptive underpressure. The cost of training is reduced, memory requirements and network bandwidth are reduced.
However, the complexity of execution increases. In this article at all data traffic issues (in a controlled environment / over a normal network / using SwitchML, etc.) are not affected, and in some implementations they are entirely the responsibility of the developer. And the time and machine costs for traffic are not taken into account. Although this would give a closer to the real calculation of the efficiency of FL. Perhaps somewhere it would have been reduced to zero. However, so that the text is not sadly large, we will leave it for the material of other articles.