我该如何做早期停止;在Tsai的Fit()期间

发布于 2025-02-10 10:19:25 字数 803 浏览 2 评论 0原文

以下是我的示例代码来自

from tsai.all import *

dsid = 'AppliancesEnergy'
arch_config = {
    'hidden_size':100, 
    'n_layers':2, 
    'rnn_dropout':0.2, 
    'fc_dropout':0.5, 
    'bidirectional':True
    }

X, y, splits = get_regression_data(dsid, split_data=False)
learn = TSRegressor(
    X, 
    y, 
    splits=splits, 
    bs=128, 
    batch_tfms=[TSStandardize(by_sample=True)], 
    arch=LSTM, 
    arch_config=arch_config, 
    metrics=[mae, rmse], 
    cbs=ShowGraph(), 
    verbose=True)

learn.fit_one_cycle(100, lr_max=1e-3)
learn.plot_metrics()

这很好。我想做的是在fit期间早点停止()。 我在Fastai中找到了回调函数'terminateOnnancallback()',我在下面的Import fastai上应用了它。

learn.fit_one_cycle(100, lr_max=1e-3, cbs=TerminateOnNaNCallback())

但这不起作用。如果有人知道,请告诉我。 谢谢。

Below is my sample code from tsai notebook for time series regression problem.

from tsai.all import *

dsid = 'AppliancesEnergy'
arch_config = {
    'hidden_size':100, 
    'n_layers':2, 
    'rnn_dropout':0.2, 
    'fc_dropout':0.5, 
    'bidirectional':True
    }

X, y, splits = get_regression_data(dsid, split_data=False)
learn = TSRegressor(
    X, 
    y, 
    splits=splits, 
    bs=128, 
    batch_tfms=[TSStandardize(by_sample=True)], 
    arch=LSTM, 
    arch_config=arch_config, 
    metrics=[mae, rmse], 
    cbs=ShowGraph(), 
    verbose=True)

learn.fit_one_cycle(100, lr_max=1e-3)
learn.plot_metrics()

This works good. What I want to do is Early Stopping during fit().
I found callbacks function 'TerminateOnNaNCallback()' for that in fastai and I applied it like below with import fastai.

learn.fit_one_cycle(100, lr_max=1e-3, cbs=TerminateOnNaNCallback())

But this does not work. If somebody knows, please let me know.
Thank you.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

无远思近则忧 2025-02-17 10:19:25

您只需导入并使用早期停止回调来自fastai with fastai with tsai库:

from fastai.callback.all import EarlyStoppingCallback

然后设置回调:

cbs = [EarlyStoppingCallback(), ShowGraph()]

您可以定义参数,例如监视哪个度量/损失,以及在未经改进的情况下终止了多少转。

You can just import and use the early stopping callback from fastai with the tsai library:

from fastai.callback.all import EarlyStoppingCallback

Then set your callbacks:

cbs = [EarlyStoppingCallback(), ShowGraph()]

You can define parameters such as which metric/loss is monitored and after how many turns of no improvement the training is terminated.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文