数学基础
- 线性代数
- 概率论与随机过程
- 数值计算
- 蒙特卡洛方法与 MCMC 采样
- 机器学习方法概论
统计学习
深度学习
- 深度学习简介
- 深度前馈网络
- 反向传播算法
- 正则化
- 深度学习中的最优化问题
- 卷积神经网络
- CNN:图像分类
- 循环神经网络 RNN
- Transformer
- 一、Transformer [2017]
- 二、Universal Transformer [2018]
- 三、Transformer-XL [2019]
- 四、GPT1 [2018]
- 五、GPT2 [2019]
- 六、GPT3 [2020]
- 七、OPT [2022]
- 八、BERT [2018]
- 九、XLNet [2019]
- 十、RoBERTa [2019]
- 十一、ERNIE 1.0 [2019]
- 十二、ERNIE 2.0 [2019]
- 十三、ERNIE 3.0 [2021]
- 十四、ERNIE-Huawei [2019]
- 十五、MT-DNN [2019]
- 十六、BART [2019]
- 十七、mBART [2020]
- 十八、SpanBERT [2019]
- 十九、ALBERT [2019]
- 二十、UniLM [2019]
- 二十一、MASS [2019]
- 二十二、MacBERT [2019]
- 二十三、Fine-Tuning Language Models from Human Preferences [2019]
- 二十四 Learning to summarize from human feedback [2020]
- 二十五、InstructGPT [2022]
- 二十六、T5 [2020]
- 二十七、mT5 [2020]
- 二十八、ExT5 [2021]
- 二十九、Muppet [2021]
- 三十、Self-Attention with Relative Position Representations [2018]
- 三十一、USE [2018]
- 三十二、Sentence-BERT [2019]
- 三十三、SimCSE [2021]
- 三十四、BERT-Flow [2020]
- 三十五、BERT-Whitening [2021]
- 三十六、Comparing the Geometry of BERT, ELMo, and GPT-2 Embeddings [2019]
- 三十七、CERT [2020]
- 三十八、DeCLUTR [2020]
- 三十九、CLEAR [2020]
- 四十、ConSERT [2021]
- 四十一、Sentence-T5 [2021]
- 四十二、ULMFiT [2018]
- 四十三、Scaling Laws for Neural Language Models [2020]
- 四十四、Chinchilla [2022]
- 四十七、GLM-130B [2022]
- 四十八、GPT-NeoX-20B [2022]
- 四十九、Bloom [2022]
- 五十、PaLM [2022] (粗读)
- 五十一、PaLM2 [2023](粗读)
- 五十二、Self-Instruct [2022]
- 句子向量
- 词向量
- 传统CTR 预估模型
- CTR 预估模型
- 一、DSSM [2013]
- 二、FNN [2016]
- 三、PNN [2016]
- 四、DeepCrossing [2016]
- 五、Wide 和 Deep [2016]
- 六、DCN [2017]
- 七、DeepFM [2017]
- 八、NFM [2017]
- 九、AFM [2017]
- 十、xDeepFM [2018]
- 十一、ESMM [2018]
- 十二、DIN [2017]
- 十三、DIEN [2019]
- 十四、DSIN [2019]
- 十五、DICM [2017]
- 十六、DeepMCP [2019]
- 十七、MIMN [2019]
- 十八、DMR [2020]
- 十九、MiNet [2020]
- 二十、DSTN [2019]
- 二十一、BST [2019]
- 二十二、SIM [2020]
- 二十三、ESM2 [2019]
- 二十四、MV-DNN [2015]
- 二十五、CAN [2020]
- 二十六、AutoInt [2018]
- 二十七、Fi-GNN [2019]
- 二十八、FwFM [2018]
- 二十九、FM2 [2021]
- 三十、FiBiNET [2019]
- 三十一、AutoFIS [2020]
- 三十三、AFN [2020]
- 三十四、FGCNN [2019]
- 三十五、AutoCross [2019]
- 三十六、InterHAt [2020]
- 三十七、xDeepInt [2023]
- 三十九、AutoDis [2021]
- 四十、MDE [2020]
- 四十一、NIS [2020]
- 四十二、AutoEmb [2020]
- 四十三、AutoDim [2021]
- 四十四、PEP [2021]
- 四十五、DeepLight [2021]
- 图的表达
- 一、DeepWalk [2014]
- 二、LINE [2015]
- 三、GraRep [2015]
- 四、TADW [2015]
- 五、DNGR [2016]
- 六、Node2Vec [2016]
- 七、WALKLETS [2016]
- 八、SDNE [2016]
- 九、CANE [2017]
- 十、EOE [2017]
- 十一、metapath2vec [2017]
- 十二、GraphGAN [2018]
- 十三、struc2vec [2017]
- 十四、GraphWave [2018]
- 十五、NetMF [2017]
- 十六、NetSMF [2019]
- 十七、PTE [2015]
- 十八、HNE [2015]
- 十九、AANE [2017]
- 二十、LANE [2017]
- 二十一、MVE [2017]
- 二十二、PMNE [2017]
- 二十三、ANRL [2018]
- 二十四、DANE [2018]
- 二十五、HERec [2018]
- 二十六、GATNE [2019]
- 二十七、MNE [2018]
- 二十八、MVN2VEC [2018]
- 二十九、SNE [2018]
- 三十、ProNE [2019]
- Graph Embedding 综述
- 图神经网络
- 一、GNN [2009]
- 二、Spectral Networks 和 Deep Locally Connected Networks [2013]
- 三、Fast Localized Spectral Filtering On Graph [2016]
- 四、GCN [2016]
- 五、神经图指纹 [2015]
- 六、GGS-NN [2016]
- 七、PATCHY-SAN [2016]
- 八、GraphSAGE [2017]
- 九、GAT [2017]
- 十、R-GCN [2017]
- 十一、 AGCN [2018]
- 十二、FastGCN [2018]
- 十三、PinSage [2018]
- 十四、GCMC [2017]
- 十五、JK-Net [2018]
- 十六、PPNP [2018]
- 十七、VRGCN [2017]
- 十八、ClusterGCN [2019]
- 十九、LDS-GNN [2019]
- 二十、DIAL-GNN [2019]
- 二十一、HAN [2019]
- 二十二、HetGNN [2019]
- 二十三、HGT [2020]
- 二十四、GPT-GNN [2020]
- 二十五、Geom-GCN [2020]
- 二十六、Graph Network [2018]
- 二十七、GIN [2019]
- 二十八、MPNN [2017]
- 二十九、UniMP [2020]
- 三十、Correct and Smooth [2020]
- 三十一、LGCN [2018]
- 三十二、DGCNN [2018]
- 三十三、AS-GCN
- 三十四、DGI [2018]
- 三十五、DIFFPOLL [2018]
- 三十六、DCNN [2016]
- 三十七、IN [2016]
- 图神经网络 2
- 图神经网络 3
- 推荐算法(传统方法)
- 一、Tapestry [1992]
- 二、GroupLens [1994]
- 三、ItemBased CF [2001]
- 四、Amazon I-2-I CF [2003]
- 五、Slope One Rating-Based CF [2005]
- 六、Bipartite Network Projection [2007]
- 七、Implicit Feedback CF [2008]
- 八、PMF [2008]
- 九、SVD++ [2008]
- 十、MMMF 扩展 [2008]
- 十一、OCCF [2008]
- 十二、BPR [2009]
- 十三、MF for RS [2009]
- 十四、 Netflix BellKor Solution [2009]
- 推荐算法(神经网络方法 1)
- 一、MIND [2019](用于召回)
- 二、DNN For YouTube [2016]
- 三、Recommending What Video to Watch Next [2019]
- 四、ESAM [2020]
- 五、Facebook Embedding Based Retrieval [2020](用于检索)
- 六、Airbnb Search Ranking [2018]
- 七、MOBIUS [2019](用于召回)
- 八、TDM [2018](用于检索)
- 九、DR [2020](用于检索)
- 十、JTM [2019](用于检索)
- 十一、Pinterest Recommender System [2017]
- 十二、DLRM [2019]
- 十三、Applying Deep Learning To Airbnb Search [2018]
- 十四、Improving Deep Learning For Airbnb Search [2020]
- 十五、HOP-Rec [2018]
- 十六、NCF [2017]
- 十七、NGCF [2019]
- 十八、LightGCN [2020]
- 十九、Sampling-Bias-Corrected Neural Modeling [2019](检索)
- 二十、EGES [2018](Matching 阶段)
- 二十一、SDM [2019](Matching 阶段)
- 二十二、COLD [2020 ] (Pre-Ranking 模型)
- 二十三、ComiRec [2020](https://www.wenjiangs.com/doc/0b4e1736-ac78)
- 二十四、EdgeRec [2020]
- 二十五、DPSR [2020](检索)
- 二十六、PDN [2021](mathcing)
- 二十七、时空周期兴趣学习网络ST-PIL [2021]
- 推荐算法之序列推荐
- 一、FPMC [2010]
- 二、GRU4Rec [2015]
- 三、HRM [2015]
- 四、DREAM [2016]
- 五、Improved GRU4Rec [2016]
- 六、NARM [2017]
- 七、HRNN [2017]
- 八、RRN [2017]
- 九、Caser [2018]
- 十、p-RNN [2016]
- 十一、GRU4Rec Top-k Gains [2018]
- 十二、SASRec [2018]
- 十三、RUM [2018]
- 十四、SHAN [2018]
- 十五、Phased LSTM [2016]
- 十六、Time-LSTM [2017]
- 十七、STAMP [2018]
- 十八、Latent Cross [2018]
- 十九、CSRM [2019]
- 二十、SR-GNN [2019]
- 二十一、GC-SAN [2019]
- 二十二、BERT4Rec [2019]
- 二十三、MCPRN [2019]
- 二十四、RepeatNet [2019]
- 二十五、LINet(2019)
- 二十六、NextItNet [2019]
- 二十七、GCE-GNN [2020]
- 二十八、LESSR [2020]
- 二十九、HyperRec [2020]
- 三十、DHCN [2021]
- 三十一、TiSASRec [2020]
- 推荐算法(综述)
- 多任务学习
- 系统架构
- 实践方法论
- 深度强化学习 1
- 自动代码生成
工具
- CRF
- lightgbm
- xgboost
- scikit-learn
- spark
- numpy
- matplotlib
- pandas
- huggingface_transformer
- 一、Tokenizer
- 二、Datasets
- 三、Model
- 四、Trainer
- 五、Evaluator
- 六、Pipeline
- 七、Accelerate
- 八、Autoclass
- 九、应用
- 十、Gradio
Scala
- 环境搭建
- 基础知识
- 函数
- 类
- 样例类和模式匹配
- 测试和注解
- 集合 collection(一)
- 集合collection(二)
- 集成 Java
- 并发
二、GNNEXPLAINER [2019]
图是强大的数据表示形式,而图神经网络
Graph Neural Network: GNN
是处理图的最新技术。GNN
能够递归地聚合图中邻域节点的信息,从而自然地捕获图结构和节点特征。尽管
GNN
效果很好,但是其可解释性较差。由于以下几个原因,GNN
的可解释性非常重要:- 可以提高对
GNN
模型的信任程度。 - 在越来越多有关公平
fairness
、隐私privacy
、以及其它安全挑战safety challenge
的关键决策应用application
中,提高模型的透明度transparency
。 - 允许从业人员理解模型特点,从而在实际环境中部署之前就能识别并纠正模型的错误。
尽管目前尚无用于解释
GNN
的方法。但是在更高层次,我们可以将包括GNN
和Non-GNN
的可解释性方法分为两个主要系列:使用简单的代理模型
surrogate model
局部逼近locally approximate
完整模型full model
,然后探索这个代理模型以进行解释。这可能是以模型无关
model-agnostic
的方式来完成的,通常是通过线性模型linear model
或者规则集合set of rules
来学习预测的局部近似,可以充分代表预测结果。仔细检查模型的相关特征
relevant feature
,并找到high-level
特征的良好定性解释good qualitative interpretation
,或者识别有影响力的输入样本influential input instance
。如通过特征梯度
feature gradients
、神经元反向传播中输入特征的贡献、反事实推理counterfactual reasoning
等。
上述两类方法专注于研究模型的固有解释,而后验
post-hoc
的可解性方法将模型视为黑盒,然后对其进行探索从而获得相关信息。但是所有这些方法无法融合关系信息,即图的结构信息。由于关系信息对于图上机器学习任务的成功至关重要,因此对
GNN
预测的任何解释都应该利用图提供的丰富关系信息、以及节点特征。论文《GNNExplainer: Generating Explanations for Graph Neural Networks》
提出了 一种解释GNN
预测的方法,称作GNNEXPLAINER
。GNNEXPLAINER
接受训练好的GNN
及其预测,返回对于预测最有影响力的、输入图的一个小的子图small subgraph
,以及节点特征的一个小的子集small subset
。如下图所示,这里展示了一个节点分类任务,其中在社交网络上训练了
GNN
模型 $ \mathbf\Phi $ 从而预测体育活动。给定训练好的GNN
$ \mathbf\Phi $ 以及对节点 $ v_i $ 的预测 $ \hat y_i= \text{basketball} $ ,GNNEXPLAINER
通过识别对预测 $ \hat y_i $ 有影响力的输入图的一个小的子图、以及节点特征的一个小的子集来生成解释explanation
,如右图所示。- 通过检查
$ \hat y_i $ 的解释explanation
,我们发现 $ v_i $ 社交圈中很多朋友都喜欢玩球类游戏,因此GNN
预测 $ v_i $ 可能会喜欢篮球。 - 同样地,通过检查
$ \hat y_j $ 的解释explanation
我们发现 $ v_j $ 社交圈中很多朋友都喜欢水上运动和沙滩运动,因此GNN
预测 $ \hat y_j=\text{sailing} $ (帆船运动)。
GNNEXPLAINER
方法和GNN
模型无关,可以解释任何GNN
在图的任何机器学习任务上的预测,包括节点分类、链接预测、图分类任务。它可以解释单个实例的预测single-instance explanation
,也可以解释一组实例的预测multi-instance explanation
。- 在单个实例预测的情况下,
GNNEXPLAINER
解释了GNN
对于特定样本的预测。 - 在多个实例预测的情况下,
GNNEXPLAINER
提供了对于一组样本(如类别为“篮球”的所有节点)的预测的一致性解释(即这些预测的共同的解释)。
GNNEXPLAINER
将解释指定为训练GNN
的整个输入图的某个子图,其中子图基于GNN
的预测最大化互信息mutual information
。这是通过平均场变分近似mean field variational approximation
,以及学习一个实值graph mask
来实现的。这个graph mask
选择了GNN
输入图的最重要子图。同时,GNNEXPLAINER
还学习了一个feature mask
,它可以掩盖不重要的节点特征。论文在人工合成图、以及真实的图上评估
GNNEXPLAINER
的效果。实验表明:GNNEXPLAINER
为GNN
的预测提供了一致而简洁的解释。- 可以提高对
虽然解释
GNN
问题没有得到很好的研究,但相关的可解释性interpretability
和神经调试neural debugging
的问题在机器学习中得到了大量的关注。在high-level
上,我们可以将那些non-graph neural network
的可解释性方法分为两个主要方向:- 第一个方向的方法制定了完整神经网络
full neural network
的简单代理模型。这可以通过模型无关model-agnostic
的方式完成,通常是通过学习prediction
周围的局部良好近似(例如通过线性模型或规则集合),代表预测的充分条件sufficient condition
。 - 第二个方向方法确定了计算的重要方面,例如,特征梯度
feature gradient
、神经元的反向传播对输入特征的贡献、以及反事实推理。然而,这些方法产生的显著性映射saliency map
已被证明在某些情况下具有误导性,并且容易出现梯度饱和等问题。这些问题在离散输入(如图邻接矩阵)上更加严重,因为梯度值可能非常大而且位于一个非常小的区域interval
。正因为如此,这种方法不适合用来解释神经网络在图上的预测。
事后可解释性
post-hoc interpretability
方法不是创建新的、固有可解释的模型,而是将模型视为黑箱,然后探测模型的相关信息。然而,还没有利用关系型结构(如graph
)方面的工作。解释图结构数据预测的方法的缺乏是有问题的,因为在很多情况下,图上的预测是由节点和它们之间的边的路径的复杂组合引起的。例如,在一些任务中,只有当图中存在另一条替代路径形成一个循环时,一条边才是重要的,而这两个特征,只有在一起考虑时,才能准确预测节点标签。因此,它们的联合贡献不能被建模为单个贡献的简单线性组合。最后,最近的
GNN
模型通过注意力机制增强了可解释性。然而,尽管学到的edge attention
值可以表明重要的图结构,但这些值对所有节点的预测都是一样的。因此,这与许多应用相矛盾,在这些应用中,一条边对于预测一个节点的标签是至关重要的,但对于另一个节点的标签则不是。此外,这些方法要么仅限于特定的GNN
架构,要么不能通过共同考虑图结构和节点特征信息来解释预测结果。- 第一个方向的方法制定了完整神经网络
GNNEXPLAINER
提供了各种好处 ,包括可视化语义相关结构以进行解释的能力、以及提供洞察GNN
的错误的能力。
2.1 方法
定义图
$ \mathcal G=(\mathcal V, \mathcal E, \mathbf X) $ ,其中: $ \mathcal V=\{v_1,\cdots,v_n\} $ 为节点集合,节点数量为 $ n $ 。 $ \mathcal E=\{e_{i,j}\} $ 为边集合, $ e_{ i,j } = (v_i,v_j) $ ,边数量为 $ m $ 。- 每个节点
$ v_i $ 关联一个特征向量 $ \mathbf{\vec x}_i\in \mathbb R^{d} $ ,所有节点的特征向量拼接为特征矩阵 $ \mathbf X\in \mathbb R^{n\times d} $ , $ d $ 为特征向量维度。
不失一般性,我们考虑节点分类问题的可解释性。定义
$ f $ 为一个节点到类别标签的映射函数: $ f:\mathcal V\rightarrow \{1,\cdots,C\} $ ,其中 $ C $ 为类别数量。GNN
模型 $ \mathbf\Phi $ 在训练集的所有节点上优化从而逼近 $ f $ ,然后用于预测未标记节点的label
。我们假设
GNN
模型 $ \mathbf\Phi $ 采用消息传递机制。在第 $ l $ 层,GNN
模型 $ \mathbf \Phi $ 的更新涉及三个关键计算:首先,模型计算每对节点
pair
对之间传递的消息。节点pair
对 $ (v_i,v_j) $ 之间的消息定义为:其中:
MSG(.)
为消息函数; $ \mathbf{\vec h}_i^{(l-1)} $ 为节点 $ v_i $ 在第 $ l-1 $ 层的representation
; $ r_{i,j} $ 为节点 $ v_i $ 和 $ v_j $ 之间的关系。然后,对于每个节点
$ v_i $ ,GNN
聚合来自其邻域的所有消息:其中:
AGG(.)
为一个邻域聚合函数; $ \mathcal N_i $ 为节点 $ v_i $ 的邻域集合。最后,对于每个节点
$ v_i $ ,GNN
根据 $ \mathbf{\vec a}_i^{(l)} , \mathbf{\vec h}_i^{(l-1)} $ 来计算 $ v_i $ 在第 $ l $ 层的representation
:其中
UPDATE(.)
为节点状态更新函数。
最终节点
$ v_i $ 的embedding
为 $ v_i $ 在第 $ L $ 层输出的representation
:对于采用
MSG,AGG,UPDATE
计算组成的任何GNN
,我们的GNNEXPLAINER
可以提供解释。我们的洞察
insight
是观察到:节点 $ v $ 的计算图computation graph
是由GNN
的neighborhood-based
聚合来定义,如下图所示。这个计算图完全决定了用于生成节点 $ v $ 的预测 $ \hat y $ 的所有信息。具体而言,节点 $ v $ 的计算图告诉GNN
如何生成节点 $ v $ 的embedding
$ \mathbf{\vec z} $ 。定义节点
$ v $ 的计算图computation graph
为 $ \mathcal G_c(v) $ ,它关联一个二元binary
邻接矩阵 $ \mathbf A_c(v)\in \mathbb R^{n\times n} $ ,其中 $ A_c(v)[i,j]\in \{0,1\} $ 取值为0
或1
; 也关联一个特征矩阵 $ \mathbf X_c(v) = \{\mathbf{\vec x}_j\mid v_j\in \mathcal G_c(v)\} $ 。GNN
模型 $ \mathbf\Phi $ 学习一个条件分布 $ P_{\mathbf \Phi}\left( Y_v\mid \mathcal G_c(v),\mathbf X_c(v)\right) $ ,表示给定节点计算图 $ \mathcal G_c(v) $ 、节点特征矩阵 $ \mathbf X_c(v) $ 的条件下,节点 $ v $ 属于各类别的概率。其中 $ Y_v\in \{1,\cdots,C\} $ 为一个随机变量。一旦
GNN
模型学到这样的分布之后,对于节点 $ v $ ,GNN
的类别预测结果为 $ \hat y = \mathbf\Phi(\mathcal G_c(v),\mathbf X_c(v)) $ ,意味着它完全由三个因素决定:模型 $ \mathbf\Phi $ 、图结构信息 $ \mathcal G_c(v) $ 、节点特征信息 $ \mathbf X_c(v) $ 。这个观察结果意味着我们只需要考虑图结构 $ \mathcal G_c(v) $ 和节点特征 $ \mathbf X_c(v) $ 来解释 $ \hat y $ ,如下图A
所示。正式地讲,
GNNEXPLAINER
为预测 $ \hat y $ 生成解释explanation
,记作 $ \left(\mathcal G_S,\mathbf X_S^F\right) $ 。其中: $ \mathcal G_S $ 为计算图的一个小的子图small subgraph
,如图A
所示。 $ \mathbf X_S $ 为 $ \mathcal G_S $ 关联的特征, $ \mathbf X_S^F $ 为节点特征的一个小的特征子集small subset
。F
表示通过mask F
来遮盖,即: $ \mathbf X_S^F = \{\mathbf{\vec x}_j^F\mid v_j\in \mathcal G_S\} $ ,如图B
所示。假设原始的节点特征集合为
$ \mathbb A=\{\mathcal A_1,\cdots,\mathcal A_d\} $ ,则经过mask F
遮盖之后的特征集合为:它是原始特征集合的一个小的特征子集,且有:
$ \mathbf{\vec x}_j^F=(x_{j,t_1},x_{j,t_2},\cdots,x_{j,t_F})^\top\in \mathbb R^F $ 为遮盖后的特征向量。
下图中:
图
A
给出了一个GNN
在节点 $ v $ 处的计算图 $ \mathcal G_c(v) $ ,它用于得到节点 $ v $ 的类别预测 $ \hat y $ 。 $ \mathcal G_c(v) $ 中的某些边构成重要的消息传播路径(绿色),这些路径允许有用的节点消息跨 $ \mathcal G_c(v) $ 传播并在节点 $ v $ 处聚合从而进行预测;相反, $ \mathcal G_c(v) $ 中的另一些边不重要(橙色)。但是无论消息重不重要,在节点 $ v $ 处GNN
都会具有所有消息(包括不重要的消息)从而进行预测,这可能会稀释重要的消息。GNNEXPLAINER
的目标是识别少量对于预测至关重要的重要特征和路径(绿色)。图
B
表示GNNEXPLAINER
通过学习节点特征mask
来确定 $ \mathcal G_S $ 中节点的那些特征维度对于预测至关重要。
接下来我们详细描述
GNNEXPLAINER
。给定训练好的GNN
模型 $ \mathbf\Phi $ 以及一个预测prediction
(即单实例解释single-instance explanation
)、或者一组预测(即多实例解释multi-instance explanation
),GNNEXPLAINER
将通过识别对模型 $ \mathbf\Phi $ 的预测影响最大的计算图的子图、节点特征的子集从而生成解释。在多实例解释中,
GNNEXPLAINER
将每个实例的解释聚合在一起并自动抽取为一个原型proto
。这个原型代表每个实例解释的公共部分,即proto
可以对所有这些实例进行解释。
2.1.1 单实例解释
给定一个节点
$ v $ ,我们的目标是识别对于该节点的GNN
预测 $ \hat y $ 很重要的子图 $ \mathcal G_S\sube \mathcal G_c $ ,以及关联特征矩阵 $ \mathbf X_S=\{ \mathbf{\vec x}_j \mid v_j\in \mathcal G_S\} $ 。现在我们不考虑特征mask
,这留待下一步讨论。我们使用互信息
mutual information:MI
来刻画子图的重要性,并将GNNEXPLAINER
形式化为以下最优化问题:其中:
$ H(\cdot) $ 为熵,它表示GNN
对于节点 $ v $ 类别预测结果的不确定性程度。 $ Y\in \{1,\cdots,C\} $ 为代表GNN
预测节点 $ v $ 类别的随机变量。注意:这里没有任何关于节点
$ v $ 真实类别的信息。也就是我们不关心GNN
预测得准不准,而是仅关心哪些因素和GNN
预测结果相关。 $ H(Y) $ 其实是 $ H(Y\mid \mathcal G=\mathcal G, \mathbf X=\mathbf X) $ ,即以原始图、原始特征矩阵来进行的预测所得到的熵。 $ H(\cdot\mid \cdot) $ 为条件熵,它表当节点 $ v $ 的计算图被限制为子图 $ \mathcal G_S $ 、节点特征被限制为 $ \mathbf X_S $ 后,GNN
预测结果不确定性程度。
MI
刻画了当节点 $ v $ 的计算图被限制为子图 $ \mathcal G_S $ 、节点特征被限制为 $ \mathbf X_S $ 后,预测结果为 $ \hat y=\mathbf\Phi(\mathcal G_c, \mathbf X_c) $ 的概率的变化。例如:- 考虑
$ v_j\in \mathcal G_c(v_i),v_j\ne v_i $ 。如果从 $ \mathcal G_c(v_i) $ 中移除 $ v_j $ ,使得预测结果为 $ \hat y_i $ 的概率急剧降低,则节点 $ v_j $ 是 $ v_i $ 预测的一个很好的解释。 - 类似地,考虑
$ v_j,v_k \in \mathcal G_v(v_i),v_j,v_k\ne v_i $ 。如果移除 $ v_j,v_k $ 之间的边,使得预测结果为 $ \hat y_i $ 的概率急剧降低,则 $ (v_j,v_k) $ 之间的边是 $ v_i $ 预测的一个很好的解释。
$ \text{MI}( Y,(\mathcal G_S,\mathbf X_S)) $ 的第一项 $ H(Y) $ 为常数项,因为对于给定的、训练好的GNN
,节点 $ v $ 的预测结果是 $ \hat y $ 的概率是已知的,和 $ \mathcal G_S $ 无关。则有:因此,对于预测
$ \hat y $ 的解释是一个子图 $ \mathcal G_S $ ,当GNN
被限制在 $ \mathcal G_S $ 时最小化 $ \mathbf\Phi $ 不确定性uncertainty
。在效果上, $ \mathcal G_S $ 最大化预测为 $ \hat y $ 的概率。理论上当
$ \mathcal G_S = \mathcal G_c $ 时,上式最大化。为了获得更紧凑的解释,我们对 $ \mathcal G_S $ 的大小施加约束,如 $ |\mathcal G_S|\le K_M $ ,使得 $ \mathcal G_S $ 最多只有 $ K_M $ 个节点。实际上这意味着GNNEXPLAINER
旨在通过采取对预测提供最高互信息的 $ K_M $ 个节点进行降噪。直接优化
GNNEXPLAINER
的目标函数很困难,因为 $ \mathcal G_c $ 有指数级的子图 $ \mathcal G_S $ 作为 $ \hat y $ 的候选解释。因此,我们考虑子图 $ \mathcal G_S $ 的分数邻接矩阵fractional adjacency matrix
,即 $ \mathbf A_S\in \mathbb R^{n\times n} $ , 其中 $ A_S[i,j] \in [0.0,1.0] $ 在0~1.0
之间。此外我们施加约束 $ A_S[j,k]\le A_c[j,k] $ ,使得没有边的节点之间 $ A_S[j,k] $ 也为零。这种连续性松弛
continuous relaxation
可以解释为 $ \mathcal G_c $ 子图分布的变分近似variational approximation
。具体而言,我们将 $ \mathcal G_S\in \mathcal G $ 视为一个随机图变量random graph variable
,则目标函数变为:我们假设目标函数是凸函数,则
Jensen
不等式给出以下的上界:实际上由于神经网络的复杂性,凸性假设不成立。但是通过实验我们发现:优化带正则化的上述目标函数通常求得一个局部极小值,该局部极小值具有高质量的解释性。
为精确地估计
$ \mathbb E_\mathcal G $ ,我们使用平均场变分近似mean-field variational approximation
,并将 $ \mathcal G $ 分解为多元伯努利分布multivariate Bernoulli distribution
:这允许我们估计对于平均场近似的期望从而获得
$ \mathbf A_S $ ,其中 $ \mathbf A_S $ 的第 $ (j,k) $ 元素代表:节点 $ (v_j,v_k) $ 之间存在边的期望。我们从实验观察到:尽管
GNN
是非凸的,但是这种近似approximation
结合一个可以提升离散型discreteness
的正则化器一起,结果可以收敛到良好的局部极小值。可以通过使用邻接矩阵的计算图的掩码
$ \mathbf A_c\odot \sigma(\mathbf M) $ 替换要优化的 $ \mathbb E_\mathcal G[\mathcal G_S] $ ,从而优化上式中的条件熵。即:其中:
$ \mathbf M\in \mathbb R^{n\times n} $ 表示我们需要学习的mask
矩阵。 $ \odot $ 表示逐元素乘积。 $ \sigma(\cdot) $ 为sigmoid
函数,它将mask
映射到0.0~1.0
之间。
GNNExplainer
的核心在于:用0.0 ~ 1.0
之间的mask
矩阵(待学习)来调整邻接矩阵,从而最小化预测的熵。但是,这种方法只关心哪个子图对预测结果最重要,并不关心哪个子图对ground-truth
最有帮助。可以通过标签类别和模型预测之间的交叉熵来修改上式中的条件熵,从而得到哪个子图对
ground-truth
最有帮助。 $ \mathbf M $ 通过随机梯度下降来学习。在某些应用
application
中,我们不关心模型预测结果的 $ \hat y $ 的解释性,而更关注如何使得模型能够预测所需要类别的标签。这里我们可以使用标签类别和模型预测之间的交叉熵来修改上式中的条件熵,即:尽管有不同的动机和目标,在
Neural Relational Inference
中也发现了masking
方法。最后,我们计算
$ \sigma(\mathbf M) $ 和 $ \mathbf A_c $ 的逐元素乘积,并通过阈值移除 $ \mathbf M $ 中的较小的值,从而得出节点 $ v $ 处模型预测 $ \hat y $ 的解释 $ \mathcal G_S $ 。
2.1.2 图结构 & 节点特征
为确定哪些节点特征对于预测
$ \hat y $ 最重要,GNNEXPLAINER
针对 $ \mathcal G_S $ 中的节点学习一个特征选择器 $ F $ 。GNNEXPLAINER
考虑 $ \mathcal G_S $ 中节点的特征子集 $ \mathbb A^F = \{\mathcal A_{t1},\cdots,\mathcal A_{t_F}\},\quad 1\le t_1\le\cdots\le t_F\le d $ ,其中每个节点特征选择后的特征向量为 $ \mathbf{\vec x}_j^F=(x_{j,t_1},x_{j,t_2},\cdots,x_{j,t_F})^\top $ 。所有 $ \mathcal G_S $ 中节点的选择后的特征矩阵为 $ \mathbf X_S^F $ 。我们通过一个
mask
来定义特征选择器:其中
$ f_i\in\{0,1\} $ 取值为0
或1
,当它为1
时表示保留对应特征,否则遮盖对应特征。因此 $ \mathbf{\vec x}_j^F $ 包含未被 $ F $ 掩盖mask out
的节点特征。我们定义特征
mask
矩阵为:则有:
$ \mathbf X_S^F = \mathbf X_S\odot \mathbf F $ 。其中 $ \odot $ 表示逐元素乘积。现在我们在互信息目标函数中考虑节点特征,从而得到解释
explanation
$ (\mathcal G_S, \mathbf X_S^F) $ :该目标函数同时考虑了对预测
$ \hat y $ 的子图结构解释、节点特征解释。从直觉上看:
- 如果某个节点特征不重要,则
GNN
权重矩阵中的相应权重应该接近于零。mask
这类特征对于预测结果没有影响。 - 如果某个节点特征很重要,则
GNN
权重矩阵中相应权重应该较大。mask
这类特征会降低预测为 $ \hat y $ 的概率。
但是在某些情况下,这种方法会忽略对于预测很重要、但是特征取值接近于零的特征。为解决该问题,我们对所有特征子集边际化
marginalize
,并在训练过程中使用蒙特卡洛估计从 $ \mathbf X_S $ 中节点的经验边际分布中采样得到边际分布。此外,我们使用
reparametrization
技巧将目标函的梯度反向传播到mask
矩阵 $ \mathbf F $ 。具体而言,为了通过
$ \mathbf X $ 反向传播,我们reparametrize
$ \mathbf X $ 为:其中:
$ \mathbf Z $ 维从经验分布中采样到的随机变量。 $ K_F $ 为要保留的最大特征数量。
上式等价于:
$ \mathbf X = (1-\mathbf F)\odot \mathbf Z + \mathbf F\odot\mathbf X_S $ 。因此 $ \mathbf X $ 由两部分加权和得到: $ \mathbf Z $ :来自于每个维度边际分布采样得到的,权重为 $ 1-\mathbf F $ ,代表噪音部分。这是为了解决特征取值接近于零但是又对于预测很重要的特征的问题。 $ \mathbf X_S $ :来自于子图节点的特征向量,权重为 $ \mathbf F $ ,代表真实信号部分。
这种特征可解释方法可以用于普通的神经网络模型。
- 如果某个节点特征不重要,则
为了在解释
explanation
中加入更多属性,可以使用正则化项扩展GNNEXPLAINER
的目标函数。可以包含很多正则化项从而产生具有所需属性的解释。- 例如,我们使用逐元素的熵来鼓励结构
mask
和节点特征mask
是离散的。 - 例如,我们可以将
mask
参数的所有元素之和作为正则化项,从而惩罚规模太大的mask
。 - 此外,
GNNEXPLAINER
可以通过诸如拉格朗日乘子Lagrange multiplier
约束、或者额外的正则化项等技术来编码domain-specific
约束。
- 例如,我们使用逐元素的熵来鼓励结构
最后需要重点注意的是:每个解释
explanation
必须是一个有效的计算图。具体而言, $ (\mathcal G_S, \mathbf X_S) $ 需要允许GNN
的消息流向节点 $ v $ ,从而允许GNN
做出预测 $ \hat y $ 。重要的是,
GNNEXPLAINER
的解释一定是有效的计算图,因为它在整个计算图上优化结构mask
。即使一条断开的边对于消息传递很重要,GNNEXPLAINER
也不会选择它作为解释,因为它不会影响GNN
的预测结果。实际上,这意味着 $ \mathcal G_S $ 倾向于是一个小的连通子图small connected subgraph
。这是因为
GNNExplainer
会运行GNN
,如果计算图无效则运行GNN
的结果失败或者预测效果很差,因此也就不会作为可解释结果。
2.1.3 多实例解释
有时候我们需要回答诸如 “为什么
GNN
对于一组给定的节点预测都是类别c
” 之类的问题。因此我们需要获得对于类别c
的全局解释。这里我们提出一个基于
GNNEXPLAINER
的解决方案,从而在类别c
中的一组不同节点的各自单实例解释中,找到针对类别c
的通用的解释。这个问题与寻找每个解释图中最大公共子图密切相关,这是一个NP-hard
问题。这里我们采用了解决该问题的神经网络方案,案称作基于对齐alignment-based
的multi-instance GNNEXPLAINER
。对于给定的类
c
,我们首先选择一个参考节点reference node
$ v_c $ 。直观地看,该节点是能够代表该类别的原型节点prototypical node
。- 可以通过计算类别
c
中所有节点的embedding
均值,然后选择类别c
中节点embedding
和这个均值最近的节点作为参考节点。 - 也可以使用有关先验知识,选择和先验知识最匹配的节点作为类别
c
的参考节点。
给定类别
c
的参考节点 $ v_c $ ,以及它关联的reference
解释图 $ \mathcal G_S(v_c) $ ,我们将类别 $ c $ 中所有节点的解释图都对齐到 $ \mathcal G_S(v_c) $ 。利用微分池化
differentiable pooling
的思想,我们使用一个松弛relaxed
的对齐矩阵alignment matrix
来找到解释图 $ \mathcal G_S(v) $ 中的节点和reference
解释图 $ \mathcal G_S(v_c) $ 中的节点之间的对应关系。设节点 $ v $ 待对齐的解释图的邻接矩阵和特征矩阵分别为 $ \mathbf A_v, \mathbf X_v $ ,设参考节点的解释图的邻接矩阵和特征矩阵分别为 $ \mathbf A^*,\mathbf X^* $ 。我们定义松弛对齐矩阵relaxed alignment matrix
$ \mathbf P\in \mathbb R^{n_v\times n^*} $ ,则优化目标为:其中:
$ n_v $ 为 $ \mathcal G_S(v) $ 中节点数量, $ n^* $ 为 $ \mathcal G_S(v_c) $ 中节点数量。 $ \mathbf P $ 的元素大于零且每一行的和为1.0
。
上式第一项表示:经过对齐之后,
$ \mathcal G_S(v) $ 对齐后的邻接关系应该尽可能接近 $ \mathbf A^* $ ;第二项表示:经过对齐之后, $ \mathcal G_S(v) $ 对齐后的特征矩阵应该尽可能接近 $ \mathbf X^* $ 。实际上对于两个大图
$ \mathcal G_S(v) $ 和 $ \mathcal G_S(v_c) $ ,上述最优化问题很难求解。但是由于单实例解释生成的 $ \mathcal G_S(v) $ 和 $ \mathcal G_S(v_c) $ 都是简洁的、很小的图,因此可以有效地计算几乎最优的对齐方式。- 可以通过计算类别
一旦得到类别
c
中所有节点对齐后的邻接矩阵,我们就可以使用中位数来生成一个原型prototype
。之所以使用中位数,是因为中位数可以有效对抗异常值。即:其中
$ \tilde{\mathbf A}_i $ 为类别c
中第 $ i $ 个节点的explanation
的对齐后的邻接矩阵(即 $ \mathbf P^\top \mathbf A_i\mathbf P $ )。原型
$ \mathbf A_{\text{proto}} $ 允许我们深入了解属于某一类的节点之间共享的图结构模式。然后对于特定的节点,用户可以通过将节点explanation
和类别原型进行比较,从而研究该特定节点。在多个解释图的邻接矩阵对齐过程中,也可以使用现在的图库
graph library
来寻找这些解释图的最大公共子图,从而替换掉神经网络部分。在多实例解释中,解释器
explainer
不仅必须突出与单个预测的局部相关信息,还需要强调不同实例之间更高level
的相关性。这些实例之间可以通过任意方式产生关联,但是最常见的还是类成员
class-membership
关联。假设类的不同样本之间存在共同特征,那么解释器需要捕获这种共同的特征。例如,通常发现诱变化合物mutagenic compounds
具有某些特定属性的功能团,如NO2
。如下图所示,经验丰富的专家可能已经注意到这些功能团的存在。当
GNNEXPLAINER
生成原型prototype
时,可以进一步加强这方面的证据。下图来自于MUTAG
数据集的诱变化合物。
2.1.4 扩展
机器学习任务的扩展:除了解释节点分类之外,
GNNEXPLAINER
还可以解释链接预测和图分类,无需更改其优化算法。- 在预测链接
$ (v_j,v_k) $ 时,GNNEXPLAINER
为链接的两个端点学习两个mask
$ \mathbf X_S(v_j), \mathbf X_S(v_k) $ 。 - 在图分类时,目标函数中的邻接矩阵是图中所有节点邻接矩阵的并集
union
。
注意:图分类任务和节点分类任务不同。由于图分类任务存在节点
embedding
的聚合,因此解释 $ \mathcal G_S $ 不必是一个连通子图。根据不同的场景,某些情况下要求解释是一个连通子图,此时可以提取解释中的最大连通分量。- 在预测链接
模型扩展:
GNNEXPLAINER
能够处理所有基于消息传递的GNN
,包括:Graph Convolutional Networks:GCN
、Gated Graph Sequence Neural Networks:GGS-NNs
、Jumping Knowledge Networks:JK-Net
、Attention Networks-GAT
、Graph Networks:GN
、具有各种聚合方案的GNN
、Line-Graph NNs
、position-aware GNN
、以及很多其它GNN
架构。GNNEXPLAINER
优化中的参数规模取决于节点 $ v $ 的计算图 $ \mathcal G_c $ 的大小。具体而言, $ \mathcal G_c(v) $ 的邻接矩阵 $ \mathbf A_c(v) $ 等于掩码矩阵 $ \mathbf M $ 的大小,其中 $ \mathbf M $ 是需要被GNNEXPLAINER
学习的。但是,由于单个节点的计算图通常较小,因此即使完整的输入图很大
GNNEXPLAINER
仍然可以有效地生成解释。
2.2 实验
数据集:
人工合成数据集:我们人工构建了四种节点分类数据集,如下表所示。
BA-SHAPES
数据集:我们从300
个节点的Barabasi-Albert:BA
基础图、以及一组80
个五节点的房屋house
结构的主题motif
开始,这些motif
被随机添加到基础图的随机选择的节点上。进一步地我们添加 $ 0.1\times n $ 条随机边,从而得到经过扰动的合成图。根据节点的结构角色,节点为以下四种类型之一:
house
顶部节点、house
中间节点、house
底部节点、非house
节点。BA-COMMUNITY
数据集:是两个BA-SHAPES
图的并集。节点具有正态分布的特征向量,并且根据其结构角色、社区成员(两种社区)分配为8
种类别之一。TREE-CYCLES
:从8-level
平衡二叉树为基础图、一组80
个 六节点的环状motif
开始,这些motif
随机添加到基础图的随机选择的节点上。TREE-GRID
:和TREE-CYCLES
相同,除了使用3x3
的网格motif
代替六节点的环motif
之外。
真实数据集:我们考虑两个图分类数据集。
MUTAG
:包含4337
个分子图的数据集,根据分子对于革兰氏阴性菌伤寒沙门氏菌Gram-negative bacterium S.typhimurium
的诱变作用mutagenic effect
进行标记。REDDIT-BINARY
:包含2000
个图的数据集,每个图代表Reddit
上讨论的话题thread
。在每个图中,节点代表话题下参与讨论的用户,边代表一个用户对另一个用户的评论进行了回复。图根据话题中用户交互类型进行标记:
r/IAmA, r/AskReddit
包含Question-Answer
交互,r/TrollXChromosomes and r/atheism
包含Online-Discussion
交互。
Baseline
方法:很多可解释性方法无法直接应用于图,尽管如此我们考虑了以下baseline
方法,这些方法可以为GNN
的预测提供解释。GRAD
:基于梯度的方法。我们计算损失函数对于邻接矩阵的梯度、损失函数对于节点特征的梯度,这类似于显著性映射方法saliency map approach
。ATT
:基于graph attention GNN:GAT
的方法。它学习计算图中边的注意力权重,并将其视为边的重要性。尽管
ATT
考虑了图结构,但是它并未考虑节点特征的解释,而且仅能解释GAT
模型。此外,由于环
cycle
的存在(如下图所示),节点的1hop
邻居也是它的2-hop
邻居。因此使用哪个注意力权重(1hop vs 2hop
)也不是很清楚。通常我们将这些hop
的注意力权重取均值。
实验配置:对于每个数据集,我们首先为这个数据集训练一个
GNN
,然后使用GARD
和GNNEXPLAINER
来对GNN
的预测做出解释。注意,
ATT baseline
需要使用GAT
之类的图注意力架构,因此我们在同一个数据集上单独训练了一个GAT
模型,并使用学到的边注意力权重进行解释。我们对所有的节点分类任务、图分类任务中调整权重正则化参数。这些超参数在所有实验中使用。
- 子图大小正则化超参数为
0.005
,该正则化倾向于得到尽可能小的子图。 - 拉普拉斯正则化参数为
0.5
。 - 特征数量正则化参数为
0.1
,该正则化倾向于得到尽可能少的unmasked
特征。
- 子图大小正则化超参数为
我们使用
Adam
优化器训练GNN
和 解释方法explaination methods
。所有
GNN
模型都训练1000
个epoch
,学习率为0.001
, 从而对节点分类数据集达到至少85%
的准确率、对于图分类数据集达到至少95%
的准确率。对于所有数据集,
train/valid/test
拆分比例为80%:10%:10%
。GNNEXPLAINER
使用相同的优化器和学习率,并训练100 ~300
个epoch
。因为
GNNEXPLAINER
仅需要在少于100
个节点的局部计算图上进行训练,因此训练epoch
要更少一些。
为了抽取解释子图
$ \mathcal G_S $ ,我们首先计算边的重要性权重(GRAD
的梯度、ATT
的注意力权重、GNNEXPLAINER
的masked
邻接矩阵)。然后我们使用一个阈值来删除权重较低的边,从而得到 $ \mathcal G_S $ 。对于所有方法,我们执行线性搜索从而找到临界阈值,使得
$ \mathcal G_S $ 至少包含 $ K_M $ 个节点。所有数据集的
ground truth explanation
是连接的子图。对于节点分类,我们将不同方法得到的
$ \mathcal G_S $ 中抽取连通分量(如前所述,对于GNNEXPLAINER
方法来讲, $ \mathcal G_S $ 已经是连通的)来作为最终的解释。对于图分类,我们抽取
$ \mathcal G_S $ 的最大连通分量来作为最终的解释。
超参数
$ K_M $ 和 $ K_F $ 控制解释中的子图大小和特征数量,这可以从数据集相关的先验知识得到。- 对于人工合成数据集,我们将
$ K_M $ 设置为ground truth
的大小。 - 对于真实世界数据集,我们设置
$ K_M=10,K_F=5 $ 。
- 对于人工合成数据集,我们将
定量分析:对于人工合成数据集,我们已有
ground-truth
解释,然后使用这些ground-truth
来评估所有方法解释的准确性。具体而言,我们将解释问题形式化为二元分类任务,其中真实解释中的边视为label
,而将可解释性方法给出的重要性权重视为预测得分。一种更好的可解释性方法对于真实解释的边的预测得分较高,从而获得更好的解释准确率。下表给出了人工合成数据集节点分类评估结果。实验结果表明:
GNNEXPLAINER
的平均效果相比其它方法高出17.1%
。定性分析:
在没有节点特征的
topology-based
预测任务中(如BA-SHAPES、TREE-CYCLES
),GNNEXPLAINER
正确地识别解释节点标签的motif
。如下图所示,
A-B
给出了四个人工合成数据集上节点分类任务的单实例解释子图,每种方法都为红色节点的预测提供解释(绿色表示重要的节点,橙色表示不重要的节点)。可以看到GNNEXPLAINER
能识别到house, cycle, trid
等motif
,而baseline
方法无法识别。我们研究图分类任务的解释。
在
MUTAG
实例中,颜色表示节点特征,这代表原子类型(氢H
、碳C
等)。GNNEXPLAINER
可以正确的识别对于图类别比较重要的碳环、以及化学基团NH2
和NO2
,它们确实已知是诱变的mutagenic
官能团。在
REDDIT-BINARY
示例中,我们看到Question-Answe
图(B
的第二行)具有2~3
个同时连接到很多低degree
节点的高degree
节点。这是讲得通的,因为在Reddit
的问答模式的话题中,通常具有2~3
位专家都回答了许多不同的问题。相反,在
REDDIT-BINARY
的讨论模式discussion pattern
图(A
的第二行),通常表现出树状模式。GRAD,ATT
方法给出了错误的或者不完整的解释。例如两种baseline
都错过了MUTAG
数据集中的碳环。此外,尽管
ATT
可以将边注意力权重视为消息传递的重要性得分,但是权重在输入图中的所有节点之间共享,因此ATT
无法提供高质量的单实例解释。
解释
explanations
的基本要求是它们必须是可解释的interpretable
,即,提供对输入节点和预测之间关系的定性理解。下图显式了一个实验结果,其中给出不同方法预测的解释的特征。特征重要性通过热力度可视化。可以看到:
GNNEXPLAINER
确实识别出了重要的特征;而gradient-based
无法识别,它为无关特征提供了较高的重要性得分。ground-truth
特征从何而来?作者并未讲清楚。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论