过度拟合LSTM Pytorch

发布于 2025-02-13 07:10:12 字数 1808 浏览 1 评论 0原文

我正在关注 coderzcolumn“> coderzcolumn”>使用Pytorch实施LSTM用于文本分类。我试图在 bbc-news dataset dataset,但是,但是,从kaggle ,但是,它大量过度,达到最大准确度约为60%。

例如,请参见火车/损耗曲线:

”在此处输入图像描述

是否有任何建议(我是新手rnn/lstm),以适应该模型以防止高效果?

该模型取自上述教程,看起来有点像这样:

class LSTMClassifier(nn.Module):
    def __init__(self, vocab, target_classes, embed_len = 50, hidden_dim=75, n_layers=1):
        super(LSTMClassifier, self).__init__()
        self.n_layers = n_layers
        self.embed_len = embed_len
        self.hidden_dim = hidden_dim
        self.embedding_layer = nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_len)
      #  self.lstm = nn.LSTM(input_size=embed_len, hidden_size=hidden_dim,dropout=0.2, num_layers=n_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size=embed_len, hidden_size=hidden_dim, num_layers=n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, len(target_classes))

    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)
        hidden, carry = torch.randn(self.n_layers, len(X_batch), self.hidden_dim), torch.randn(self.n_layers, len(X_batch), self.hidden_dim)
        output, (hidden, carry) = self.lstm(embeddings, (hidden, carry))

        return self.fc(output[:,-1])

我真的很感谢您如何适应教程中的版本以在其他数据集中更有效地使用它

I was following the tutorial on CoderzColumn to implement a LSTM for text classification using pytorch. I tried to apply the implementation on the bbc-news Dataset from Kaggle, however, it heavily overfits, achieving a max accuracy of about 60%.

See the train/loss curve for example:

enter image description here

Is there any advice (I am quite new to RNN/LSTM), to adapt the model to prevent that high overfiting?

The model is taken from the above tutorial and looks kind of like this:

class LSTMClassifier(nn.Module):
    def __init__(self, vocab, target_classes, embed_len = 50, hidden_dim=75, n_layers=1):
        super(LSTMClassifier, self).__init__()
        self.n_layers = n_layers
        self.embed_len = embed_len
        self.hidden_dim = hidden_dim
        self.embedding_layer = nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_len)
      #  self.lstm = nn.LSTM(input_size=embed_len, hidden_size=hidden_dim,dropout=0.2, num_layers=n_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size=embed_len, hidden_size=hidden_dim, num_layers=n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, len(target_classes))

    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)
        hidden, carry = torch.randn(self.n_layers, len(X_batch), self.hidden_dim), torch.randn(self.n_layers, len(X_batch), self.hidden_dim)
        output, (hidden, carry) = self.lstm(embeddings, (hidden, carry))

        return self.fc(output[:,-1])

I would be really thankful for any adive how to adapt the version in the tutorial to use it more effectively on other datasets

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

感情废物 2025-02-20 07:10:12

您是否尝试过在self.fc之前添加nn.dropout图层?
检查什么p = 0.1 / 0.2 / 0.3 < / code>会做什么。

您可以做的另一件事是通过stoge_decay参数添加正则化:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) 

首先使用小值,然后增加10次,请参阅这将使您获得最佳结果。

另外,不用说,请确保火车集中没有测试数据点。确保您没有忘记将火车套装洗牌:

train_loader = DataLoader(train_dataset, batch_size=1024, collate_fn=vectorize_batch, shuffle=True)

Have you tried adding nn.Dropout layer before the self.fc?
Check what p = 0.1 / 0.2 / 0.3 will do.

Another thing you can do is to add regularisation to your training via weight_decay parameter:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) 

Use small values first, and increase by 10 times, see which will get you the best result.

Also, goes without saying, make sure that there is no test data points in train set. Make sure you did not forget to shuffle your train set:

train_loader = DataLoader(train_dataset, batch_size=1024, collate_fn=vectorize_batch, shuffle=True)
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文