如何在python中将树状图转换为树对象?

Chr*_*uri 5 python hierarchical-clustering dendrogram scipy

我正在尝试使用 scipy.hierarchy.cluster 模块对某些文本进行分层聚类。我做了以下工作:

l = linkage(model.wv.syn0, method='complete', metric='cosine')

den = dendrogram(
l,
leaf_rotation=0.,  
leaf_font_size=16.,  
orientation='left',
leaf_label_func=lambda v: str(model.wv.index2word[v])
Run Code Online (Sandbox Code Playgroud)

dendrogram 函数返回一个包含树的表示的字典,其中:

den['ivl'] 是与叶子对应的标签列表:

['politics', 'protest', 'characterfirstvo', 'machine', 'writing', 'learning', 'healthcare', 'climate', 'of', 'rights', 'activism', 'resistance', 'apk', 'week', 'challenge', 'water', 'obamacare', 'colorado', 'change', 'voiceovers', '52', 'acting', 'android']
Run Code Online (Sandbox Code Playgroud)

den['leaves'] 是每个叶子在叶子从左到右遍历中的位置列表:
[0, 18, 5, 6, 2, 7, 12, 16, 21, 20, 22, 3, 10, 14, 15, 19, 11, 1, 17, 4, 13, 8, 9]

我知道 scipy 的to_tree()方法通过返回对根节点(一个 ClusterNode 对象)的引用,将由链接矩阵表示的层次聚类转换为树对象 - 但我不确定这个根节点如何对应我的叶子/标签。例如,get_id()本例中方法返回的 id为root = 44, left = 41, right = 43

rootnode, nodelist = to_tree(l, rd=True)
rootID = rootnode.get_id()
leftID = rootnode.get_left().get_id()
rightID = rootnode.get_right().get_id()
Run Code Online (Sandbox Code Playgroud)

我的问题本质上是,如何遍历这棵树并为每个 ClusterNode获取相应的位置den['leaves']和标签den['ivl']

预先感谢您的任何帮助!

作为参考,这是链接矩阵 l:

[[20.         22.          0.72081252  2.        ]
[12.         16.          0.78620636  2.        ]
[ 3.         10.          0.79635815  2.        ]
[ 0.         18.          0.80193474  2.        ]
[15.         19.          0.82297097  2.        ]
[ 2.          7.          0.84152483  2.        ]
[ 1.         17.          0.84453892  2.        ]
[ 4.         13.          0.86098654  2.        ]
[ 8.          9.          0.88163748  2.        ]
[14.         27.          0.91252009  3.        ]
[11.         29.          0.92034739  3.        ]
[21.         23.          0.92406542  3.        ]
[ 5.          6.          0.93213108  2.        ]
[25.         32.          0.98555722  5.        ]
[26.         35.          0.99214198  4.        ]
[30.         31.          1.05624908  4.        ]
[24.         34.          1.0606247   5.        ]
[28.         39.          1.06322889  7.        ]
[37.         40.          1.1455562  11.        ]
[33.         38.          1.15171714  7.        ]
[36.         42.          1.17330334 12.        ]
[41.         43.          1.25056073 23.        ]]
Run Code Online (Sandbox Code Playgroud)

小智 0

您不需要树状图来遍历聚类树。假设你有linkage_matrix和cluster_ids数组(scipy.cluster.hierarchy.fcluster方法的输出),你可以使用get_node函数来获取对应于给定cluster_id的聚类树的节点:

import numpy as np
from scipy.cluster.hierarchy import leaders, ClusterNode, to_tree
from typing import Optional, List


def get_node(
    linkage_matrix: np.ndarray,
    clusters_array: np.ndarray,
    cluster_num: int
) -> ClusterNode:
    """
    Returns ClusterNode (the node of the cluster tree) corresponding to the given cluster number.
    :param linkage_matrix: linkage matrix
    :param clusters_array: array of cluster numbers for each point
    :param cluster_num: id of cluster for which we want to get ClusterNode
    :return: ClusterNode corresponding to the given cluster number
    """
    L, M = leaders(linkage_matrix, clusters_array)
    idx = L[M == cluster_num]
    tree = to_tree(linkage_matrix)
    result = search_for_node(tree, idx)
    assert result
    return result


def search_for_node(
    cur_node: Optional[ClusterNode],
    target: int
) -> Optional[ClusterNode]:
    """
    Searches for the node with the given id of the cluster in the given subtree.
    :param cur_node: root of the cluster subtree to search for target node
    :param target: id of the target node (cluster)
    :return: ClusterNode with the given id if it exists in the subtree, None otherwise
    """
    if cur_node is None:
        return False
    if cur_node.get_id() == target:
        return cur_node
    left = search_for_node(cur_node.get_left(), target)
    if left:
        return left
    return search_for_node(cur_node.get_right(), target)
Run Code Online (Sandbox Code Playgroud)

要获取属于当前集群节点的所有样本,您应该只获取所有后代叶节点:

def get_leaves_ids(node: ClusterNode) -> List[int]:
    """
    Returns ids of all samples (leaf nodes) that belong to the given ClusterNode (belong to the node's subtree).
    :param node: ClusterNode for which we want to get ids of samples
    :return: list of ids of samples that belong to the given ClusterNode
    """
    res = []

    def dfs(cur: Optional[ClusterNode]):
        if cur is None:
            return
        if cur.is_leaf():
            res.append(cur.get_id())
            return
        dfs(cur.get_left())
        dfs(cur.get_right())
    dfs(node)
    return res
Run Code Online (Sandbox Code Playgroud)

遍历树,寻找兄弟姐妹、祖先、后代与普通树的情况非常相似。叶子的 ID 实际上是数据集中样本的 ID,非终端节点的 ID 可以使用领导者函数映射到集群的 ID(请参阅 get_node 实现)。