数学基础
- 线性代数
- 概率论与随机过程
- 数值计算
- 蒙特卡洛方法与 MCMC 采样
- 机器学习方法概论
统计学习
深度学习
- 深度学习简介
- 深度前馈网络
- 反向传播算法
- 正则化
- 深度学习中的最优化问题
- 卷积神经网络
- CNN:图像分类
- 循环神经网络 RNN
- Transformer
- 一、Transformer [2017]
- 二、Universal Transformer [2018]
- 三、Transformer-XL [2019]
- 四、GPT1 [2018]
- 五、GPT2 [2019]
- 六、GPT3 [2020]
- 七、OPT [2022]
- 八、BERT [2018]
- 九、XLNet [2019]
- 十、RoBERTa [2019]
- 十一、ERNIE 1.0 [2019]
- 十二、ERNIE 2.0 [2019]
- 十三、ERNIE 3.0 [2021]
- 十四、ERNIE-Huawei [2019]
- 十五、MT-DNN [2019]
- 十六、BART [2019]
- 十七、mBART [2020]
- 十八、SpanBERT [2019]
- 十九、ALBERT [2019]
- 二十、UniLM [2019]
- 二十一、MASS [2019]
- 二十二、MacBERT [2019]
- 二十三、Fine-Tuning Language Models from Human Preferences [2019]
- 二十四 Learning to summarize from human feedback [2020]
- 二十五、InstructGPT [2022]
- 二十六、T5 [2020]
- 二十七、mT5 [2020]
- 二十八、ExT5 [2021]
- 二十九、Muppet [2021]
- 三十、Self-Attention with Relative Position Representations [2018]
- 三十一、USE [2018]
- 三十二、Sentence-BERT [2019]
- 三十三、SimCSE [2021]
- 三十四、BERT-Flow [2020]
- 三十五、BERT-Whitening [2021]
- 三十六、Comparing the Geometry of BERT, ELMo, and GPT-2 Embeddings [2019]
- 三十七、CERT [2020]
- 三十八、DeCLUTR [2020]
- 三十九、CLEAR [2020]
- 四十、ConSERT [2021]
- 四十一、Sentence-T5 [2021]
- 四十二、ULMFiT [2018]
- 四十三、Scaling Laws for Neural Language Models [2020]
- 四十四、Chinchilla [2022]
- 四十七、GLM-130B [2022]
- 四十八、GPT-NeoX-20B [2022]
- 四十九、Bloom [2022]
- 五十、PaLM [2022] (粗读)
- 五十一、PaLM2 [2023](粗读)
- 五十二、Self-Instruct [2022]
- 句子向量
- 词向量
- 传统CTR 预估模型
- CTR 预估模型
- 一、DSSM [2013]
- 二、FNN [2016]
- 三、PNN [2016]
- 四、DeepCrossing [2016]
- 五、Wide 和 Deep [2016]
- 六、DCN [2017]
- 七、DeepFM [2017]
- 八、NFM [2017]
- 九、AFM [2017]
- 十、xDeepFM [2018]
- 十一、ESMM [2018]
- 十二、DIN [2017]
- 十三、DIEN [2019]
- 十四、DSIN [2019]
- 十五、DICM [2017]
- 十六、DeepMCP [2019]
- 十七、MIMN [2019]
- 十八、DMR [2020]
- 十九、MiNet [2020]
- 二十、DSTN [2019]
- 二十一、BST [2019]
- 二十二、SIM [2020]
- 二十三、ESM2 [2019]
- 二十四、MV-DNN [2015]
- 二十五、CAN [2020]
- 二十六、AutoInt [2018]
- 二十七、Fi-GNN [2019]
- 二十八、FwFM [2018]
- 二十九、FM2 [2021]
- 三十、FiBiNET [2019]
- 三十一、AutoFIS [2020]
- 三十三、AFN [2020]
- 三十四、FGCNN [2019]
- 三十五、AutoCross [2019]
- 三十六、InterHAt [2020]
- 三十七、xDeepInt [2023]
- 三十九、AutoDis [2021]
- 四十、MDE [2020]
- 四十一、NIS [2020]
- 四十二、AutoEmb [2020]
- 四十三、AutoDim [2021]
- 四十四、PEP [2021]
- 四十五、DeepLight [2021]
- 图的表达
- 一、DeepWalk [2014]
- 二、LINE [2015]
- 三、GraRep [2015]
- 四、TADW [2015]
- 五、DNGR [2016]
- 六、Node2Vec [2016]
- 七、WALKLETS [2016]
- 八、SDNE [2016]
- 九、CANE [2017]
- 十、EOE [2017]
- 十一、metapath2vec [2017]
- 十二、GraphGAN [2018]
- 十三、struc2vec [2017]
- 十四、GraphWave [2018]
- 十五、NetMF [2017]
- 十六、NetSMF [2019]
- 十七、PTE [2015]
- 十八、HNE [2015]
- 十九、AANE [2017]
- 二十、LANE [2017]
- 二十一、MVE [2017]
- 二十二、PMNE [2017]
- 二十三、ANRL [2018]
- 二十四、DANE [2018]
- 二十五、HERec [2018]
- 二十六、GATNE [2019]
- 二十七、MNE [2018]
- 二十八、MVN2VEC [2018]
- 二十九、SNE [2018]
- 三十、ProNE [2019]
- Graph Embedding 综述
- 图神经网络
- 一、GNN [2009]
- 二、Spectral Networks 和 Deep Locally Connected Networks [2013]
- 三、Fast Localized Spectral Filtering On Graph [2016]
- 四、GCN [2016]
- 五、神经图指纹 [2015]
- 六、GGS-NN [2016]
- 七、PATCHY-SAN [2016]
- 八、GraphSAGE [2017]
- 九、GAT [2017]
- 十、R-GCN [2017]
- 十一、 AGCN [2018]
- 十二、FastGCN [2018]
- 十三、PinSage [2018]
- 十四、GCMC [2017]
- 十五、JK-Net [2018]
- 十六、PPNP [2018]
- 十七、VRGCN [2017]
- 十八、ClusterGCN [2019]
- 十九、LDS-GNN [2019]
- 二十、DIAL-GNN [2019]
- 二十一、HAN [2019]
- 二十二、HetGNN [2019]
- 二十三、HGT [2020]
- 二十四、GPT-GNN [2020]
- 二十五、Geom-GCN [2020]
- 二十六、Graph Network [2018]
- 二十七、GIN [2019]
- 二十八、MPNN [2017]
- 二十九、UniMP [2020]
- 三十、Correct and Smooth [2020]
- 三十一、LGCN [2018]
- 三十二、DGCNN [2018]
- 三十三、AS-GCN
- 三十四、DGI [2018]
- 三十五、DIFFPOLL [2018]
- 三十六、DCNN [2016]
- 三十七、IN [2016]
- 图神经网络 2
- 图神经网络 3
- 推荐算法(传统方法)
- 一、Tapestry [1992]
- 二、GroupLens [1994]
- 三、ItemBased CF [2001]
- 四、Amazon I-2-I CF [2003]
- 五、Slope One Rating-Based CF [2005]
- 六、Bipartite Network Projection [2007]
- 七、Implicit Feedback CF [2008]
- 八、PMF [2008]
- 九、SVD++ [2008]
- 十、MMMF 扩展 [2008]
- 十一、OCCF [2008]
- 十二、BPR [2009]
- 十三、MF for RS [2009]
- 十四、 Netflix BellKor Solution [2009]
- 推荐算法(神经网络方法 1)
- 一、MIND [2019](用于召回)
- 二、DNN For YouTube [2016]
- 三、Recommending What Video to Watch Next [2019]
- 四、ESAM [2020]
- 五、Facebook Embedding Based Retrieval [2020](用于检索)
- 六、Airbnb Search Ranking [2018]
- 七、MOBIUS [2019](用于召回)
- 八、TDM [2018](用于检索)
- 九、DR [2020](用于检索)
- 十、JTM [2019](用于检索)
- 十一、Pinterest Recommender System [2017]
- 十二、DLRM [2019]
- 十三、Applying Deep Learning To Airbnb Search [2018]
- 十四、Improving Deep Learning For Airbnb Search [2020]
- 十五、HOP-Rec [2018]
- 十六、NCF [2017]
- 十七、NGCF [2019]
- 十八、LightGCN [2020]
- 十九、Sampling-Bias-Corrected Neural Modeling [2019](检索)
- 二十、EGES [2018](Matching 阶段)
- 二十一、SDM [2019](Matching 阶段)
- 二十二、COLD [2020 ] (Pre-Ranking 模型)
- 二十三、ComiRec [2020](https://www.wenjiangs.com/doc/0b4e1736-ac78)
- 二十四、EdgeRec [2020]
- 二十五、DPSR [2020](检索)
- 二十六、PDN [2021](mathcing)
- 二十七、时空周期兴趣学习网络ST-PIL [2021]
- 推荐算法之序列推荐
- 一、FPMC [2010]
- 二、GRU4Rec [2015]
- 三、HRM [2015]
- 四、DREAM [2016]
- 五、Improved GRU4Rec [2016]
- 六、NARM [2017]
- 七、HRNN [2017]
- 八、RRN [2017]
- 九、Caser [2018]
- 十、p-RNN [2016]
- 十一、GRU4Rec Top-k Gains [2018]
- 十二、SASRec [2018]
- 十三、RUM [2018]
- 十四、SHAN [2018]
- 十五、Phased LSTM [2016]
- 十六、Time-LSTM [2017]
- 十七、STAMP [2018]
- 十八、Latent Cross [2018]
- 十九、CSRM [2019]
- 二十、SR-GNN [2019]
- 二十一、GC-SAN [2019]
- 二十二、BERT4Rec [2019]
- 二十三、MCPRN [2019]
- 二十四、RepeatNet [2019]
- 二十五、LINet(2019)
- 二十六、NextItNet [2019]
- 二十七、GCE-GNN [2020]
- 二十八、LESSR [2020]
- 二十九、HyperRec [2020]
- 三十、DHCN [2021]
- 三十一、TiSASRec [2020]
- 推荐算法(综述)
- 多任务学习
- 系统架构
- 实践方法论
- 深度强化学习 1
- 自动代码生成
工具
- CRF
- lightgbm
- xgboost
- scikit-learn
- spark
- numpy
- matplotlib
- pandas
- huggingface_transformer
- 一、Tokenizer
- 二、Datasets
- 三、Model
- 四、Trainer
- 五、Evaluator
- 六、Pipeline
- 七、Accelerate
- 八、Autoclass
- 九、应用
- 十、Gradio
Scala
- 环境搭建
- 基础知识
- 函数
- 类
- 样例类和模式匹配
- 测试和注解
- 集合 collection(一)
- 集合collection(二)
- 集成 Java
- 并发
八、GraphSAGE [2017]
在大型图中节点的低维向量
embedding
已被证明作为特征输入非常有用,可用于各种预测和图分析graph analysis
任务。node embedding
方法背后的基本思想是:使用降维技术将关于节点的graph neighborhood
的高维信息蒸馏成稠密的、低维的向量embedding
。然后可以将这些node embedding
馈入到下游机器学习系统,并帮助完成节点分类、节点聚类、以及链接预测等任务。然而,以前的工作集中在从单个固定图
a single fixed graph
上的节点的embedding
,许多实际application
需要为unseen
的节点、或全新的图快速生成embedding
。这种归纳能力inductive capability
对于高吞吐量、生产型的机器学习系统至关重要,其中这些机器学习系统在不断演变的图上运行并不断遇到unseen
的节点(如Reddit
上的帖子、Youtube
上的用户和视频)。生成node embedding
的归纳方法inductive approach
还有助于跨具有相同形式特征的图进行泛化:例如,可以在源自模型器官model organism
的protein-protein
交互图上训练一个embedding generator
,然后使用经过训练的embedding generator
轻松地为在新器官上收集的数据生成node embedding
。与直推式配置
transductive setting
相比,归纳式inductive
的node embedding
问题特别困难,因为泛化到unseen
的节点需要将新观察到的子图observed subgraph
与算法已经优化的node embedding
进行对齐aligning
。归纳式框架inductive framework
必须学会识别节点领域的结构属性,这些属性揭示了节点在图中的局部角色local role
及其全局位置global position
。大多数现有的生成
node embedding
的方法本质上都是直推式的。这些方法中的大多数使用基于矩阵分解的目标直接优化每个节点的embedding
,并且无法自然地泛化到unseen
的数据,因为它们在单个固定图上对节点进行预测。这些方法可以被修改从而在归纳式配置中运行,但是这些修改往往在计算上代价很大,需要额外的梯度下降轮次才能作出新的预测。最近还有一些使用卷积算子来学习图结构的方法,这些方法提供了作为embedding
方法的承诺(《Semi-supervised classification with graph convolutional networks》
)。到目前为止,图卷积网络graph convolutional network: GCN
仅应用于具有固定图fixed graph
的直推式配置。在论文《Inductive Representation Learning on Large Graphs》
中,作者将GCN
泛化到归纳式无监督学习的任务,并提出了一个框架,该框架泛化了GCN
方法从而使用可训练的聚合函数(超越了简单的卷积)。《Semi-supervised classification with graph convolutional networks》
提出的GCN
要求在训练过程中已知完整的图拉普拉斯算子,而测试期间unseen
的节点必然会改变图拉普拉斯算子,因此该方法也是直推式的。论文的工作:
作者提出了一个通用框架,称作
GraphSAGE
(SAmple and aggreGatE
),用于归纳式node embedding
。与基于矩阵分解的embedding
方法不同,GraphSAGE
利用节点特征(如,文本属性、节点画像node profile
信息、节点degree
)来学习一个embedding
函数,该embedding
函数可以泛化到unseen
的节点。通过在学习算法中加入节点特征,GraphSAGE
同时学习了每个节点邻域的拓扑结构、以及该邻域内节点特征的分布。虽然GraphSAGE
聚焦于特征丰富的graph
(如,具有文本属性的引文数据,具有功能标记/分子标记的生物数据),但是GraphSAGE
还可以利用所有图中存在的结构特征(如,节点degree
)。因此,GraphSAGE
也可以应用于没有节点特征的图。GraphSAGE
不是为每个节点训练一个distinct
的embedding
向量,而是训练一组聚合器函数aggregator function
,这些函数学习从节点的局部邻域来聚合特征信息(如下图所示)。每个聚合器函数聚合来自远离给定节点的不同hop
数(或搜索深度)的信息。在测试或推断时,GraphSAGE
通过应用学到的聚合函数为unseen
的节点生成embedding
。遵从之前的
node embedding
工作,作者设计了一个无监督损失函数,允许在没有task-specific
监督信息的情况下训练GraphSAGE
。作者还表明GraphSAGE
可以通过完全监督的方式进行训练。作者在三个关于节点/图分类
benchmark
上评估GraphSAGE
,这些benchmark
测试了GraphSAGE
在unseen
数据上生成有效embedding
的能力。作者使用基于引文数据和Reddit
帖子数据(分别预测论文类别和帖子类别)的两个不断演变的文档图,以及基于protein-protein
交互的数据集(预测蛋白质功能)的多图泛化multigraph generalization
实验。使用这些
benchmark
,作者表明GraphSAGE
能够有效地为unseen
的节点生成representation
,并大大优于相关baseline
:跨所有这些不同的领域,与单独使用节点特征相比,GraphSAGE
的监督方法将分类F1
分数平均提高了51%
,并且GraphSAGE
始终优于强大的直推式的baseline
,并且该baseline
需要100
轮迭代甚至更长的时间才能预测unseen
的节点。作者还表明,与受图卷积网络(
《Semi-supervised classification with graph convolutional networks》
)启发的聚合器相比,论文提出的新聚合器架构提供了显著的增益(平均增益7.4%
)。最后,作者探讨了
GraphSAGE
的表达能力expressive capability
,并通过理论分析表明:GraphSAGE
能够学到有关节点在图中的角色的结构信息,尽管它本质上是基于特征的。
相关工作:我们的算法在概念上与之前的
node embedding
方法、图上学习的通用监督方法general supervised approache
、以及将卷积神经网络应用于图结构数据的最新进展等等相关。基于分解的
embedding
方法:最近有许多node embedding
方法使用随机游走统计和基于矩阵分解的学习目标来学习低维embedding
(GraRep, node2vec, Deepwalk, Line, SDNE
)。这些方法还与更经典的谱聚类spectral clustering
方法、多维缩放multi-dimensional scaling
、以及PageRank
算法密切相关。由于这些
embedding
算法直接为单个节点individual node
训练node embedding
,因此它们本质上是直推式的,并且至少需要昂贵的额外训练(如,通过随机梯度下降)来对unseen
节点进行预测。此外,对于大多数这些方法,目标函数对于embedding
的正交变换是不变的,这意味着embedding
空间不会自然地在图之间泛化,并且在re-training
期间可能会漂移drift
。因为这些方法是基于矩阵分解的,而矩阵分解的内积函数
$ \mathbf{\vec v}_i\cdot \mathbf{\vec v}_j $ 是embedding
空间的正交不变的,即:将embeddign
空间旋转任意角度,原始内积函数和新内积函数的结果是相等的。这一趋势的一个显著例外是
Planetoid-I
算法,它是一种归纳式的、基于embedding
的半监督学习方法。但是,Planetoid-I
在推断过程中不使用任何图结构信息,相反,它在训练期间使用图结构信息作为正则化的一种形式。与先前的这些方法不同,我们利用特征信息来训练模型从而为
unseen
节点生成embedding
。图上的监督学习:除了
node embedding
方法之外,还有大量关于图结构数据的监督学习的工作。这包括各种各样的kernel-based
方法,其中图的特征向量来自于各种graph kernel
。最近还有许多神经网络方法可以对图结构数据进行监督学习。我们的方法在概念上受到大多数这些算法的启发。然而,这些方法试图对整个图(或子图)进行分类,但是我们这项工作的重点是为每个节点生成有用的representation
。图卷积网络:近年来,人们已经提出了几种用于图上学习的卷积神经网络架构。这些方法中的大多数无法扩展到大型图、或者设计用于整个图的分类。然而,我们的方法与
《Semi-supervised classification with graph convolutional networks》
提出的图卷积网络graph convolutional network: GCN
密切相关。原始的GCN
算法是为直推式setting
的半监督学习而设计的,确切exact
的算法要求在训练期间知道整个图的拉普拉斯算子。我们算法的一个简单变体可以视作GCN
框架对归纳式setting
的扩展,我们将在正文部分重新讨论这一点。
8.1 模型
- 我们方法背后的关键思想是:我们学习如何从节点的局部邻域聚合特征信息(如,邻域节点的
degree
或文本属性)。我们首先描述GraphSAGE
的embedding
生成(即,前向传播)算法,该算法在假设GraphSAGE
模型参数已经学到的情况下为节点生成embedding
。然后,我们描述了如何使用标准随机梯度下降和反向传播技术来学习GraphSAGE
模型参数。
8.1.1 前向传播
这里我们将描述前向传播算法(也叫
embedding
生成算法),其中假设模型已经训练好并且参数是固定的。具体而言,假设我们已经学到了 $ K $ 个聚合函数 $ \text{AGG}_k,k\in \{1,2,\cdots,K\} $ ,这些聚合函数用于聚合节点的邻域信息。假设我们也学到了 $ K $ 个权重矩阵 $ \mathbf W^{(k)},k\in \{1,2,\cdots,K\} $ ,它们用于在不同层之间传递信息。 $ K $ 也称作搜索深度,或layer
层数。GraphSAGE
的embedding
生成算法为:输入:
- 图
$ \mathcal G(\mathcal V,\mathcal E) $ ,输入特征 $ \left\{\mathbf{\vec x}_v\mid v\in \mathcal V\right\} $ ,搜索深度 $ K $ ,邻域函数 $ \mathcal N(\cdot) $ $ K $ 个权重矩阵 $ \mathbf W^{(k)} $ , $ K $ 个聚合函数 $ \text{AGG}_k $ , $ k\in \{1,\cdots,K\} $- 非线性激活函数
$ \sigma(\cdot) $
- 图
输出:节点的
embedding
向量 $ \left\{\mathbf{\vec z}_v\mid v \in \mathcal V \right\} $算法步骤:
初始化:
$ \mathbf{\vec h}_v^{(0)} = \mathbf{\vec x}_v, v\in \mathcal V $对每一层迭代,迭代条件为:
$ k=1,2,\cdots,K $ 。迭代步骤:遍历每个节点
$ v\in \mathcal V $ ,执行:其中
concat()
表示向量拼接。这里是拼接融合,也可以考虑其它类型的融合方式。
对每个节点
$ v $ 的隐向量归一化:
$ \mathbf{\vec z}_v= \mathbf{\vec h}_v^{(K)} $
GraphSAGE
前向传播算法的背后直觉是:在每次迭代或搜索深度,节点都会聚合来自其局部邻域的信息;并且随着这个过程的迭代,节点将从图的更远范围逐渐获取越来越多的信息。在算法的外层循环中的每个
step
如下进行,其中 $ k $ 表示外层循环中的current step
(也叫做搜索深度), $ \mathbf{\vec h}^{(k)} $ 表示该step
中的node representation
:首先,每个节点
$ v\in \mathcal V $ 聚合其直接邻域中节点的representation
$ \left\{\mathbf{\vec h}_u^{(k-1)}\mid u\in \mathcal N(v)\right\} $ 到一个向量 $ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 中。注意,这个聚合步骤依赖于第 $ k-1 $ 轮迭代产生的node representation
(即 $ \mathbf{\vec h}^{(k-1)} $ ),并且 $ k=0 $ 时的representation
被定义为节点输入特征 $ \mathbf{\vec x} $ 。邻域
representation
可以通过各种聚合器架构(以AGGREGATE
占位符来表达)来完成,接下来我们会讨论不同的架构选择。然后,在聚合邻域特征向量之后,
GraphSAGE
将节点的当前representation
$ \mathbf{\vec h}_v^{(k-1)} $ 和聚合后的邻域向量 $ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 拼接起来,然后通过一个带非线性激活函数 $ \sigma(\cdot) $ 的全连接层。这个全连接层的输出就是下一个step
要用到的representation
,即 $ \mathbf{\vec h}^{(k)} $ 。大多数节点
embedding
方法将学到的embedding
归一化为单位向量,这里也做类似处理。
为了记号方便,我们将第
$ K $ 步的final representation
记做 $ \mathbf{\vec z}_v=\mathbf{\vec h}_v^{(K)},\forall v\in \mathcal V $ 。
a. mini-batch 训练
为了将算法扩展到
mini-batch setting
,给定一组输入节点,我们首先前向采样forward sample
所需要的邻域集合(直到深度 $ K $ )然后执行内层循环,而不是迭代所有节点。我们仅计算满足每个 $ k $ 所需的representation
(而不是所有节点的representation
)。为了使用随机梯度下降算法,我们需要对
GraphSAGE
的前向传播算法进行修改,从而允许mini-batch
中每个节点能够执行前向传播、反向传播。即:确保前向传播、反向传播过程中用到的节点都在同一个
mini-batch
中。GraphSAGE mini-batch
前向传播算法(这里 $ \mathcal B $ 包含了我们想要为其生成representation
的节点):算法输入:
- 图
$ \mathcal G(\mathcal V,\mathcal E) $ ,输入特征 $ \left\{\mathbf{\vec x}_v\mid v\in \mathcal B\right\} $ ,搜索深度 $ K $ ,邻域函数 $ \mathcal N(\cdot) $ $ K $ 个权重矩阵 $ \mathbf W^{(k)} $ , $ K $ 个聚合函数 $ \text{AGG}_k $ , $ k\in \{1,\cdots,K\} $- 非线性激活函数
$ \sigma(\cdot) $
- 图
输出:节点的
embedding
向量 $ \left\{\mathbf{\vec z}_v\mid v \in \mathcal B\right\} $算法步骤:
初始化:
$ \mathcal B^{(K)} = \mathcal B $迭代
$ k=K,\cdots,1 $ ,迭代步骤为: $ \mathcal B^{(k-1)} = \mathcal B^{(k)} $- 遍历
$ u\in \mathcal B^{(k)} $ ,计算 $ \mathcal B^{(k-1)} = \mathcal B^{(k-1)}\bigcup \mathcal N_k(u) $
初始化:
$ \mathbf{\vec h}_v^{(0)} = \mathbf{\vec x}_v, v\in \mathcal B^{(0)} $对每一层迭代,迭代条件为:
$ k=1,2,\cdots,K $ 。迭代步骤:遍历每个节点
$ v\in \mathcal B^{(k)} $ ,执行:这里用
$ \mathcal N_k(v) $ 表示节点 $ v $ 的邻域在每个深度 $ k $ 都不相同,依赖于前向采样的结果。对每个节点
$ v $ 的隐向量归一化:
$ \mathbf{\vec z}_v = \mathbf{\vec h}_v^{(K)},v\in \mathcal B $
mini-batch
前向传播算法的主要思想是:首先采样所有所需的节点。集合 $ \mathcal B^{(k-1)} $ 包含了第 $ k $ 轮迭代计算representation
的节点所依赖的节点集合。由于 $ \mathcal B^{(k)} \sube \mathcal B^{(k-1)} $ ,所以在计算 $ \mathbf{\vec h}_v^{(k)} $ 时依赖的 $ \mathbf{\vec h}_v^{(k-1)} $ 已经在第 $ k-1 $ 轮已被计算。另外第 $ k $ 轮需要计算representation
的节点更少,这避免计算不必要的节点。然后计算目标节点的
representation
,这一步和batch
前向传播算法相同。mini-batch
前向传播和batch
前向传播的主要区别在于:mini-batch
前向传播还有一个前向采样的步骤。我们使用
$ \mathcal N_k(\cdot) $ 的 $ k $ 来表明:不同层之间使用独立的random walk
采样。这里我们使用均匀采样,并且当节点邻域节点数量少于指定数量时采用有放回的采样,否则使用无放回的采样。有一些算法聚焦于如何更好地进行采样,从而优化最终效果。
mini-batch
算法的采样过程在概念上与batch
算法的迭代过程是相反的。我们从需要以深度 $ K $ 生成representation
的节点开始,然后我们对它们的邻域进行采样(即,深度 $ K-1 $ ),依此类推。这样做的一个后果是邻域采样规模的定义可能有点违反直觉。具体而言,假设 $ K=2 $ :在
batch
算法中,我们在 $ k=1 $ 时对节点邻域内采样 $ S_1 $ 个节点,在 $ k=2 $ 时对节点邻域内采样 $ S_2 $ 个节点。在
mibi-batch
算法中,我们在 $ k= 2 $ 时对节点邻域内采样 $ S_2 $ 个节点,然后在 $ k=1 $ 时对节点邻域内采样 $ S_1\times S_2 $ 个节点。这样才能保证我们的目标
$ \mathcal B $ 中包含mibi-batch
所需要计算的所有节点。
b. 和 WL-Test 关系
GraphSAGE
算法在概念上受到图的同构性检验的经典算法的启发。在前向传播过程中,如果令 $ K=|\mathcal V| $ 、 $ \mathbf W^{(k)} = \mathbf I $ ,并选择合适的hash
函数来作为聚合函数,同时移除非线性函数,则该算法是Weisfeiler-Lehman:WL
同构性检验算法的一个特例,被称作naive vertex refinement
。如果算法输出的
node representation
$ \left\{\mathbf{\vec z}_v,v\in \mathcal V\right\} $ 在两个子图是相等的,则WL-test
算法认为这两个子图是同构的。虽然在某些情况下该检验会失败,但是大多数情况下该检验是有效的。GraphSAGE
是WL test
算法的一个continous
近似,其中GraphSAGE
使用可训练的神经网络聚合函数代替了不连续的哈希函数。虽然GraphSAGE
的目标是生成节点的有效embedding
而不是检验图的同构性,但是GraphSAGE
和WL test
之间的联系为我们设计学习节点邻域拓扑结构的算法提供了理论背景。可以证明:即使我们提供的是节点特征信息,
GraphSAGE
也能够学到图的结构信息。参考 “理论分析” 部分。
c. 邻域定义
在
GraphSAGE
中我们并没有使用完整的邻域,而是均匀采样一组固定大小的邻域,从而确保每个batch
的计算代价是固定的。因此我们定义 $ \mathcal N(v) $ 为:从集合 $ \{u\mid u\in \mathcal V,(u,v)\in \mathcal E\} $ 中均匀采样的、固定大小的集合,并且我们在算法的每轮迭代 $ k $ 中采样不同的邻域。如果对每个节点使用完整的邻域,则每个
batch
的内存需求和运行时间是不确定的,最坏情况为 $ O(|\mathcal V|) $ 。如果使用采样后的邻域,则每个batch
的时间和空间复杂度固定为 $ O(\prod_{k=1}^KS_k) $ ,其中 $ S_k $ 表示第 $ k $ 轮迭代时的邻域大小。 $ K $ 以及 $ S_k $ 均为用户指定的超参数,实验发现 $ K=2, S_1\times S_2\le 500 $ 时的效果较好。 $ K $ 和 $ S_k $ 依赖于具体的数据集和任务。
8.1.2 模型学习
为了在完全无监督的环境中学习有用的、预测性的
representation
,我们将一个graph-based
损失函数应用于output representation
$ \mathbf{\vec z}_u,\forall u\in \mathcal V $ ,并且通过随机梯度下降来学习模型参数。这个graph-based
损失函数鼓励临近的节点具有相似的representation
,同时迫使不相近的节点具有高度不相似的representation
:其中:
$ v $ 是和节点 $ u $ 在一个长度为 $ l $ 的random walk
上共现的节点。sigmoid(.)
为sigmoid
函数。 $ P_n(\cdot) $ 为负采样用到的分布函数, $ v_n $ 为负采样到的negative node
, $ Q $ 为负采样的样本数。
重要的是,与之前的
embedding
方法不同,GraphSAGE
中的节点representation
$ \mathbf{\vec z}_u $ 是从节点局部邻域中包含的特征而生成的,而不是通过embedding look-up
而生成的。可以看到,
GraphSAGE
和DeepWalk
类似,也依赖于图上的随机游走过程。为了提高训练效率,通常在训练之前执行一次随机游走过程(避免在训练的每轮迭代中进行随机游走)。以无监督方式学到的节点
embedding
可以作为通用service
来服务于下游的机器学习任务。但是如果仅在特定的任务上应用,则可以简单地将特定于任务的监督学习损失替代或增强原始的无监督损失。通过结合监督损失和无监督损失,那么可以同时利用
labeled
数据和unlabeled
数据,即半监督学习。
8.1.3 聚合函数
和网格型数据(如文本、图像)不同,图的节点之间没有任何顺序关系,因此算法中的聚合函数必须能够在无序的节点集合上运行。理想的聚合函数是对称的,同时可训练并保持较高的表达能力。这种对称性可以确保我们的神经网络模型可以用于任意顺序的节点邻域的训练和测试。
对称性是指:对于给定的一组节点集合,无论它们以何种顺序输入到聚合函数,聚合后的输出结果不变。
聚合函数有多种形式,我们检查了三种主要的聚合函数:均值聚合函数
mean aggregator
、LSTM
聚合函数LSTM aggregator
、池化聚合函数pooling aggregator
。mean aggregator
:简单的使用邻域节点的特征向量的逐元素均值来作为聚合结果。这几乎等价于直推式GCN
框架中的卷积传播规则。具体而言,如果我们将前向传播:
替换为:
则这得到直推式
GCN
的一个inductive
变种,我们称之为基于均值聚合的卷积mean-based aggregator convolutional
。它是局部谱卷积localized spectral convolution
的一个粗糙的线性近似。GCN
的前向传播为:其中:
$ \tilde{\mathbf A} = \mathbf A + \mathbf I $ , $ \mathbf A $ 为邻接矩阵, $ \tilde{\mathbf D} $ 为 $ \tilde{\mathbf A} $ 的degree
矩阵。因此有:
注意,
GCN
的 $ \mathbf{\vec h}_v^{(0)} $ 是通过embedding look-up
而生成的(而不是输入特征 $ \mathbf{\vec x}_v $ )。这个卷积聚合器与我们提出的其它聚合器之间的一个重要区别在于:它并未执行拼接操作(即,将
$ \mathbf{\vec h}_v^{(k-1)} $ 和 $ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 拼接起来) 。这种拼接操作可以视为GraphSAGE
算法的不同search depth
(或layer
)之间的skip connection
的一种简单形式,它可以显著提高性能。事实上其它聚合器在拼接操作之后执行了带非线性激活函数的投影,因此破坏了这种
skip connection
。是否修改为以下形式更好?LSTM aggregator
:和均值聚合相比,LSTM
具有更强大的表达能力。但是LSTM
原生的是非对称的(即,LSTM
不是permutation invariant
的),它依赖于节点的输入顺序。因此我们通过简单地将LSTM
应用于邻域节点的随机排序,从而使得LSTM
可以应用于无序的节点集合。pooling aggregator
:池化聚合器是对称的、可训练的。在这种池化方法中,邻域每个节点的特征向量都通过全连接神经网络独立馈入,然后通过一个逐元素的最大池化来聚合邻域信息:其中
max
表示逐元素的max
运算符, $ \sigma(\cdot) $ 是非线性激活函数。理论上可以在最大池化之前使用任意深度的多层感知机,但是我们这里专注于简单的单层网络结构。直观上看,可以将多层感知机视为一组函数,这组函数为邻域集合内的每个节点
representation
计算特征。通过将最大池化应用到这些计算到的特征上,模型可以有效捕获邻域集合的不同方面aspect
。理论上可以使用任何的对称向量函数(如逐元素均值)来替代
max
运算符。但是我们在实验中发现最大池化和均值池化之间没有显著差异,因此我们专注于最大池化。
8.1.4 理论分析
这里我们将探讨
GraphSAGE
的表达能力,以便深入了解GraphSAGE
如何学习图结构,即使它本质上是基于特征的。作为案例研究,我们考虑GraphSAGE
是否可以学习预测节点的聚类系数clustering coefficient
,即:在节点的1-hop
邻域内,闭合的三角形占所有三角形(闭合的和未闭合的)的比例。聚类系数是衡量节点局部邻域聚类程度的常用指标,它可以作为许多更复杂的结构主题structural motif
的building block
。可以证明:GraphSAGE
算法能够将聚类系数逼近到任意精度。定理:令
$ \mathbf{\vec x}_v\in \mathbb U,v\in \mathcal V $ 作为GraphSAGE
算法针对图 $ \mathcal G=(\mathcal V,\mathcal E) $ 的输入,其中 $ \mathbb U $ 是 $ \mathbb R^d $ 的一个紧致子集compact subset
。假设存在一个固定的正的常数 $ C\in \mathbb R^+ $ 使得 $ \left\|\mathbf{\vec x}_v - \mathbf{\vec x}_{v^\prime}\right\|_2\gt C $ 对任意节点pair
$ (v,v^\prime) $ 成立,那么我们有:对于任意 $ \epsilon \gt 0 $ ,这里存在一个参数setting
$ \mathbf \Theta^* $ ,使得GraphSAGE
算法在 $ K=4 $ 轮迭代之后有:其中:
$ z_v\in \mathbb R $ 为GraphSAGE
算法的final output
值, $ c_v $ 为节点的聚类系数。注意,这里假设
output representation
是一维的。上述定理指出:对于任意的图,
GraphSAGE
算法都存在一个参数setting
,如果每个每个节点的特征都是不同的(并且如果模型足够高维),那么算法可以将图的聚类系数逼近到任意精度。证明见原始论文。注意:作为该定理的推论,
GraphSAGE
可以了解局部图结构,即使节点特征输入是从连续随机分布中采样的(因此特征输入与图结构无关)。证明背后的基本思想是:如果每个节点都有一个
unique
的特征,那么我们可以学习将节点映射到indicator
向量并识别节点邻域。定理的证明依赖于池化聚合器的一些属性,这也提供了为什么GraphSAGE-pool
优于GCN
、以及mean-based
聚合器的洞察。
8.2 实验
我们在三个
benchmark
任务上检验GraphSAGE
的效果:Web of Science Citation
数据集的论文分类任务、Reddit
数据集的帖子分类任务、PPI
数据集的蛋白质分类任务。前两个数据集是对训练期间
unseen
的节点进行预测,最后一个数据集是对训练期间unseen
的图进行预测。数据集:
Web of Science Cor Collection
数据集:包含2000
年到2005
年六个生物学相关领域的所有论文,每篇论文属于六种主题类别之一。数据集包含302424
个节点,节点的平均degree
为9.15
。其中:Immunology
免疫学的标签为NI
,节点数量77356
。Ecology
生态学的标签为GU
,节点数量37935
。Biophysics
生物物理学的标签为DA
,节点数量36688
。Endocrinology and Metabolism
内分泌与代谢的标签为IA
,节点数量52225
。Cell Biology
细胞生物学的标签为DR
,节点数量84231
。Biology(other)
生物学其它的标签为CU
,节点数量13988
。
任务目标是预测论文主题的类别。我们根据
2000-2004
年的数据来训练所有算法,并用2005
年的数据进行进行测试(其中30%
用于验证)。我们使用节点
degree
和文章的摘要作为节点的特征,其中节点摘要根据Arora
等人的方法使用sentence embedding
方法来处理文章的摘要,并使用Gensim word2vec
的实现来训练了300
维的词向量。Reddit
数据集:包含2014
年9
月Reddit
上发布帖子的一个大型图数据集,节点标签为帖子所属的社区。我们采样了50
个大型社区,并构建一个帖子到帖子的图。如果一个用户同时在两个帖子上发表评论,则这两个帖子将链接起来。数据集包含232965
个节点,节点的平均degree
为492
。为了对社区进行采样,我们按照每个社区在
2014
年的评论总数对社区进行排名,并选择排名在[11,50]
(包含)的社区。我们忽略了最大的那些社区,因为它们是大型的、通用的默认社区,会严重扭曲类别的分布。我们选择这些社区上定义的最大连通图largest connected component
。任务的目标是预测帖子的社区
community
。我们将该月前20
天用于训练,剩下的天数作为测试(其中30%
用于验证)。我们使用帖子的以下特征:标题的平均
embedding
、所有评论的平均embedding
、帖子评分、帖子评论数。其中embedding
直接使用现有的300
维的GloVe CommonCral
词向量,而不是在所有帖子中重新训练。PPI
数据集:包含Molecular Signatures Dataset
中的图,每个图对应于不同的人类组织,节点标签采用gene ontology sets
,一共121
种标签。平均每个图包含2373
个节点,所有节点的平均degree
为28.8
。任务的目的是评估模型的跨图泛化的能力。我们在
20
个随机选择的图上进行训练、2
个图进行验证、2
个图进行测试。其中训练集中每个图至少有15000
条边,验证集和测试集中每个图都至少包含35000
条边。注意:对于所有的实验,验证集和测试集是固定选择的,训练集是随机选择的。我们最后给出测试图上的micro-F1
指标。我们使用
positional gene sets
、motif gene sets
以及immunological signatures
作为节点特征。我们选择至少在10%
的蛋白质上出现过的特征,低于该比例的特征不被采纳。最终节点特征非常稀疏,有42%
的节点没有非零特征(即,42%
的节点的特征全是空的),这使得节点之间的链接非常重要。
Baseline
模型:- 随机分类器。
- 基于节点特征的逻辑回归分类器(完全忽略图的结构信息)。
- 代表因子分解方法的
DeepWalk
算法+逻辑回归分类器(完全忽略节点的特征)。 - 拼接了
DeepWalk
的embedding
以及节点特征的方法(融合图的节点特征和结构特征)。
我们使用了不同聚合函数的
GraphSAGE
的四个变体。由于卷积的变体是GCN
的inductive
扩展,因此我们称其为GraphSAGE-GCN
。我们使用了
GraphSAGE
的无监督版本,也直接使用分类交叉熵作为损失的有监督版本。模型配置:
GrahSage
:- 所有
GraphSAGE
模型都在Tensorflow
中使用Adam
优化器实现, 而DeepWalk
在普通的随机梯度优化器中表现更好。 - 为防止
GraphSAGE
聚合函数的效果比较时出现意外的超参数hacking
,我们对所有GraphSAGE
版本进行了相同的超参数配置:根据验证集的性能为每个版本提供最佳配置。 - 对于所有的
GraphSAGE
版本设置 $ K=2 $ 以及邻域采样大小 $ S_1=25, S_2= 10 $ 。 - 对于所有的
GraphSAGE
,我们对每个节点执行以该节点开始的50
轮长度为5
的随机游走序列,从而得到pair
节点对。我们的随机游走序列生成完全基于Python
代码实现。 - 由于节点
degree
分布的长尾效应,我们将GraphSAGE
算法中所有图的边执行降采样预处理。经过降采样之后,使得没有任何节点的degree
超过128
。由于我们每个节点最多采样25
个邻居,因此这是一个合理的权衡。
- 所有
为公平比较,所有模型都采样相同的
mini-batch
迭代器、损失函数(当然监督损失和无监督损失不同)、邻域采样器。对于原生特征模型,以及基于无监督模型的
embedding
进行预测时,我们使用scikit-learn
中的SGDClassifier
逻辑回归分类器,并使用默认配置。在所有配置中,我们都对学习率和模型的维度以及
batch-size
等等进行超参数选择:除了
DeepWalk
之外,我们为监督学习模型设置初始学习率的搜索空间为 $ \{0.01,0.001,0.0001\} $ ,为无监督学习模型设置初始学习率的搜索空间为 $ \{2\times 10^{-6},2\times 10^{-7},2\times 10^{-8}\} $ 。最初实验表明
DeepWalk
在更大的学习率下表现更好,因此我们选择DeepWalk
的初始学习率搜索空间为 $ \{0.2,0.4,0.8\} $ 。我们测试了每个
GraphSAGE
模型的big
版本和small
版本。- 对于池化聚合函数,
big
模型的池化层维度为1024
,small
模型的池化层维度为512
。 - 对于
LSTM
聚合函数,big
模型的隐层维度为256
,small
模型的隐层维度为128
。
注意,这里设置的是聚合器的维度,而不是
hidden representation
的维度。- 对于池化聚合函数,
所有实验中,我们将
GraphSAGE
每一层的 $ \mathbf{\vec h}_i^{(k)} $ 的维度设置为256
。所有的
GraphSAGE
以及DeepWalk
的非线性激活函数为ReLU
。对于无监督
GraphSAGE
和DeepWalk
模型,我们使用20
个负采样的样本,并且使用0.75
的平滑参数对节点的degree
进行上下文分布平滑。对于监督
GraphSAGE
,我们为每个模型运行10
个epoch
。我们对
GraphSAGE
选择batch-size = 512
。对于DeepWalk
我们使用batch-size=64
,因为我们发现这个较小的batch-size
收敛速度更快。
硬件配置:
DeepWalk
在CPU
密集型机器上速度更快,它的硬件参数为144 core
的Intel Xeon CPU(E7-8890 V3 @ 2.50 GHz)
,2T
内存。- 其它模型在单台机器上实验,该机器具有
4
个NVIDIA Titan X Pascal GPU
(12 Gb
显存,10Gbps
访问速度),16 core
的Intel Xeon CPU(E5-2623 v4 @ 2.60GHz)
,以及256 Gb
内存。
所有实验在共享资源环境下大约进行了
3
天。我们预期在消费级的单GPU
机器上(如配备了Titan X GPU
)的全部资源专用,可以在4
到7
天完成所有实验。DeepWalk
测试阶段:对于
Reddit
和引文数据集,我们按照Perozzi
等人的描述对DeepWalk
执行oneline
训练。对于新的测试节点,我们进行了新一轮的SGD
优化,从而得到新节点的embedding
。现有的
DeepWalk
实现仅仅是word2vec
代码的封装,它难以支持embedding
新节点以及其它变体。这里我们根据tensorflow
中的官方word2vec
教程实现了DeepWalk
。为了得到新节点的embedding
,我们在保持已有节点的embedding
不变的情况下,对每个新的节点执行50
个长度为5
的随机游走序列,然后更新新节点的embedding
。我们还测试了两种变体:一种是将采样的随机游走“上下文节点”限制为仅来自已经训练过的旧节点集合,这可以缓解统计漂移;另一种是没有该限制。我们总数选择性能最强的那个。
尽管
DeepWalk
在inductive
任务上的表现很差,但是在transductive
环境下测试时它表现出更强的竞争力。因为在该环境下DeepWalk
可以在单个固定的图上进行持续的训练。我们观察到在inductive
环境下DeepWalk
的性能可以通过进一步的训练来提高。并且在某种情况下,如果让它比其它方法运行的时间长1000
倍,则它能够达到与无监督GraphSAGE
(而不是有监督GraphSAGE
)差不多的性能。但是我们不认为这种比较对于inductive
是有意义的。在
PPI
数据集中我们无法应用DeepWalk
,因为在不同的、不相交的图上运行DeepWalk
算法生成的embedding
空间可以相对于彼此任意旋转。参考最后一小节的证明。
GraphSAGE
及baseline
在这三个任务上的表现如下表所示。这里给出的是测试集上的micro-F1
指标,对于macro-F1
结果也有类似的趋势。其中Unsup
表示无监督学习,Sup
表示监督学习。GraphSAGE
的性能明显优于所有的baseline
模型。根据
GraphSAGE
不同版本可以看到:与GCN
聚合方式相比,可训练的神经网络聚合函数具有明显的优势。注意,这里的
GraphSAGE-mean
是将GraphSAGE-pool
的max
函数替换为mean
得到。尽管
LSTM
这种聚合函数是为有序数据进行设计而不是为无序set
准备的,但是通过随机排列的方式,它仍然表现出出色的性能。和监督版本的
GraphSAGE
相比,无监督GraphSAGE
版本的性能具有相当竞争力。这表明我们的框架无需特定于具体任务就可以实现强大的性能。
通过在
Reddit
数据集上不同模型的训练和测试的运行时间如下表所示,其中batch size = 512
,测试集包含79534
个节点。可以看到:- 这些方法的训练时间相差无几,其中
GraphSAGE-LSTM
最慢。 - 除了
DeepWalk
之外,其它方法的测试时间也相差无几。由于DeepWalk
需要采样新的随机游走序列,并运行多轮SGD
随机梯度下降来生成unseen
节点的embedding
,这使得DeepWalk
在测试期间慢了100~500
倍。
- 这些方法的训练时间相差无几,其中
对于
GraphSAGE
变体,我们发现和 $ K=1 $ 相比,设置 $ K=2 $ 使得平均准确性可以一致性的提高大约10%~15%
。但是当 $ K $ 增加到2
以上时会导致性能的回报较低(0~5%
) ,但是运行时间增加到夸张的10~100
倍,具体取决于采样邻域的大小。另外,随着采样邻域大小逐渐增加,模型获得的收益递减。因此,尽管对邻域的采样引起了更高的方差,但是
GraphSAGE
仍然能够保持较强的预测准确性,同时显著改善运行时间。下图给出了在引文网络数据集上GraphSAGE-mean
模型采用不同邻域大小对应的模型性能以及运行时间,其中 $ K=2 $ 以及 $ S_1=S_2 $ 。总体而言我们发现就平均性能和超参数而言,基于
LSTM
聚合函数和池化聚合函数的表现最好。为了定量的刻画这种比较优势,我们将三个数据集、监督学习/无监督学习两种方式一共六种配置作为实验,然后使用Wilcoxon Signed-Rank Test
来量化不同模型的性能。结论:
- 基于
LSTM
聚合函数和池化聚合函数的效果确实最好。 - 基于
LSTM
聚合函数的效果和基于池化聚合函数的效果相差无几,但是由于GraphSAGE-LSTM
比GraphSAGE-pool
慢得多(大约2
倍),这使得基于池化的聚合函数总体上略有优势。
- 基于
8.3 DeepWalk Embedding 旋转不变性
DeepWalk,node2vec
以及其它类似的node embedding
方法的目标函数都有类似的形式:其中:
$ f(\cdot),g(\cdot) $ 为平滑、连续的函数。 $ \mathbf{\vec z}_i $ 为直接优化的node embedding
(通过embedding
的look up
得到)。 $ \mathcal A,\mathcal B $ 为满足某些条件的节点pair
对。
事实上这类方法可以认为是一个隐式的矩阵分解
$ \mathbf Z^\top\mathbf Z \simeq \mathbf M \in \mathbb R^{|\mathcal V|\times |\mathcal V|} $ ,其中: $ \mathbf Z\in \mathbb R^{d\times |\mathcal V|} $ 的每一列代表一个节点的embedding
。 $ \mathbf M\in \mathbb R^{|\mathcal V|\times |\mathcal V|} $ 是一个包含某些随机游走统计量的矩阵。
这类方法的一个重要结果是:
embedding
可以通过任意单位正交矩阵变换,从而不影响矩阵分解:其中
$ \mathbf Q\in \mathbb R^{d\times d} $ 为任意单位正交矩阵。所以整个embedding
空间在训练过程中可以自由旋转。embedding
矩阵可以在embedding
空间可以自由旋转带来两个明显后果:如果我们在两个单独的图
A
和B
上基于 $ \mathcal L $ 来训练embedding
方法,如果没有一些明确的惩罚机制来强制两个图的节点对齐,则两个图学到的embedding
空间将相对于彼此可以任意旋转。因此,对于在图A
的节点embedding
上训练的任何节点分类模型,如果直接灌入图B
的节点embedding
,这这等效于对该分类模型灌入随机数据。如果我们有办法在图之间对齐节点,从而在图之间共享信息,则可以缓解该问题。研究如何对齐是未来的方向,但是对齐过程不可避免地在新数据集上运行缓慢。
而
GraphSAGE
完全无需做额外地节点对齐,它可以简单地为新节点生成embedding
信息。如果在时刻
$ t $ 对图A
基于 $ \mathcal L $ 来训练embedding
方法,然后在学到的embedding
上训练分类器。如果在时刻 $ t+1 $ ,图A
添加了一批新的节点,并通过运行新一轮的随机梯度下降来更新所有节点的embedding
,则这会导致两个问题:- 首先,类似于上面提到的第一点,如果新节点仅连接到少量的旧节点,则新节点的
embedding
空间实际上可以相对于原始节点的embedding
空间任意旋转。 - 其次,如果我们在训练过程中更新所有节点的
embedding
,则相比于我们训练分类模型所依赖的原始embedding
空间相比,我们新的embedding
空间可以任意旋转。
- 首先,类似于上面提到的第一点,如果新节点仅连接到少量的旧节点,则新节点的
这类
embedding
空间旋转问题对于依赖成对节点距离的任务(如通过embedding
的点积来预测链接)没有影响。因为不管
embedding
空间怎么旋转,节点之间的距离不变(如通过内积的距离,或通过欧式距离的距离)。缓解这类统计漂移问题(即
embedding
空间旋转)的一些方法为:- 为新节点训练
embedding
时,不要更新已经训练的embedding
。 - 在采样的随机游走序列中,仅保留旧节点为上下文节点,从而确保
skip-gram
目标函数中的每个点积操作都是一个旧节点和一个新节点。
我们尝试了这两种方式,并始终选择效果最好的
DeepWalk
变体。- 为新节点训练
从经验来讲,
DeepWalk
在引文网络上的效果要比Reddit
网络更好。因为和引文网络相比,Reddit
的这种统计漂移更为严重:Reddit
数据集中,从测试集链接到训练集的边更少。在引文网络中,测试集有96%
的新节点链接到训练集;在Reddit
数据集中,测试集只有73%
的新节点链接到训练集。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论