pytorch NaNLabelEncoder 用于编码和解码分类目标

发布于 2025-01-13 06:42:28 字数 974 浏览 0 评论 0原文

我是 pytorch 的新手,但感觉这应该很简单。如何对这个张量进行逆变换?

classification_dataset = TimeSeriesDataSet(
    df,
    group_ids=['group'],
    target="target_col",  # categorical target
    time_idx="time_idx",
    min_encoder_length= 60 * 60, # how much history to use
    max_encoder_length= 60 * 60,
    min_prediction_length=5,
    max_prediction_length=5,  # how far to predict into future
    time_varying_unknown_reals=[
        
        #...list of columns here
    ],
    #time_varying_unknown_reals=[time_varying_unknown_reals[0]],
    target_normalizer=NaNLabelEncoder(),  # Use the NaNLabelEncoder to encode categorical target
)

x, y = next(iter(classification_dataset.to_dataloader(batch_size=4)))
y[0]  # target values are encoded categories

输出

tensor([[6, 6, 6, 6, 6],
        [5, 5, 5, 5, 5],
        [5, 5, 5, 5, 5],
        [1, 1, 1, 1, 1]])

classification_dataset.target_normalizer返回NaNLabelEncoder()但它不适合。

I'm new to pytorch but it feels like this should be simple. How do I inverse transform this tensor?

classification_dataset = TimeSeriesDataSet(
    df,
    group_ids=['group'],
    target="target_col",  # categorical target
    time_idx="time_idx",
    min_encoder_length= 60 * 60, # how much history to use
    max_encoder_length= 60 * 60,
    min_prediction_length=5,
    max_prediction_length=5,  # how far to predict into future
    time_varying_unknown_reals=[
        
        #...list of columns here
    ],
    #time_varying_unknown_reals=[time_varying_unknown_reals[0]],
    target_normalizer=NaNLabelEncoder(),  # Use the NaNLabelEncoder to encode categorical target
)

x, y = next(iter(classification_dataset.to_dataloader(batch_size=4)))
y[0]  # target values are encoded categories

output

tensor([[6, 6, 6, 6, 6],
        [5, 5, 5, 5, 5],
        [5, 5, 5, 5, 5],
        [1, 1, 1, 1, 1]])

classification_dataset.target_normalizer returns NaNLabelEncoder() but it's not fitted.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

Saygoodbye 2025-01-20 06:42:28

啊,就这么简单

classification_dataset.target_normalizer.classes_

ah it's as simple as

classification_dataset.target_normalizer.classes_
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文