无效参数:logits 和标签必须可广播:logits_size=[16,8] labels_size=[16,4]

发布于 2025-01-18 15:07:59 字数 1978 浏览 1 评论 0原文

我正在进行一个用于分类玉米疾病图像的CNN项目(4类),它使用VGG16作为基本模型。我创建并保存了模型。现在,是否可以将该模型用作另一个转移学习任务的基础,以对棉叶疾病图像进行分类(4个类),并保留从玉米病图像中获得的知识以及棉叶病图像?如果是这样,我应该如何修改玉米病模型。我是否需要使输出层神经元为8(棉4,4用于玉米病)? 这是我使用玉米植物模型作为基础的棉花植物CNN的代码。我通过删除最后两层(输出和致密层),然后添加了新的致密层,然后用8个神经元添加了新的致密层,但是当我训练模型时,我会发现logits和标签必须可以广播的错误。

from tensorflow.keras.models import load_model
savedmodel = load_model('corn.h5')
train_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
val_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
traindata = train_gen.flow_from_directory('/Users/saibalaji/Desktop/data/train/',target_size=(224,224),batch_size=16)

class_labels = []
for class_label,class_mode in traindata.class_indices.items():
    print(class_label)
    class_labels.append(class_label)
nmodel = tf.keras.Sequential()

for layer in savedmodel.layers[0:-1]:
    print(layer)
    nmodel.add(layer)
for layer in nmodel.layers:
    layer.trainable = False

nmodel.add(tf.keras.layers.Dense(units=15,activation='relu',name='dense_3'))
nmodel.add(tf.keras.layers.Dense(units=8,name='cf',activation='softmax'))
nmodel.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
nmodel.fit(traindata,epochs=5)

这是我对玉米叶病图像的保存模型摘要(4类),我在这里使用了特征提取器传输学习,这是通过Reatining Conv2d和VGG16的Maxpool层进行的。

我应该如何通过使用转移学习来保留从玉米植物中获得的知识来修改棉叶疾病的模型。

这是我用于棉花植物疾病的改良模型。

但是我得到了这个错误

I’m doing a CNN project for classification of Corn disease images(4 classes) It uses VGG16 as its base model. I have created and saved the model. Now is it possible to use that model as a base for another transfer learning task to classify cotton leaf disease images( 4 classes) with retaining the knowledge gained from corn disease images along with cotton leaf disease images? If so how should I modify the corn disease model. Should I need to make the output layer neurons as 8( 4 for cotton, 4 for corn disease) ?
Here is my code for Cotton plant CNN using corn plant model as base. I tried it by removing last two layers (output and dense layer) then added new dense layer followed by output layer with 8 neurons but when I train the model I get error that logits and labels must be broadcastable.

from tensorflow.keras.models import load_model
savedmodel = load_model('corn.h5')
train_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
val_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
traindata = train_gen.flow_from_directory('/Users/saibalaji/Desktop/data/train/',target_size=(224,224),batch_size=16)

class_labels = []
for class_label,class_mode in traindata.class_indices.items():
    print(class_label)
    class_labels.append(class_label)
nmodel = tf.keras.Sequential()

for layer in savedmodel.layers[0:-1]:
    print(layer)
    nmodel.add(layer)
for layer in nmodel.layers:
    layer.trainable = False

nmodel.add(tf.keras.layers.Dense(units=15,activation='relu',name='dense_3'))
nmodel.add(tf.keras.layers.Dense(units=8,name='cf',activation='softmax'))
nmodel.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
nmodel.fit(traindata,epochs=5)

Here is my saved model summary for corn leaf disease images(4 classes) here I used feature extractor transfer learning by reatining conv2d and maxpool layers of vgg16.
enter image description here

How should I need to modify the model for cotton leaf disease by retaining the knowledge gained from corn plant by using transfer learning.

And here is my modified model for Cotton plant disease.
enter image description here

But i get this error
enter image description here

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

桃扇骨 2025-01-25 15:07:59

在玉米图像上训练的模型的知识属于模型权重。只需加载模型,然后将其训练在棉花图像上。

The knowledge of the model trained on corn images is resident in the model weights. Just load the model, then train it on the cotton images.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文