返回介绍

数学基础

统计学习

深度学习

工具

Scala

二、GNNEXPLAINER [2019]

发布于 2023-07-17 23:38:24 字数 148889 浏览 0 评论 0 收藏 0

  1. 图是强大的数据表示形式,而图神经网络Graph Neural Network: GNN 是处理图的最新技术。 GNN 能够递归地聚合图中邻域节点的信息,从而自然地捕获图结构和节点特征。

    尽管GNN 效果很好,但是其可解释性较差。由于以下几个原因,GNN 的可解释性非常重要:

    • 可以提高对 GNN 模型的信任程度。
    • 在越来越多有关公平fairness 、隐私privacy 、以及其它安全挑战safety challenge 的关键决策应用application 中,提高模型的透明度 transparency
    • 允许从业人员理解模型特点,从而在实际环境中部署之前就能识别并纠正模型的错误。

    尽管目前尚无用于解释 GNN 的方法。但是在更高层次,我们可以将包括 GNNNon-GNN 的可解释性方法分为两个主要系列:

    • 使用简单的代理模型 surrogate model 局部逼近locally approximate 完整模型full model,然后探索这个代理模型以进行解释。

      这可能是以模型无关model-agnostic 的方式来完成的,通常是通过线性模型 linear model 或者规则集合 set of rules 来学习预测的局部近似,可以充分代表预测结果。

    • 仔细检查模型的相关特征relevant feature,并找到high-level 特征的良好定性解释good qualitative interpretation ,或者识别有影响力的输入样本influential input instance

      如通过特征梯度 feature gradients、神经元反向传播中输入特征的贡献、反事实推理counterfactual reasoning 等。

    上述两类方法专注于研究模型的固有解释,而后验post-hoc 的可解性方法将模型视为黑盒,然后对其进行探索从而获得相关信息。

    但是所有这些方法无法融合关系信息,即图的结构信息。由于关系信息对于图上机器学习任务的成功至关重要,因此对 GNN 预测的任何解释都应该利用图提供的丰富关系信息、以及节点特征。论文 《GNNExplainer: Generating Explanations for Graph Neural Networks》 提出了 一种解释 GNN 预测的方法,称作 GNNEXPLAINERGNNEXPLAINER 接受训练好的GNN 及其预测,返回对于预测最有影响力的、输入图的一个小的子图small subgraph ,以及节点特征的一个小的子集small subset

    如下图所示,这里展示了一个节点分类任务,其中在社交网络上训练了 GNN 模型Φ$ \mathbf\Phi $ 从而预测体育活动。给定训练好的GNNΦ$ \mathbf\Phi $ 以及对节点vi$ v_i $ 的预测y^i=basketball$ \hat y_i= \text{basketball} $ ,GNNEXPLAINER 通过识别对预测y^i$ \hat y_i $ 有影响力的输入图的一个小的子图、以及节点特征的一个小的子集来生成解释explanation ,如右图所示。

    • 通过检查y^i$ \hat y_i $ 的解释 explanation,我们发现vi$ v_i $ 社交圈中很多朋友都喜欢玩球类游戏,因此 GNN 预测vi$ v_i $ 可能会喜欢篮球。
    • 同样地,通过检查y^j$ \hat y_j $ 的解释 explanation我们发现vj$ v_j $ 社交圈中很多朋友都喜欢水上运动和沙滩运动,因此 GNN 预测y^j=sailing$ \hat y_j=\text{sailing} $ (帆船运动)。

    GNNEXPLAINER 方法和 GNN 模型无关,可以解释任何 GNN 在图的任何机器学习任务上的预测,包括节点分类、链接预测、图分类任务。它可以解释单个实例的预测 single-instance explanation ,也可以解释一组实例的预测 multi-instance explanation

    • 在单个实例预测的情况下,GNNEXPLAINER 解释了 GNN 对于特定样本的预测。
    • 在多个实例预测的情况下,GNNEXPLAINER 提供了对于一组样本(如类别为“篮球”的所有节点)的预测的一致性解释(即这些预测的共同的解释)。

    GNNEXPLAINER 将解释指定为训练 GNN 的整个输入图的某个子图,其中子图基于 GNN 的预测最大化互信息mutual information 。这是通过平均场变分近似 mean field variational approximation ,以及学习一个实值graph mask 来实现的。这个 graph mask 选择了 GNN 输入图的最重要子图。同时,GNNEXPLAINER 还学习了一个 feature mask,它可以掩盖不重要的节点特征。

    论文在人工合成图、以及真实的图上评估GNNEXPLAINER 的效果。实验表明:GNNEXPLAINERGNN 的预测提供了一致而简洁的解释。

  2. 虽然解释GNN 问题没有得到很好的研究,但相关的可解释性 interpretability 和神经调试neural debugging 的问题在机器学习中得到了大量的关注。在high-level 上,我们可以将那些 non-graph neural network 的可解释性方法分为两个主要方向:

    • 第一个方向的方法制定了完整神经网络full neural network 的简单代理模型。这可以通过模型无关model-agnostic 的方式完成,通常是通过学习 prediction 周围的局部良好近似(例如通过线性模型或规则集合),代表预测的充分条件sufficient condition
    • 第二个方向方法确定了计算的重要方面,例如,特征梯度feature gradient、神经元的反向传播对输入特征的贡献、以及反事实推理。然而,这些方法产生的显著性映射 saliency map 已被证明在某些情况下具有误导性,并且容易出现梯度饱和等问题。这些问题在离散输入(如图邻接矩阵)上更加严重,因为梯度值可能非常大而且位于一个非常小的区域interval 。正因为如此,这种方法不适合用来解释神经网络在图上的预测。

    事后可解释性post-hoc interpretability 方法不是创建新的、固有可解释的模型,而是将模型视为黑箱,然后探测模型的相关信息。然而,还没有利用关系型结构(如graph)方面的工作。解释图结构数据预测的方法的缺乏是有问题的,因为在很多情况下,图上的预测是由节点和它们之间的边的路径的复杂组合引起的。例如,在一些任务中,只有当图中存在另一条替代路径形成一个循环时,一条边才是重要的,而这两个特征,只有在一起考虑时,才能准确预测节点标签。因此,它们的联合贡献不能被建模为单个贡献的简单线性组合。

    最后,最近的GNN 模型通过注意力机制增强了可解释性。然而,尽管学到的edge attention 值可以表明重要的图结构,但这些值对所有节点的预测都是一样的。因此,这与许多应用相矛盾,在这些应用中,一条边对于预测一个节点的标签是至关重要的,但对于另一个节点的标签则不是。此外,这些方法要么仅限于特定的GNN架构,要么不能通过共同考虑图结构和节点特征信息来解释预测结果。

  3. GNNEXPLAINER 提供了各种好处 ,包括可视化语义相关结构以进行解释的能力、以及提供洞察GNN 的错误的能力。

2.1 方法

  1. 定义图G=(V,E,X)$ \mathcal G=(\mathcal V, \mathcal E, \mathbf X) $ ,其中:

    • V={v1,,vn}$ \mathcal V=\{v_1,\cdots,v_n\} $ 为节点集合,节点数量为n$ n $ 。
    • E={ei,j}$ \mathcal E=\{e_{i,j}\} $ 为边集合,ei,j=(vi,vj)$ e_{ i,j } = (v_i,v_j) $ ,边数量为m$ m $ 。
    • 每个节点vi$ v_i $ 关联一个特征向量xiRd$ \mathbf{\vec x}_i\in \mathbb R^{d} $ ,所有节点的特征向量拼接为特征矩阵XRn×d$ \mathbf X\in \mathbb R^{n\times d} $ ,d$ d $ 为特征向量维度。

    不失一般性,我们考虑节点分类问题的可解释性。定义f$ f $ 为一个节点到类别标签的映射函数:f:V{1,,C}$ f:\mathcal V\rightarrow \{1,\cdots,C\} $ ,其中C$ C $ 为类别数量。GNN 模型Φ$ \mathbf\Phi $ 在训练集的所有节点上优化从而逼近f$ f $ ,然后用于预测未标记节点的label

  2. 我们假设GNN 模型Φ$ \mathbf\Phi $ 采用消息传递机制。在第l$ l $ 层,GNN 模型Φ$ \mathbf \Phi $ 的更新涉及三个关键计算:

    • 首先,模型计算每对节点pair 对之间传递的消息。节点pair(vi,vj)$ (v_i,v_j) $ 之间的消息定义为:

      mi,j(l)=MSG(hi(l1),hj(l1),ri,j)

      其中:MSG(.) 为消息函数;hi(l1)$ \mathbf{\vec h}_i^{(l-1)} $ 为节点vi$ v_i $ 在第l1$ l-1 $ 层的 representationri,j$ r_{i,j} $ 为节点vi$ v_i $ 和vj$ v_j $ 之间的关系。

    • 然后,对于每个节点vi$ v_i $ ,GNN 聚合来自其邻域的所有消息:

      ai(l)=AGG({mi,j(l)vjNi})

      其中:AGG(.) 为一个邻域聚合函数;Ni$ \mathcal N_i $ 为节点vi$ v_i $ 的邻域集合。

    • 最后,对于每个节点vi$ v_i $ ,GNN 根据ai(l),hi(l1)$ \mathbf{\vec a}_i^{(l)} , \mathbf{\vec h}_i^{(l-1)} $ 来计算vi$ v_i $ 在第l$ l $ 层的 representation

      hi(l)=UPDATE(ai(l),hi(l1))

      其中 UPDATE(.) 为节点状态更新函数。

    最终节点vi$ v_i $ 的embeddingvi$ v_i $ 在第L$ L $ 层输出的 representation

    zi=hi(L)

    对于采用 MSG,AGG,UPDATE 计算组成的任何 GNN ,我们的 GNNEXPLAINER 可以提供解释。

  3. 我们的洞察insight 是观察到:节点v$ v $ 的计算图computation graph 是由 GNNneighborhood-based 聚合来定义,如下图所示。这个计算图完全决定了用于生成节点v$ v $ 的预测y^$ \hat y $ 的所有信息。具体而言,节点v$ v $ 的计算图告诉 GNN 如何生成节点v$ v $ 的 embeddingz$ \mathbf{\vec z} $ 。

    定义节点v$ v $ 的计算图 computation graphGc(v)$ \mathcal G_c(v) $ ,它关联一个二元binary 邻接矩阵Ac(v)Rn×n$ \mathbf A_c(v)\in \mathbb R^{n\times n} $ ,其中Ac(v)[i,j]{0,1}$ A_c(v)[i,j]\in \{0,1\} $ 取值为01 ; 也关联一个特征矩阵Xc(v)={xjvjGc(v)}$ \mathbf X_c(v) = \{\mathbf{\vec x}_j\mid v_j\in \mathcal G_c(v)\} $ 。

    GNN 模型Φ$ \mathbf\Phi $ 学习一个条件分布PΦ(YvGc(v),Xc(v))$ P_{\mathbf \Phi}\left( Y_v\mid \mathcal G_c(v),\mathbf X_c(v)\right) $ ,表示给定节点计算图Gc(v)$ \mathcal G_c(v) $ 、节点特征矩阵Xc(v)$ \mathbf X_c(v) $ 的条件下,节点v$ v $ 属于各类别的概率。其中Yv{1,,C}$ Y_v\in \{1,\cdots,C\} $ 为一个随机变量。

    一旦GNN 模型学到这样的分布之后,对于节点v$ v $ ,GNN 的类别预测结果为y^=Φ(Gc(v),Xc(v))$ \hat y = \mathbf\Phi(\mathcal G_c(v),\mathbf X_c(v)) $ ,意味着它完全由三个因素决定:模型Φ$ \mathbf\Phi $ 、图结构信息Gc(v)$ \mathcal G_c(v) $ 、节点特征信息Xc(v)$ \mathbf X_c(v) $ 。这个观察结果意味着我们只需要考虑图结构Gc(v)$ \mathcal G_c(v) $ 和节点特征Xc(v)$ \mathbf X_c(v) $ 来解释y^$ \hat y $ ,如下图 A 所示。

    正式地讲,GNNEXPLAINER 为预测y^$ \hat y $ 生成解释explanation,记作(GS,XSF)$ \left(\mathcal G_S,\mathbf X_S^F\right) $ 。其中:

    • GS$ \mathcal G_S $ 为计算图的一个小的子图small subgraph ,如图 A 所示。

    • XS$ \mathbf X_S $ 为GS$ \mathcal G_S $ 关联的特征,XSF$ \mathbf X_S^F $ 为节点特征的一个小的特征子集 small subsetF 表示通过 mask F 来遮盖,即:XSF={xjFvjGS}$ \mathbf X_S^F = \{\mathbf{\vec x}_j^F\mid v_j\in \mathcal G_S\} $ ,如图B 所示。

      假设原始的节点特征集合为A={A1,,Ad}$ \mathbb A=\{\mathcal A_1,\cdots,\mathcal A_d\} $ ,则经过 mask F 遮盖之后的特征集合为:

      AF={At1,,AtF},1t1tFd

      它是原始特征集合的一个小的特征子集,且有:xjF=(xj,t1,xj,t2,,xj,tF)RF$ \mathbf{\vec x}_j^F=(x_{j,t_1},x_{j,t_2},\cdots,x_{j,t_F})^\top\in \mathbb R^F $ 为遮盖后的特征向量。

    下图中:

    • A 给出了一个 GNN 在节点v$ v $ 处的计算图Gc(v)$ \mathcal G_c(v) $ ,它用于得到节点v$ v $ 的类别预测y^$ \hat y $ 。

      Gc(v)$ \mathcal G_c(v) $ 中的某些边构成重要的消息传播路径(绿色),这些路径允许有用的节点消息跨Gc(v)$ \mathcal G_c(v) $ 传播并在节点v$ v $ 处聚合从而进行预测;相反,Gc(v)$ \mathcal G_c(v) $ 中的另一些边不重要(橙色)。但是无论消息重不重要,在节点v$ v $ 处GNN 都会具有所有消息(包括不重要的消息)从而进行预测,这可能会稀释重要的消息。

      GNNEXPLAINER 的目标是识别少量对于预测至关重要的重要特征和路径(绿色)。

    • B 表示 GNNEXPLAINER 通过学习节点特征mask 来确定GS$ \mathcal G_S $ 中节点的那些特征维度对于预测至关重要。

  4. 接下来我们详细描述 GNNEXPLAINER。给定训练好的 GNN 模型Φ$ \mathbf\Phi $ 以及一个预测 prediction (即单实例解释single-instance explanation)、或者一组预测(即多实例解释multi-instance explanation ), GNNEXPLAINER 将通过识别对模型Φ$ \mathbf\Phi $ 的预测影响最大的计算图的子图、节点特征的子集从而生成解释。

    在多实例解释中,GNNEXPLAINER 将每个实例的解释聚合在一起并自动抽取为一个原型proto。这个原型代表每个实例解释的公共部分,即proto 可以对所有这些实例进行解释。

2.1.1 单实例解释

  1. 给定一个节点v$ v $ ,我们的目标是识别对于该节点的 GNN 预测y^$ \hat y $ 很重要的子图GSGc$ \mathcal G_S\sube \mathcal G_c $ ,以及关联特征矩阵XS={xjvjGS}$ \mathbf X_S=\{ \mathbf{\vec x}_j \mid v_j\in \mathcal G_S\} $ 。现在我们不考虑特征 mask,这留待下一步讨论。

    我们使用互信息mutual information:MI 来刻画子图的重要性,并将 GNNEXPLAINER 形式化为以下最优化问题:

    maxGSMI(Y,(GS,XS))=H(Y)H(YG=GS,X=XS)

    其中:

    • H()$ H(\cdot) $ 为熵,它表示GNN 对于节点v$ v $ 类别预测结果的不确定性程度。Y{1,,C}$ Y\in \{1,\cdots,C\} $ 为代表 GNN 预测节点v$ v $ 类别的随机变量。

      注意:这里没有任何关于节点v$ v $ 真实类别的信息。也就是我们不关心 GNN 预测得准不准,而是仅关心哪些因素和 GNN 预测结果相关。

      H(Y)$ H(Y) $ 其实是H(YG=G,X=X)$ H(Y\mid \mathcal G=\mathcal G, \mathbf X=\mathbf X) $ ,即以原始图、原始特征矩阵来进行的预测所得到的熵。

    • H()$ H(\cdot\mid \cdot) $ 为条件熵,它表当节点v$ v $ 的计算图被限制为子图GS$ \mathcal G_S $ 、节点特征被限制为XS$ \mathbf X_S $ 后,GNN 预测结果不确定性程度。

    MI 刻画了当节点v$ v $ 的计算图被限制为子图GS$ \mathcal G_S $ 、节点特征被限制为XS$ \mathbf X_S $ 后,预测结果为y^=Φ(Gc,Xc)$ \hat y=\mathbf\Phi(\mathcal G_c, \mathbf X_c) $ 的概率的变化。例如:

    • 考虑vjGc(vi),vjvi$ v_j\in \mathcal G_c(v_i),v_j\ne v_i $ 。如果从Gc(vi)$ \mathcal G_c(v_i) $ 中移除vj$ v_j $ ,使得预测结果为y^i$ \hat y_i $ 的概率急剧降低,则节点vj$ v_j $ 是vi$ v_i $ 预测的一个很好的解释。
    • 类似地,考虑vj,vkGv(vi),vj,vkvi$ v_j,v_k \in \mathcal G_v(v_i),v_j,v_k\ne v_i $ 。如果移除vj,vk$ v_j,v_k $ 之间的边,使得预测结果为y^i$ \hat y_i $ 的概率急剧降低,则(vj,vk)$ (v_j,v_k) $ 之间的边是vi$ v_i $ 预测的一个很好的解释。
  2. MI(Y,(GS,XS))$ \text{MI}( Y,(\mathcal G_S,\mathbf X_S)) $ 的第一项H(Y)$ H(Y) $ 为常数项,因为对于给定的、训练好的 GNN,节点v$ v $ 的预测结果是y^$ \hat y $ 的概率是已知的,和GS$ \mathcal G_S $ 无关。则有:

    maxGSMI(Y,(GS,XS))=minGSH(YG=GS,X=XS)=maxGSEYGS,XS[logPΦ(YG=GS,X=XS)]

    因此,对于预测y^$ \hat y $ 的解释是一个子图GS$ \mathcal G_S $ ,当 GNN 被限制在GS$ \mathcal G_S $ 时最小化Φ$ \mathbf\Phi $ 不确定性 uncertainty 。在效果上,GS$ \mathcal G_S $ 最大化预测为y^$ \hat y $ 的概率。

    理论上当GS=Gc$ \mathcal G_S = \mathcal G_c $ 时,上式最大化。为了获得更紧凑的解释,我们对GS$ \mathcal G_S $ 的大小施加约束,如|GS|KM$ |\mathcal G_S|\le K_M $ ,使得GS$ \mathcal G_S $ 最多只有KM$ K_M $ 个节点。实际上这意味着 GNNEXPLAINER 旨在通过采取对预测提供最高互信息的KM$ K_M $ 个节点进行降噪。

  3. 直接优化 GNNEXPLAINER 的目标函数很困难,因为Gc$ \mathcal G_c $ 有指数级的子图GS$ \mathcal G_S $ 作为y^$ \hat y $ 的候选解释。因此,我们考虑子图GS$ \mathcal G_S $ 的分数邻接矩阵fractional adjacency matrix,即ASRn×n$ \mathbf A_S\in \mathbb R^{n\times n} $ , 其中AS[i,j][0.0,1.0]$ A_S[i,j] \in [0.0,1.0] $ 在 0~1.0 之间。此外我们施加约束AS[j,k]Ac[j,k]$ A_S[j,k]\le A_c[j,k] $ ,使得没有边的节点之间AS[j,k]$ A_S[j,k] $ 也为零。

    这种连续性松弛continuous relaxation 可以解释为Gc$ \mathcal G_c $ 子图分布的变分近似variational approximation 。具体而言,我们将GSG$ \mathcal G_S\in \mathcal G $ 视为一个随机图变量random graph variable,则目标函数变为:

    minGEGSGH(YG=GS,X=XS)

    我们假设目标函数是凸函数,则 Jensen 不等式给出以下的上界:

    minGH(YG=EG[GS],X=XS)

    实际上由于神经网络的复杂性,凸性假设不成立。但是通过实验我们发现:优化带正则化的上述目标函数通常求得一个局部极小值,该局部极小值具有高质量的解释性。

  4. 为精确地估计EG$ \mathbb E_\mathcal G $ ,我们使用平均场变分近似mean-field variational approximation ,并将G$ \mathcal G $ 分解为多元伯努利分布multivariate Bernoulli distribution

    PG(GS)=(j,k)GcAS[j,k]

    这允许我们估计对于平均场近似的期望从而获得AS$ \mathbf A_S $ ,其中AS$ \mathbf A_S $ 的第(j,k)$ (j,k) $ 元素代表:节点(vj,vk)$ (v_j,v_k) $ 之间存在边的期望。

    • 我们从实验观察到:尽管 GNN 是非凸的,但是这种近似approximation 结合一个可以提升离散型discreteness 的正则化器一起,结果可以收敛到良好的局部极小值。

    • 可以通过使用邻接矩阵的计算图的掩码Acσ(M)$ \mathbf A_c\odot \sigma(\mathbf M) $ 替换要优化的EG[GS]$ \mathbb E_\mathcal G[\mathcal G_S] $ ,从而优化上式中的条件熵。即:

      minGH(YG=Acσ(M),X=XS)=minMc=1CPΦ(Y=cG=Acσ(M),X=XS)logPΦ(Y=cG=Acσ(M),X=XS)

      其中:

      • MRn×n$ \mathbf M\in \mathbb R^{n\times n} $ 表示我们需要学习的mask 矩阵。
      • $ \odot $ 表示逐元素乘积。
      • σ()$ \sigma(\cdot) $ 为 sigmoid 函数,它将 mask 映射到 0.0~1.0 之间。

    GNNExplainer 的核心在于:用 0.0 ~ 1.0 之间的 mask 矩阵(待学习)来调整邻接矩阵,从而最小化预测的熵。但是,这种方法只关心哪个子图对预测结果最重要,并不关心哪个子图对 ground-truth 最有帮助。

    可以通过标签类别和模型预测之间的交叉熵来修改上式中的条件熵,从而得到哪个子图对 ground-truth 最有帮助。

    M$ \mathbf M $ 通过随机梯度下降来学习。

  5. 在某些应用application中,我们不关心模型预测结果的y^$ \hat y $ 的解释性,而更关注如何使得模型能够预测所需要类别的标签。这里我们可以使用标签类别和模型预测之间的交叉熵来修改上式中的条件熵,即:

    minMc=1CI[y=c]logPΦ(Y=cG=Acσ(M),X=Xc)
  6. 尽管有不同的动机和目标,在 Neural Relational Inference 中也发现了masking 方法。

  7. 最后,我们计算σ(M)$ \sigma(\mathbf M) $ 和Ac$ \mathbf A_c $ 的逐元素乘积,并通过阈值移除M$ \mathbf M $ 中的较小的值,从而得出节点v$ v $ 处模型预测y^$ \hat y $ 的解释GS$ \mathcal G_S $ 。

2.1.2 图结构 & 节点特征

  1. 为确定哪些节点特征对于预测y^$ \hat y $ 最重要,GNNEXPLAINER 针对GS$ \mathcal G_S $ 中的节点学习一个特征选择器F$ F $ 。GNNEXPLAINER 考虑GS$ \mathcal G_S $ 中节点的特征子集AF={At1,,AtF},1t1tFd$ \mathbb A^F = \{\mathcal A_{t1},\cdots,\mathcal A_{t_F}\},\quad 1\le t_1\le\cdots\le t_F\le d $ ,其中每个节点特征选择后的特征向量为xjF=(xj,t1,xj,t2,,xj,tF)$ \mathbf{\vec x}_j^F=(x_{j,t_1},x_{j,t_2},\cdots,x_{j,t_F})^\top $ 。所有GS$ \mathcal G_S $ 中节点的选择后的特征矩阵为XSF$ \mathbf X_S^F $ 。

    我们通过一个 mask 来定义特征选择器:

    f=(f1,f2,,fd)

    其中fi{0,1}$ f_i\in\{0,1\} $ 取值为01 ,当它为1 时表示保留对应特征,否则遮盖对应特征。因此xjF$ \mathbf{\vec x}_j^F $ 包含未被F$ F $ 掩盖mask out 的节点特征。

    我们定义特征 mask 矩阵为:

    F=[f1f2fdf1f2fdf1f2fd]R|GS|×d

    则有:XSF=XSF$ \mathbf X_S^F = \mathbf X_S\odot \mathbf F $ 。其中$ \odot $ 表示逐元素乘积。

  2. 现在我们在互信息目标函数中考虑节点特征,从而得到解释explanation(GS,XSF)$ (\mathcal G_S, \mathbf X_S^F) $ :

    maxGS,FMI(Y,(GS,F))=H(Y)H(YG=GS,X=XSF)

    该目标函数同时考虑了对预测y^$ \hat y $ 的子图结构解释、节点特征解释。

  3. 从直觉上看:

    • 如果某个节点特征不重要,则 GNN 权重矩阵中的相应权重应该接近于零。mask 这类特征对于预测结果没有影响。
    • 如果某个节点特征很重要,则 GNN 权重矩阵中相应权重应该较大。mask 这类特征会降低预测为y^$ \hat y $ 的概率。

    但是在某些情况下,这种方法会忽略对于预测很重要、但是特征取值接近于零的特征。为解决该问题,我们对所有特征子集边际化marginalize,并在训练过程中使用蒙特卡洛估计从XS$ \mathbf X_S $ 中节点的经验边际分布中采样得到边际分布。

    此外,我们使用 reparametrization 技巧将目标函的梯度反向传播到 mask 矩阵F$ \mathbf F $ 。

    具体而言,为了通过X$ \mathbf X $ 反向传播,我们 reparametrizeX$ \mathbf X $ 为:

    X=Z+(XSZ)F,s.t.jfjKF

    其中:

    • Z$ \mathbf Z $ 维从经验分布中采样到的随机变量。
    • KF$ K_F $ 为要保留的最大特征数量。

    上式等价于:X=(1F)Z+FXS$ \mathbf X = (1-\mathbf F)\odot \mathbf Z + \mathbf F\odot\mathbf X_S $ 。因此X$ \mathbf X $ 由两部分加权和得到:

    • Z$ \mathbf Z $ :来自于每个维度边际分布采样得到的,权重为1F$ 1-\mathbf F $ ,代表噪音部分。这是为了解决特征取值接近于零但是又对于预测很重要的特征的问题。
    • XS$ \mathbf X_S $ :来自于子图节点的特征向量,权重为F$ \mathbf F $ ,代表真实信号部分。

    这种特征可解释方法可以用于普通的神经网络模型。

  4. 为了在解释explanation中加入更多属性,可以使用正则化项扩展 GNNEXPLAINER 的目标函数。可以包含很多正则化项从而产生具有所需属性的解释。

    • 例如,我们使用逐元素的熵来鼓励结构mask 和节点特征mask 是离散的。
    • 例如,我们可以将 mask 参数的所有元素之和作为正则化项,从而惩罚规模太大的mask
    • 此外, GNNEXPLAINER 可以通过诸如拉格朗日乘子Lagrange multiplier 约束、或者额外的正则化项等技术来编码domain-specific 约束。
  5. 最后需要重点注意的是:每个解释explanation 必须是一个有效的计算图。具体而言,(GS,XS)$ (\mathcal G_S, \mathbf X_S) $ 需要允许 GNN 的消息流向节点v$ v $ ,从而允许 GNN 做出预测y^$ \hat y $ 。

    重要的是,GNNEXPLAINER 的解释一定是有效的计算图,因为它在整个计算图上优化结构 mask。即使一条断开的边对于消息传递很重要,GNNEXPLAINER 也不会选择它作为解释,因为它不会影响 GNN 的预测结果。实际上,这意味着GS$ \mathcal G_S $ 倾向于是一个小的连通子图small connected subgraph

    这是因为 GNNExplainer 会运行 GNN,如果计算图无效则运行 GNN 的结果失败或者预测效果很差,因此也就不会作为可解释结果。

2.1.3 多实例解释

  1. 有时候我们需要回答诸如 “为什么 GNN 对于一组给定的节点预测都是类别 c ” 之类的问题。因此我们需要获得对于类别 c 的全局解释。

    这里我们提出一个基于 GNNEXPLAINER 的解决方案,从而在类别 c 中的一组不同节点的各自单实例解释中,找到针对类别c 的通用的解释。这个问题与寻找每个解释图中最大公共子图密切相关,这是一个 NP-hard 问题。这里我们采用了解决该问题的神经网络方案,案称作基于对齐alignment-basedmulti-instance GNNEXPLAINER

  2. 对于给定的类 c,我们首先选择一个参考节点 reference nodevc$ v_c $ 。直观地看,该节点是能够代表该类别的原型节点 prototypical node

    • 可以通过计算类别 c 中所有节点的 embedding 均值,然后选择类别 c 中节点 embedding 和这个均值最近的节点作为参考节点。
    • 也可以使用有关先验知识,选择和先验知识最匹配的节点作为类别 c 的参考节点。

    给定类别 c 的参考节点vc$ v_c $ ,以及它关联的reference 解释图GS(vc)$ \mathcal G_S(v_c) $ ,我们将类别c$ c $ 中所有节点的解释图都对齐到GS(vc)$ \mathcal G_S(v_c) $ 。

    利用微分池化differentiable pooling 的思想,我们使用一个松弛relaxed 的对齐矩阵alignment matrix来找到解释图GS(v)$ \mathcal G_S(v) $ 中的节点和 reference解释图GS(vc)$ \mathcal G_S(v_c) $ 中的节点之间的对应关系。设节点v$ v $ 待对齐的解释图的邻接矩阵和特征矩阵分别为Av,Xv$ \mathbf A_v, \mathbf X_v $ ,设参考节点的解释图的邻接矩阵和特征矩阵分别为A,X$ \mathbf A^*,\mathbf X^* $ 。我们定义松弛对齐矩阵 relaxed alignment matrixPRnv×n$ \mathbf P\in \mathbb R^{n_v\times n^*} $ ,则优化目标为:

    minP|PAvPA|+|PXvX|

    其中:

    • nv$ n_v $ 为GS(v)$ \mathcal G_S(v) $ 中节点数量,n$ n^* $ 为GS(vc)$ \mathcal G_S(v_c) $ 中节点数量。
    • P$ \mathbf P $ 的元素大于零且每一行的和为 1.0

    上式第一项表示:经过对齐之后,GS(v)$ \mathcal G_S(v) $ 对齐后的邻接关系应该尽可能接近A$ \mathbf A^* $ ;第二项表示:经过对齐之后,GS(v)$ \mathcal G_S(v) $ 对齐后的特征矩阵应该尽可能接近X$ \mathbf X^* $ 。

    实际上对于两个大图GS(v)$ \mathcal G_S(v) $ 和GS(vc)$ \mathcal G_S(v_c) $ ,上述最优化问题很难求解。但是由于单实例解释生成的GS(v)$ \mathcal G_S(v) $ 和GS(vc)$ \mathcal G_S(v_c) $ 都是简洁的、很小的图,因此可以有效地计算几乎最优的对齐方式。

  3. 一旦得到类别 c 中所有节点对齐后的邻接矩阵,我们就可以使用中位数来生成一个原型prototype 。之所以使用中位数,是因为中位数可以有效对抗异常值。即:

    Aproto=median(A~i)

    其中A~i$ \tilde{\mathbf A}_i $ 为类别 c 中第i$ i $ 个节点的 explanation 的对齐后的邻接矩阵(即PAiP$ \mathbf P^\top \mathbf A_i\mathbf P $ )。

    原型Aproto$ \mathbf A_{\text{proto}} $ 允许我们深入了解属于某一类的节点之间共享的图结构模式。然后对于特定的节点,用户可以通过将节点 explanation 和类别原型进行比较,从而研究该特定节点。

  4. 在多个解释图的邻接矩阵对齐过程中,也可以使用现在的图库 graph library 来寻找这些解释图的最大公共子图,从而替换掉神经网络部分。

  5. 在多实例解释中,解释器explainer 不仅必须突出与单个预测的局部相关信息,还需要强调不同实例之间更高level 的相关性。

    这些实例之间可以通过任意方式产生关联,但是最常见的还是类成员class-membership关联。假设类的不同样本之间存在共同特征,那么解释器需要捕获这种共同的特征。例如,通常发现诱变化合物 mutagenic compounds 具有某些特定属性的功能团,如 NO2

    如下图所示,经验丰富的专家可能已经注意到这些功能团的存在。当 GNNEXPLAINER 生成原型prototype 时,可以进一步加强这方面的证据。下图来自于 MUTAG 数据集的诱变化合物。

2.1.4 扩展

  1. 机器学习任务的扩展:除了解释节点分类之外,GNNEXPLAINER 还可以解释链接预测和图分类,无需更改其优化算法。

    • 在预测链接(vj,vk)$ (v_j,v_k) $ 时,GNNEXPLAINER 为链接的两个端点学习两个maskXS(vj),XS(vk)$ \mathbf X_S(v_j), \mathbf X_S(v_k) $ 。
    • 在图分类时,目标函数中的邻接矩阵是图中所有节点邻接矩阵的并集 union

    注意:图分类任务和节点分类任务不同。由于图分类任务存在节点 embedding的聚合,因此解释GS$ \mathcal G_S $ 不必是一个连通子图。根据不同的场景,某些情况下要求解释是一个连通子图,此时可以提取解释中的最大连通分量。

  2. 模型扩展: GNNEXPLAINER 能够处理所有基于消息传递的GNN,包括:Graph Convolutional Networks:GCNGated Graph Sequence Neural Networks:GGS-NNsJumping Knowledge Networks:JK-NetAttention Networks-GATGraph Networks:GN、具有各种聚合方案的 GNNLine-Graph NNsposition-aware GNN、以及很多其它 GNN 架构。

  3. GNNEXPLAINER 优化中的参数规模取决于节点v$ v $ 的计算图Gc$ \mathcal G_c $ 的大小。具体而言,Gc(v)$ \mathcal G_c(v) $ 的邻接矩阵Ac(v)$ \mathbf A_c(v) $ 等于掩码矩阵M$ \mathbf M $ 的大小,其中M$ \mathbf M $ 是需要被 GNNEXPLAINER 学习的。

    但是,由于单个节点的计算图通常较小,因此即使完整的输入图很大 GNNEXPLAINER 仍然可以有效地生成解释。

2.2 实验

  1. 数据集:

    • 人工合成数据集:我们人工构建了四种节点分类数据集,如下表所示。

      • BA-SHAPES 数据集:我们从 300 个节点的 Barabasi-Albert:BA 基础图、以及一组80 个五节点的房屋house 结构的主题 motif 开始,这些 motif 被随机添加到基础图的随机选择的节点上。进一步地我们添加0.1×n$ 0.1\times n $ 条随机边,从而得到经过扰动的合成图。

        根据节点的结构角色,节点为以下四种类型之一:house 顶部节点、house 中间节点、house 底部节点、非house 节点。

      • BA-COMMUNITY 数据集:是两个 BA-SHAPES 图的并集。节点具有正态分布的特征向量,并且根据其结构角色、社区成员(两种社区)分配为8种类别之一。

      • TREE-CYCLES:从 8-level 平衡二叉树为基础图、一组 80 个 六节点的环状 motif 开始,这些 motif 随机添加到基础图的随机选择的节点上。

      • TREE-GRID:和 TREE-CYCLES 相同,除了使用 3x3 的网格 motif 代替六节点的环 motif 之外。

    • 真实数据集:我们考虑两个图分类数据集。

      • MUTAG:包含 4337 个分子图的数据集,根据分子对于革兰氏阴性菌伤寒沙门氏菌Gram-negative bacterium S.typhimurium 的诱变作用mutagenic effect 进行标记。

      • REDDIT-BINARY:包含 2000个图的数据集,每个图代表Reddit 上讨论的话题 thread 。在每个图中,节点代表话题下参与讨论的用户,边代表一个用户对另一个用户的评论进行了回复。

        图根据话题中用户交互类型进行标记:r/IAmA, r/AskReddit 包含 Question-Answer 交互, r/TrollXChromosomes and r/atheism 包含Online-Discussion 交互。

  2. Baseline 方法:很多可解释性方法无法直接应用于图,尽管如此我们考虑了以下baseline 方法,这些方法可以为 GNN 的预测提供解释。

    • GRAD:基于梯度的方法。我们计算损失函数对于邻接矩阵的梯度、损失函数对于节点特征的梯度,这类似于显著性映射方法 saliency map approach

    • ATT:基于graph attention GNN:GAT 的方法。它学习计算图中边的注意力权重,并将其视为边的重要性。

      尽管 ATT 考虑了图结构,但是它并未考虑节点特征的解释,而且仅能解释 GAT 模型。

      此外,由于环cycle 的存在(如下图所示),节点的 1hop 邻居也是它的 2-hop 邻居。因此使用哪个注意力权重(1hop vs 2hop)也不是很清楚。通常我们将这些 hop 的注意力权重取均值。

  3. 实验配置:对于每个数据集,我们首先为这个数据集训练一个 GNN,然后使用 GARDGNNEXPLAINER 来对 GNN 的预测做出解释。

    注意,ATT baseline 需要使用 GAT 之类的图注意力架构,因此我们在同一个数据集上单独训练了一个 GAT 模型,并使用学到的边注意力权重进行解释。

    • 我们对所有的节点分类任务、图分类任务中调整权重正则化参数。这些超参数在所有实验中使用。

      • 子图大小正则化超参数为 0.005 ,该正则化倾向于得到尽可能小的子图。
      • 拉普拉斯正则化参数为 0.5
      • 特征数量正则化参数为 0.1,该正则化倾向于得到尽可能少的unmasked 特征。
    • 我们使用 Adam 优化器训练 GNN 和 解释方法 explaination methods

      • 所有 GNN 模型都训练 1000epoch,学习率为 0.001, 从而对节点分类数据集达到至少 85%的准确率、对于图分类数据集达到至少95%的准确率。

        对于所有数据集,train/valid/test 拆分比例为 80%:10%:10%

      • GNNEXPLAINER 使用相同的优化器和学习率,并训练 100 ~300epoch

        因为 GNNEXPLAINER 仅需要在少于 100 个节点的局部计算图上进行训练,因此训练 epoch 要更少一些。

    • 为了抽取解释子图GS$ \mathcal G_S $ ,我们首先计算边的重要性权重(GRAD 的梯度、ATT 的注意力权重、GNNEXPLAINERmasked 邻接矩阵)。然后我们使用一个阈值来删除权重较低的边,从而得到GS$ \mathcal G_S $ 。

      • 对于所有方法,我们执行线性搜索从而找到临界阈值,使得GS$ \mathcal G_S $ 至少包含KM$ K_M $ 个节点。

      • 所有数据集的 ground truth explanation 是连接的子图。

        对于节点分类,我们将不同方法得到的GS$ \mathcal G_S $ 中抽取连通分量(如前所述,对于 GNNEXPLAINER 方法来讲,GS$ \mathcal G_S $ 已经是连通的)来作为最终的解释。

        对于图分类,我们抽取GS$ \mathcal G_S $ 的最大连通分量来作为最终的解释。

    • 超参数KM$ K_M $ 和KF$ K_F $ 控制解释中的子图大小和特征数量,这可以从数据集相关的先验知识得到。

      • 对于人工合成数据集,我们将KM$ K_M $ 设置为 ground truth 的大小。
      • 对于真实世界数据集,我们设置KM=10,KF=5$ K_M=10,K_F=5 $ 。
  4. 定量分析:对于人工合成数据集,我们已有 ground-truth 解释,然后使用这些ground-truth 来评估所有方法解释的准确性。具体而言,我们将解释问题形式化为二元分类任务,其中真实解释中的边视为label,而将可解释性方法给出的重要性权重视为预测得分。一种更好的可解释性方法对于真实解释的边的预测得分较高,从而获得更好的解释准确率。

    下表给出了人工合成数据集节点分类评估结果。实验结果表明:GNNEXPLAINER 的平均效果相比其它方法高出 17.1%

  5. 定性分析:

    • 在没有节点特征的 topology-based 预测任务中(如 BA-SHAPES、TREE-CYCLES),GNNEXPLAINER 正确地识别解释节点标签的motif

      如下图所示,A-B 给出了四个人工合成数据集上节点分类任务的单实例解释子图,每种方法都为红色节点的预测提供解释(绿色表示重要的节点,橙色表示不重要的节点)。可以看到 GNNEXPLAINER 能识别到 house, cycle, tridmotif,而 baseline 方法无法识别。

    • 我们研究图分类任务的解释。

      • MUTAG 实例中,颜色表示节点特征,这代表原子类型(氢H、碳C 等)。GNNEXPLAINER 可以正确的识别对于图类别比较重要的碳环、以及化学基团 NH2NO2,它们确实已知是诱变的 mutagenic 官能团。

      • REDDIT-BINARY 示例中,我们看到Question-Answe 图(B 的第二行)具有2~3 个同时连接到很多低 degree 节点的高 degree 节点。这是讲得通的,因为在 Reddit 的问答模式的话题中,通常具有 2~3 位专家都回答了许多不同的问题。

        相反,在 REDDIT-BINARY 的讨论模式discussion pattern 图(A 的第二行),通常表现出树状模式。

        GRAD,ATT 方法给出了错误的或者不完整的解释。例如两种baseline 都错过了 MUTAG 数据集中的碳环。

        此外,尽管 ATT 可以将边注意力权重视为消息传递的重要性得分,但是权重在输入图中的所有节点之间共享,因此 ATT 无法提供高质量的单实例解释。

    • 解释explanations 的基本要求是它们必须是可解释的interpretable,即,提供对输入节点和预测之间关系的定性理解。下图显式了一个实验结果,其中给出不同方法预测的解释的特征。特征重要性通过热力度可视化。

      可以看到:GNNEXPLAINER 确实识别出了重要的特征;而 gradient-based 无法识别,它为无关特征提供了较高的重要性得分。

      ground-truth 特征从何而来?作者并未讲清楚。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文