Abstract
In order to improve communication efficiency and adversarial robustness, the well known XGBoost libraries have been modified to execute in a federated setting, thus allowing the training of multiple number of trees using XGBoost functionalities along with associated local data set(s) before passing the trees to a centralised server for aggregation purposes.
In a conventional approach, client devices train just one single tree and send that single tree to the server for aggregation. The inventive approach is referred to as “tree plus” in that client devices can train trees locally and pass a number of trees to a centralised service or server, significantly improving processing performance and reducing power costs on the devices. In some embodiments, normalization of one or more parameters, such as a “learning rate” allows different client devices to process a different number of datapoints.
Description of FLeXGBoost
In a conventional distributed, IoT environment, endpoints (e.g., client devices) access and process local data using centrally derived and distributed models. These models are typically trained in a central processing server using data collected from the local endpoint client devices. Unfortunately, there are numerous implementations where transmitting locally sourced data with a central service is expensive (from a processing power and/or network bandwidth perspective) and can run afoul of laws, regulations or data usage policies regarding the use of private data.
Federated Learning (“FL”) offers an alternative solution than the traditional centralised machine learning approaches, by not requiring the clients to share their private and sensitive data, but still benefit from system-wide learning. In a federated learning setting, in each training epoch, the server calculates a global update of the model parameters upon receiving local updates from clients who get the local updates by using their data in a private setting, thus creating a powerful shared model without having to share potentially private data.
There are numerous approaches to implementing federated learning based on data partitions and the particular model(s) being used. Herein we use a federated learning approach in a horizontal setting, meaning all the client devices have the same feature space (e.g., the same column names) but each receives different samples and the model is chosen in XGBoost.
Neural network (NN) based models are one of the most widely-used model architectures while tree-based methods (like XGBoost) are far less explored due to the challenges in overcoming the iterative and additive characteristics of the algorithm. However, there are certain commercial advantages of using tree-based methods, such as XGBoost. For example, XGBoost is plug and play, whereas use of NN-based models requires some fairly technical background knowledge of the method and the associated software packages, and though neural networks can provide customised solutions specifically designed to each dataset, that comes with a high implementation cost. Moreover, using a tree-based approach such as XGBoost is more efficient from a processing perspective, saving power and costs at the client device. XGBoost also provides a more straightforward approach to understanding feature importance than traditional neural network models, and can process data that is not independent and identically distributed (IID), which is critical given the decentralised nature of the data being collected. XGBoost can also address missing values and sparse datasets.
XGBoost or eXtreme Gradient Boosting is a machine learning model for supervised learning tasks, such as regression and classification. XGBoost is an advanced variant of Boosting Machine, using additive strategy (add one new tree at a time, which is a sub-class of Tree-based Ensemble algorithm, like Random Forest.)
In XGBoost, the training proceeds iteratively, adding new trees that predict the residuals or errors of prior trees that are then combined with previous trees to make the final prediction. It is referred to as “gradient boosting” because it uses a gradient descent algorithm to minimise the loss when adding new models. The “federation” of this tree-based model is illustrated in Fig. 1, and described further below.
Figure 1
In one embodiment, the process for federating the XGBoost process follows these steps:
- Assume, we have N number of clients each having their local dataset and all the clients are participating at each FL round.
- At the start of the first FL-round, each client is training XGBoost with their local data for a given number of local rounds (say n). Each client is sending their n local trees to the server at the end of the first FL-round.
- At the server side, the n trees coming from each of the N clients (N × n) are aggregated together to form a first global model. This global model is then passed to each of the N clients at the beginning of the second FL round. The server also stores this global model centrally.
- Upon receiving the global model consisting of (N × n) trees, each client generates n number of trees using their local data. At the end of the second FL-round, each client will send the newly generated n trees to the server for processing. Note that, clients do not need to send the whole (nN + n) number of trees to the server as the server already have a copy of first nN number of trees.
- At the second round, the server has a global model consisting of (nN + nN) = 2nN trees where the first nN is coming from the first round and second nN is coming from n number of trees from N clients during the second round. The server then passes the global model consisting of 2nN to each client.
- This above process continues until the preselected number of total FL rounds are processed.
- The above method is valid for both tree, tree+ cases in both binary and multi-class classification problems. The main difference between binary and multiclass classification problems is the way of storing the tree and the indexing of trees in a particular iteration. Hence we have to be extra careful while bagging the trees at the server maintaining the proper indexing of the trees.
Note that, there are two important parameters, n, the number of local rounds and learning rate (eta). Learning rate or shrinkage rate parameter is used to reduce the effect of one single tree to reduce overfitting. Based on these two parameters, there are different ways to implement FLeXGBoost using the instant inventive techniques:
Methodology | Reason behind methodology |
n : n, n=[2,3,5] | Model convergence is faster if the clients are allowed to pass more than a single tree in each FL round. |
n : n, n=[2,3,5]+normalized learning rate | Learning rate or shrinkage rate reduces the effect of one single tree to reduce overfitting of the model. In many scenarios, different clients may have a different number of datapoints. Clients having a small number of datapoints may include more noise in generating the tree compared to the clients with a higher number of datapoints. In fl-xgboost, each client is initiated with the same learning rate. Normalised learning rate reduces the effect of clients with smaller datapoints further by using a normalised learning rate which is proportional to the number of datapoints. As a result, each client changes its learning rate by a factor which is equal to the number of data points at that client, divided by the total number of datapoints by all clients. The normalised learning rate is used as the clients generates any number of trees at their end. |
n :n, n=[2,3,5]+normalized learning rate Normalized using feature importance, useful for security purposes. | This approach uses an alternate form of normalisation rather than the number of datapoints. It is possible that clients having a greater number of datapoints may have datapoints with noises generating trees with noise. The Gain parameter in XGBoost is maximised while finding the break point against a particular feature. After generating a tree-structure, the gain against each feature for each tree is stored, and summed over all the features, thus measuring the overall gain associated with each feature against each tree. The learning rate/shrinkage rate is then normalised by gain per client divided by the total gain. Using this approach excludes clients which have trees with much lower gain values compared to other clients. |
Configuration parameters such as the following may be used:
- pool_size: Total number of Clients.
- num_rounds: Number of total FL rounds.
- num_clients_per_round: Number of participating clients par FL round. In certain instances, every client is participating in each FL round. Hence num_client_per_round = pool_size. In other cases, not every client participates in each round.
- num-local-round: number of trees each client will generate at their side in each FL round. If n = 1, then bare-bone, if n > 1, then the model is Tree+.
- normalized_lr: boolean. if yes, then strategy will be FedXgbBaggingNormLr, implementing normalization of learning rate method. if no, then strategy will be FedXgbBagging. if n = 1 and normalized_lr = yes, method will be NormLr, for n > 1 and normaized_lr = yes, method is NormLr+.
BST_PARAMS = {
"objective": "binary:logistic",
"eta": 0.015, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}
In this instance, “eta” is the learning rate/shrinkage rate parameter in the case of normalisation learning rate hypothesis. Initially, each client receives the same bst_params and normalisation is done inside the code itself. Note that, with increasing client or pool size, the eta value decreases as additional trees are added into the model. For the comparison analysis with n = 1 and n> 1 cases, we have kept the same ‘eta’ value for the fixed number of clients. Note that, in a multi-class classification problem, there may be additional key ‘num_class’ in the param dictionary where the objective will be ‘multi:softmax’.
BST_PARAMS = {
"objective": "multi:softmax",
"num_class": 9,
}
If num_class : 2, objective : multi:softmax will still work but the aggregation method will not unless the objective function is set to binary:logistic. Using the Higgs dataset as an example classification problem, the technique may be used to distinguish between a signal process which produces Higgs bosons and a background process which does not, with a dataset of 11000000 datapoints, of which 1000000 datapoints are used as holdout test data set at the server side. The total number of features are 28. The metrics are being calculated at the server side using the hold-out test data set.
- Metrics used: AUC (Area under the curve).
- Benchmark is done against Flower and NVIDIA results.
This test run uses five clients having uniform data distribution – meaning each client has an equal number of datapoints and an eta set to 0.1. To reach the same accuracy (e.g., 0.83) the Tree+ method (with n=3) converges in approximately 20 rounds while conventional methods take approximately 36 rounds. In order to do a comparison study, we assess the number of trees being exchanged between the server and clients after a particular round. For conventional techniques, after round 36, the total number of trees being passed is (36*5*1) or 180 trees, while using Tree+, the same number of trees being passed is at round 12 (12*3*5). The time taken for bare-bone and Tree+ to pass the exact 180 number of trees are 357 secs and 391 secs with an accuracy of 0.83 and 0.825. Hence the disclosed technique accumulates decision trees and model accuracy at a similar rate per second within 10% of the time for conventional approaches, but using the Tree+ strategy, there are 66% fewer (1-12/36) communication rounds between the server and client. As a result, significant bandwidth and data transmission issues can be avoided.
When using the normalised FL-XGBoost method with a square data distribution, a dominating factor is the learning rate on the final value of AUC. The learning rate in the case of normalised case is much faster (almost 10 times) compared to unnormalised as the learning rate is distributed among multiple clients. The same approach was extended to cover a use case with 20 clients, and the results are graphed below:
Conclusion
The implementation of Federated XGBoost within the OctaiFL framework marks a significant advancement in distributed machine learning technology. By enabling local training of multiple decision trees and their subsequent aggregation at a centralised server, FLeXGBoost not only enhances processing efficiency and reduces power consumption but also addresses critical concerns related to data privacy and network bandwidth constraints.
This approach leverages the robustness of XGBoost in handling diverse and non-IID datasets, making it especially suited for real-world applications where data privacy is crucial. As we continually refine these techniques, our goal is to promote greater adoption of federated learning models, ensuring that organisations can benefit from collective insights without compromising individual data security. This evolution in machine learning paradigms has the ability to reshape how data intelligence is approached in industries across the globe.