针对验证损失的 Optuna 剪枝

发布于 2025-01-16 19:55:41 字数 502 浏览 3 评论 0原文

我在我的深度学习项目中引入了以下几行代码,以便在验证损失在 10 个 epoch 内没有改善时尽早停止:

if best_valid_loss is None or valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    counter = 0 
else:
    counter += 1 
    if counter == 10: 
        break

现在我想使用 Optuna 来调整一些超参数,但我不太明白 Optuna 中的剪枝是如何工作的。 Optuna 修剪器是否可以按照上面代码中的方式进行操作?我假设我必须使用以下内容:

optuna.pruners.PatientPruner(???, patience=10)

但我不知道我可以在 PatientPruner 中使用哪个修剪器。顺便说一句,在 Optuna 中,我正在最大限度地减少验证损失。

I introduced the following lines in my deep learning project in order to early stop when the validation loss has not improved for 10 epochs:

if best_valid_loss is None or valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    counter = 0 
else:
    counter += 1 
    if counter == 10: 
        break

Now I want to use Optuna to tune some hyperparameters, but I don't really understand how pruning works in Optuna. Is it possible for Optuna pruners to act the same way as in the code above? I assume I have to use the following:

optuna.pruners.PatientPruner(???, patience=10)

But I don't know which pruner I could use inside PatientPruner. Btw in Optuna I'm minimizing the validation loss.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

醉殇 2025-01-23 19:55:41

简短回答:是的。

大家好,我是 Optuna 中 PatientPruner 的作者之一。如果我们执行普通提前停止,则 wrapped_pruner=None 会按我们的预期工作。例如,

import optuna

def objective(t):
    for step in range(30):
        if step == 5:
            t.report(0., step=step)
        else:
            t.report(step * 0.1, step=step)
        if t.should_prune():
            print("pruned at {}".format(step))
            raise optuna.exceptions.TrialPruned()
            
    return 1.

study = optuna.create_study(pruner=optuna.pruners.PatientPruner(None, patience=9), direction="minimize")
study.optimize(objective, n_trials=1)

输出将修剪为 15

Short answer: Yes.

Hi, I'm one of the authors of PatientPruner in Optuna. If we perform vanilla early-stopping, wrapped_pruner=None works as we expected. For example,

import optuna

def objective(t):
    for step in range(30):
        if step == 5:
            t.report(0., step=step)
        else:
            t.report(step * 0.1, step=step)
        if t.should_prune():
            print("pruned at {}".format(step))
            raise optuna.exceptions.TrialPruned()
            
    return 1.

study = optuna.create_study(pruner=optuna.pruners.PatientPruner(None, patience=9), direction="minimize")
study.optimize(objective, n_trials=1)

The output will be pruned at 15.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文