如何在 tf2.keras 中进行微调时冻结 BERT 的某些层

发布于 2025-01-11 11:33:21 字数 1663 浏览 1 评论 0原文

我正在尝试在数据集上微调“基于 bert-uncased”的文本分类任务。这是我下载模型的方式:

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

由于 bert-base 有 12 层,我想微调最后 2 层以防止过度拟合。 model.layers[i].trainable = False 不会有帮助。因为 model.layers[0] 给出了整个 bert 基础模型,如果我将 trainable 参数设置为 False,那么 bert 的所有层都会被冻结。这是 model 的架构:

Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bert (TFBertMainLayer)      multiple                  109482240 
                                                                 
 dropout_37 (Dropout)        multiple                  0         
                                                                 
 classifier (Dense)          multiple                  9997      
                                                                 
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________

另外,我想使用 model.layers[0].weights[j]._trainable = False;但是 weights 列表有 199 个元素,形状为 TensorShape([30522, 768])。所以我无法弄清楚哪些权重与最后两层相关。 任何人都可以帮我解决这个问题吗?

I am trying to fine-tune 'bert-based-uncased' on a dataset for a text classification task. Here is the way I am downloading the model:

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

As bert-base has 12 layers, I wanted to just fine-tune the last 2 layers to prevent overfitting. model.layers[i].trainable = False will not help. Because model.layers[0] gives the whole bert base model and if I set the trainable parameter to False, then all layers of bert will be frozen. Here is the architecture of model:

Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bert (TFBertMainLayer)      multiple                  109482240 
                                                                 
 dropout_37 (Dropout)        multiple                  0         
                                                                 
 classifier (Dense)          multiple                  9997      
                                                                 
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________

Also, I wanted to use model.layers[0].weights[j]._trainable = False; but weights list has 199 elements in shape of TensorShape([30522, 768]). So I could not figure out that which weights are related to the last 2 layers.
Can any-one help me to fix this?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

月隐月明月朦胧 2025-01-18 11:33:21

我找到了答案并在这里分享。希望它可以帮助其他人。
这篇文章的帮助下,这篇文章是关于使用pytorch微调bert的,相当于tensorflow2.keras 如下:

model.bert.encoder.layer[i].trainable = False

其中 i 是适当层的索引。

I found the answer and I share it here. Hope it can help others.
By the help of this article, which is about fine tuning bert using pytorch, the equivalent in tensorflow2.keras is as below:

model.bert.encoder.layer[i].trainable = False

where i is the index of the proper layer.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文