Home > front end >  LightGBM: train() vs update() vs refit()
LightGBM: train() vs update() vs refit()

Time:09-10

I'm implementing LightGBM (Python) into a continuous learning pipeline. My goal is to train an initial model and update the model (e.g. every day) with newly available data. Most examples load an already trained model and apply train() once again:

updated_model = lightgbm.train(params=last_model_params, train_set=new_data, init_model = last_model)

However, I'm wondering if this is actually the correct way to approach continuous learning within the LightGBM library since the amount of fitted trees (num_trees()) grows for every application of train() by n_estimators. For my understanding a model update should take an initial model definition (under a given set of model parameters) and refine it without ever growing the amount of trees/size of the model definition.

I find the documentation regarding train(), update() and refit() not particularly helpful. What would be considered the right approach to implement continuous learning with LightGBM?

CodePudding user response:

In lightgbm (the Python package for LightGBM), these entrypoints you've mentioned do have different purposes.

The main lightgbm model object is a Booster. A fitted Booster is produced by training on input data. Given an initial trained Booster...

  • Booster.refit() does not change the structure of an already-trained model. It just updates the leaf counts and leaf values based on the new data. It will not add any trees to the model.
  • Booster.update() will perform exactly 1 additional round of gradient boosting on an existing Booster. It will add at most 1 tree to the model.
  • train() with an init_model will perform gradient boosting for num_iterations additional rounds. It also allows for lots of other functionality, like custom callbacks (e.g. to change the learning rate from iteration-to-iteration) and early stopping (to stop adding trees if performance on a validation set fails to improve). It will add up to num_iterations trees to the model.

What would be considered the right approach to implement continuous learning with LightGBM?

There are trade-offs involved in this choice and no one of these is the globally "right" way to achieve the goal "modify an existing model based on newly-arrived data".

Booster.refit() is the only one of these approaches that meets your definition of "refine [the model] without ever growing the amount of trees/size of the model definition". But it could lead to drastic changes in the predictions produced by the model, especially if the batch of newly-arrived data is much smaller than the original training data, or if the distribution of the target is very different.

Booster.update() is the simplest interface for this, but a single iteration might not be enough to get most of the information from the newly-arrived data into the model. For example, if you're using fairly shallow trees (say, num_leaves=7) and a very small learning rate, even newly-arrived data that is very different from the original training data might not change the model's predictions by much.

train(init_model=previous_model) is the most flexible and powerful option, but it also introduces more parameters and choices. If you choose to use train(init_model=previous_model), pay attention to parameters num_iterations and learning_rate. Lower values of these parameters will decrease the impact of newly-arrived data on the trained model, higher values will allow a larger change to the model. Finding the right balance between those is a concern for your evaluation framework.

  • Related