使用预取数据集用于多标签分类的混淆矩阵
我是Keras和机器学习的新手,我已经适应了此
在本教程中,创建了用于培训和测试的预取数据集,但它没有显示如何从预取数据集中创建混淆矩阵。我遵循此 https:> https: //www.pythonfixing.com/2021/11/fixed-how-to-plot-confusion-matrix-for.html 这也涉及多标签分类(带有3个类的IRIS数据集) 但是我遇到了一个错误:
ValueError Traceback (most recent call last)
<ipython-input-81-02116c95e29c> in <module>()
----> 1 confusion_matrix(predicted_categories, true_categories)
1 frames
/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py in
_check_targets(y_true, y_pred)
93 raise ValueError(
94 "Classification metrics can't handle a mix of {0} and {1}
targets".format(
---> 95 type_true, type_pred
96 )
97 )
ValueError: Classification metrics can't handle a mix of multiclass and multilabel-
indicator targets
我的代码在下面:
epochs = 20
shallow_mlp_model = make_model()
shallow_mlp_model.compile(
loss="binary_crossentropy", optimizer="adam", metrics=["categorical_accuracy"]
)
history = shallow_mlp_model.fit(
train_dataset, validation_data=validation_dataset, epochs=epochs
)
def plot_result(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_result("loss")
plot_result("categorical_accuracy")
_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.")
predictions=shallow_mlp_model.predict(test_dataset)
predicted_categories = tf.argmax(predictions, axis=1)
true_categories = tf.concat([y for x, y in test_dataset], axis=0)
confusion_matrix(predicted_categories, true_categories)
因此,在这种情况下,任何人都可以帮助我创建混乱矩阵吗?
I am new to keras and machine learning in general and I have adapted this keras tutorial to my case that involves text classification into 8 categories.
In this tutorial a prefetch dataset is created for training and testing, but it does not show how to create a confusion matrix from prefetch datasets. I followed the instructions on this https://www.pythonfixing.com/2021/11/fixed-how-to-plot-confusion-matrix-for.html which also involves multi label classification (iris dataset with 3 classes)
but I am getting an error:
ValueError Traceback (most recent call last)
<ipython-input-81-02116c95e29c> in <module>()
----> 1 confusion_matrix(predicted_categories, true_categories)
1 frames
/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py in
_check_targets(y_true, y_pred)
93 raise ValueError(
94 "Classification metrics can't handle a mix of {0} and {1}
targets".format(
---> 95 type_true, type_pred
96 )
97 )
ValueError: Classification metrics can't handle a mix of multiclass and multilabel-
indicator targets
My code is below:
epochs = 20
shallow_mlp_model = make_model()
shallow_mlp_model.compile(
loss="binary_crossentropy", optimizer="adam", metrics=["categorical_accuracy"]
)
history = shallow_mlp_model.fit(
train_dataset, validation_data=validation_dataset, epochs=epochs
)
def plot_result(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_result("loss")
plot_result("categorical_accuracy")
_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.")
predictions=shallow_mlp_model.predict(test_dataset)
predicted_categories = tf.argmax(predictions, axis=1)
true_categories = tf.concat([y for x, y in test_dataset], axis=0)
confusion_matrix(predicted_categories, true_categories)
So, anyone could help me to create a confusion matrix in this case?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
我找到了解决方案。我的真实标签仍然是一式编码。因此,我在变量“ true_categories”中使用了tf.argmax,并打印了混淆矩阵。
I have found the solution. My true labels were still one-hot-encoded. So, I used tf.argmax in the variable "true_categories" and the confusion matrix was printed.