如何自动化颜色的数量和从树状图聚集层次结构返回的颜色?

发布于 2025-01-19 01:15:53 字数 1889 浏览 5 评论 0原文

scikit-learn 给出 生成树状图的 python 代码示例。我复制/粘贴下面的代码。此代码生成树状图。该树状图显示 3 种不同的颜色:蓝色、绿色和橙色。

问题:与该树状图代码示例相关的哪些代码可以自动提供:

  • 由树状图生成的颜色数量
  • 这些颜色的列表(或其代码)?
import numpy as np

from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering


def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


iris = load_iris()
X = iris.data

# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(X)
plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="level", p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()

scikit-learn gives an example of python code to generate a dendogram. I copy/paste this code bellow. This code generates a dendogram. This dendogram display 3 differents colors: blue, green, and orange.

Question: which code associated with this dendogram code example, could automaticaly deliver:

  • the number of colors generated by the dendrogram ?
  • the list of those of those colors (or their code number) ?
import numpy as np

from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering


def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


iris = load_iris()
X = iris.data

# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(X)
plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="level", p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()

dendrogram image returned by the above code

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

浅听莫相离 2025-01-26 01:15:53

如果您在此处阅读文档 ,颜色数量由color_threshold决定,默认为0.7*max(Z[:,2])。因此,您只需找到高于该值的合并数量:

首先修改代码以获取链接矩阵:

def get_linkage(model):
    # Create linkage matrix

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)
    return linkage_matrix


iris = load_iris()
X = iris.data

# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(X)
linkage_matrix = get_linkage(model)

然后从中计算颜色数量:

from scipy.cluster.hierarchy import cut_tree
color_threshold = 0.7 * max(linkage_matrix[:, 2])
n_color = 1 + len(np.unique(cut_tree(linkage_matrix, height = color_threshold)))
color_codes = ['C' + str(i) for i in range(n_color)] # this is simply the matplotlib default color code

If you read the documentation here, the number of colors is determined by color_threshold, which is defaulted to 0.7*max(Z[:,2]). So you only have to find the number of merges higher than that:

First modify your code to get the linkage matrix:

def get_linkage(model):
    # Create linkage matrix

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)
    return linkage_matrix


iris = load_iris()
X = iris.data

# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

model = model.fit(X)
linkage_matrix = get_linkage(model)

Then calculate the number of colors from it:

from scipy.cluster.hierarchy import cut_tree
color_threshold = 0.7 * max(linkage_matrix[:, 2])
n_color = 1 + len(np.unique(cut_tree(linkage_matrix, height = color_threshold)))
color_codes = ['C' + str(i) for i in range(n_color)] # this is simply the matplotlib default color code
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文