如何显示用于聚类图像的高斯混合模型?

发布于 2025-02-08 20:14:57 字数 2246 浏览 1 评论 0原文

我使用附件代码来获取一些图像的GMM。我还想在图像的直方图上显示GMM。我已经做到了。但是,我也想显示GMM簇分布。我将GMM的输出附加在直方图上,以及我想要获得的另一个图像。

谢谢在此处输入图像描述

            # Code for GMM
            
            import os
            import matplotlib.pyplot as plt
            import numpy as np
            import cv2
            
            
            img = cv2.imread("test.jpg")
            
            #Convert MxNx3 image into Kx3 where K=MxN
            img2 = img.reshape((-1,3))  #-1 reshape means, in this case MxN
            
            from sklearn.mixture import GaussianMixture as GMM
            
            #covariance choices, full, tied, diag, spherical
            gmm_model = GMM(n_components=6, covariance_type='full').fit(img2)  #tied works better than full
            gmm_labels = gmm_model.predict(img2)
            
            #Put numbers back to original shape so we can reconstruct segmented image
            original_shape = img.shape
            segmented = gmm_labels.reshape(original_shape[0], original_shape[1])
            cv2.imwrite("test_segmented.jpg", segmented)
            
            
            gmm_model.means_
            
            gmm_model.covariances_
            
            gmm_model.weights_
            
            print(gmm_model.means_, gmm_model.covariances_, gmm_model.weights_)
            
            data = img2.ravel()
            data = data[data != 0]
            data = data[data != 1]  #Removes background pixels (intensities 0 and 1)
            gmm = GMM(n_components = 6)
            gmm = gmm.fit(X=np.expand_dims(data,1))
            gmm_x = np.linspace(0,255,256)
            gmm_y = np.exp(gmm.score_samples(gmm_x.reshape(-1,1)))
            
            
            #Plot histograms and gaussian curves
            fig, ax = plt.subplots()
            ax.hist(img.ravel(),255,[2,256], density=True, stacked=True)
            ax.plot(gmm_x, gmm_y, color="crimson", lw=2, label="GMM")
            
            ax.set_ylabel("Frequency")
            ax.set_xlabel("Pixel Intensity")
            
            plt.legend()
            plt. grid(False)
            
            plt.show()

I used the attached code to get the GMM for some images. I also want to show the GMM on the histogram of the image. I already did that. However, I also wanna show the GMM clusters distribution. I attached the output of the GMM on the histogram and another image of what I wanna get.

Thanksenter image description here

            # Code for GMM
            
            import os
            import matplotlib.pyplot as plt
            import numpy as np
            import cv2
            
            
            img = cv2.imread("test.jpg")
            
            #Convert MxNx3 image into Kx3 where K=MxN
            img2 = img.reshape((-1,3))  #-1 reshape means, in this case MxN
            
            from sklearn.mixture import GaussianMixture as GMM
            
            #covariance choices, full, tied, diag, spherical
            gmm_model = GMM(n_components=6, covariance_type='full').fit(img2)  #tied works better than full
            gmm_labels = gmm_model.predict(img2)
            
            #Put numbers back to original shape so we can reconstruct segmented image
            original_shape = img.shape
            segmented = gmm_labels.reshape(original_shape[0], original_shape[1])
            cv2.imwrite("test_segmented.jpg", segmented)
            
            
            gmm_model.means_
            
            gmm_model.covariances_
            
            gmm_model.weights_
            
            print(gmm_model.means_, gmm_model.covariances_, gmm_model.weights_)
            
            data = img2.ravel()
            data = data[data != 0]
            data = data[data != 1]  #Removes background pixels (intensities 0 and 1)
            gmm = GMM(n_components = 6)
            gmm = gmm.fit(X=np.expand_dims(data,1))
            gmm_x = np.linspace(0,255,256)
            gmm_y = np.exp(gmm.score_samples(gmm_x.reshape(-1,1)))
            
            
            #Plot histograms and gaussian curves
            fig, ax = plt.subplots()
            ax.hist(img.ravel(),255,[2,256], density=True, stacked=True)
            ax.plot(gmm_x, gmm_y, color="crimson", lw=2, label="GMM")
            
            ax.set_ylabel("Frequency")
            ax.set_xlabel("Pixel Intensity")
            
            plt.legend()
            plt. grid(False)
            
            plt.show()

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文