Home > Back-end >  Understanding Intermediate Values and Pruning in Optuna
Understanding Intermediate Values and Pruning in Optuna

Time:11-19

I am just curious for more information on what an intermediate step actually is and how to use pruning if you're using a different ml library that isn't in the tutorial section eg) XGB, Pytorch etc.

For example:

X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)
n_train_iter = 100

def objective(trial):
    global num_pruned
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)
    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        intermediate_value = clf.score(X_valid, y_valid)
        trial.report(intermediate_value, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return clf.score(X_valid, y_valid)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.HyperbandPruner(
        min_resource=1, max_resource=n_train_iter, reduction_factor=3
    ),
)
study.optimize(objective, n_trials=30)

What is the point of the for step in range() section? Doesn't doing this just make the optimisation take more time and won't you yield the same result for every step in the loop?

I'm really trying to figure out the need for for step in range() and is it required every time you wish to use pruning?

CodePudding user response:

The basic model creation can be done by passing a complete training datasets once. But there are models that can still be improved (an increase in accuracy) by re-training again on the same training datasets.

To see to it that we are not wasting resources here, we would check the accuracy after every step using the validation datasets via intermediate_score if accuracy improves, if not we prune the whole trial skipping other steps. Then we go for next trial asking another value of alpha - the hyperparameter that we are trying to determine to have the greatest accuracy on the validation datasets.

For other libraries, it is just a matter of asking ourselves what do we want with our model, accuracy for sure is a good criteria to measure the model's competency. There can be others.

Example optuna pruning, I want the model to continue re-training but only at my specific conditions. If intermediate value cannot defeat my best_accuracy and if steps are already more than half of my max iteration then prune this trial.

best_accuracy = 0.0


def objective(trial):
    global best_accuracy

    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)

    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        if step > n_train_iter//2:
            intermediate_value = clf.score(X_valid, y_valid)

            if intermediate_value < best_accuracy:
                raise optuna.TrialPruned()

    best_accuracy = clf.score(X_valid, y_valid)

    return best_accuracy

Optuna has specialized pruners at https://optuna.readthedocs.io/en/stable/reference/pruners.html

  • Related