返回介绍

数学基础

统计学习

深度学习

工具

Scala

八、GraphSAGE [2017]

发布于 2023-07-17 23:38:24 字数 126056 浏览 0 评论 0 收藏 0

  1. 在大型图中节点的低维向量 embedding 已被证明作为特征输入非常有用,可用于各种预测和图分析graph analysis 任务。node embedding 方法背后的基本思想是:使用降维技术将关于节点的 graph neighborhood 的高维信息蒸馏成稠密的、低维的向量 embedding 。然后可以将这些 node embedding 馈入到下游机器学习系统,并帮助完成节点分类、节点聚类、以及链接预测等任务。

    然而,以前的工作集中在从单个固定图a single fixed graph上的节点的 embedding ,许多实际 application 需要为 unseen 的节点、或全新的图快速生成 embedding 。这种归纳能力 inductive capability 对于高吞吐量、生产型的机器学习系统至关重要,其中这些机器学习系统在不断演变的图上运行并不断遇到 unseen 的节点(如 Reddit 上的帖子、Youtube 上的用户和视频)。生成 node embedding 的归纳方法 inductive approach 还有助于跨具有相同形式特征的图进行泛化:例如,可以在源自模型器官 model organismprotein-protein 交互图上训练一个 embedding generator ,然后使用经过训练的 embedding generator 轻松地为在新器官上收集的数据生成 node embedding

    与直推式配置 transductive setting 相比,归纳式inductivenode embedding 问题特别困难,因为泛化到 unseen 的节点需要将新观察到的子图observed subgraph 与算法已经优化的 node embedding 进行对齐 aligning 。归纳式框架 inductive framework 必须学会识别节点领域的结构属性,这些属性揭示了节点在图中的局部角色local role 及其全局位置global position

    大多数现有的生成 node embedding 的方法本质上都是直推式的。这些方法中的大多数使用基于矩阵分解的目标直接优化每个节点的 embedding ,并且无法自然地泛化到 unseen 的数据,因为它们在单个固定图上对节点进行预测。这些方法可以被修改从而在归纳式配置中运行,但是这些修改往往在计算上代价很大,需要额外的梯度下降轮次才能作出新的预测。最近还有一些使用卷积算子来学习图结构的方法,这些方法提供了作为 embedding 方法的承诺(《Semi-supervised classification with graph convolutional networks》)。到目前为止,图卷积网络 graph convolutional network: GCN 仅应用于具有固定图 fixed graph 的直推式配置。在论文《Inductive Representation Learning on Large Graphs》 中,作者将 GCN 泛化到归纳式无监督学习的任务,并提出了一个框架,该框架泛化了 GCN 方法从而使用可训练的聚合函数(超越了简单的卷积)。

    《Semi-supervised classification with graph convolutional networks》 提出的 GCN 要求在训练过程中已知完整的图拉普拉斯算子,而测试期间 unseen 的节点必然会改变图拉普拉斯算子,因此该方法也是直推式的。

    论文的工作:

    • 作者提出了一个通用框架,称作 GraphSAGESAmple and aggreGatE),用于归纳式 node embedding。与基于矩阵分解的 embedding 方法不同,GraphSAGE 利用节点特征(如,文本属性、节点画像node profile信息、节点 degree )来学习一个 embedding 函数,该embedding 函数可以泛化到 unseen 的节点。通过在学习算法中加入节点特征,GraphSAGE 同时学习了每个节点邻域的拓扑结构、以及该邻域内节点特征的分布。虽然GraphSAGE 聚焦于特征丰富的 graph(如,具有文本属性的引文数据,具有功能标记/分子标记的生物数据),但是GraphSAGE 还可以利用所有图中存在的结构特征(如,节点 degree)。因此,GraphSAGE 也可以应用于没有节点特征的图。

    • GraphSAGE 不是为每个节点训练一个distinctembedding 向量,而是训练一组聚合器函数 aggregator function ,这些函数学习从节点的局部邻域来聚合特征信息(如下图所示)。每个聚合器函数聚合来自远离给定节点的不同 hop 数(或搜索深度)的信息。在测试或推断时,GraphSAGE 通过应用学到的聚合函数为 unseen 的节点生成 embedding

      遵从之前的 node embedding 工作,作者设计了一个无监督损失函数,允许在没有task-specific 监督信息的情况下训练 GraphSAGE 。作者还表明 GraphSAGE 可以通过完全监督的方式进行训练。

    • 作者在三个关于节点/图分类 benchmark 上评估GraphSAGE ,这些 benchmark 测试了 GraphSAGEunseen 数据上生成有效 embedding 的能力。作者使用基于引文数据和 Reddit 帖子数据(分别预测论文类别和帖子类别)的两个不断演变的文档图,以及基于 protein-protein 交互的数据集(预测蛋白质功能)的多图泛化multigraph generalization实验。

      使用这些 benchmark,作者表明GraphSAGE 能够有效地为 unseen 的节点生成 representation,并大大优于相关 baseline :跨所有这些不同的领域,与单独使用节点特征相比,GraphSAGE 的监督方法将分类 F1 分数平均提高了 51%,并且 GraphSAGE 始终优于强大的直推式的 baseline ,并且该 baseline 需要 100 轮迭代甚至更长的时间才能预测 unseen 的节点。

      作者还表明,与受图卷积网络(《Semi-supervised classification with graph convolutional networks》)启发的聚合器相比,论文提出的新聚合器架构提供了显著的增益(平均增益 7.4%)。

      最后,作者探讨了GraphSAGE 的表达能力expressive capability,并通过理论分析表明:GraphSAGE 能够学到有关节点在图中的角色的结构信息,尽管它本质上是基于特征的。

  2. 相关工作:我们的算法在概念上与之前的 node embedding 方法、图上学习的通用监督方法general supervised approache、以及将卷积神经网络应用于图结构数据的最新进展等等相关。

    • 基于分解的 embedding 方法:最近有许多 node embedding 方法使用随机游走统计和基于矩阵分解的学习目标来学习低维 embeddingGraRep, node2vec, Deepwalk, Line, SDNE)。这些方法还与更经典的谱聚类spectral clustering方法、多维缩放multi-dimensional scaling、以及 PageRank 算法密切相关。

      由于这些 embedding 算法直接为单个节点individual node 训练 node embedding,因此它们本质上是直推式的,并且至少需要昂贵的额外训练(如,通过随机梯度下降)来对 unseen 节点进行预测。此外,对于大多数这些方法,目标函数对于 embedding 的正交变换是不变的,这意味着 embedding 空间不会自然地在图之间泛化,并且在 re-training 期间可能会漂移 drift

      因为这些方法是基于矩阵分解的,而矩阵分解的内积函数vivj$ \mathbf{\vec v}_i\cdot \mathbf{\vec v}_j $ 是 embedding 空间的正交不变的,即:将 embeddign 空间旋转任意角度,原始内积函数和新内积函数的结果是相等的。

      这一趋势的一个显著例外是 Planetoid-I 算法,它是一种归纳式的、基于 embedding 的半监督学习方法。但是,Planetoid-I 在推断过程中不使用任何图结构信息,相反,它在训练期间使用图结构信息作为正则化的一种形式。

      与先前的这些方法不同,我们利用特征信息来训练模型从而为 unseen 节点生成 embedding

    • 图上的监督学习:除了 node embedding 方法之外,还有大量关于图结构数据的监督学习的工作。这包括各种各样的 kernel-based 方法,其中图的特征向量来自于各种 graph kernel 。最近还有许多神经网络方法可以对图结构数据进行监督学习。我们的方法在概念上受到大多数这些算法的启发。然而,这些方法试图对整个图(或子图)进行分类,但是我们这项工作的重点是为每个节点生成有用的 representation

    • 图卷积网络:近年来,人们已经提出了几种用于图上学习的卷积神经网络架构。这些方法中的大多数无法扩展到大型图、或者设计用于整个图的分类。然而,我们的方法与 《Semi-supervised classification with graph convolutional networks》 提出的图卷积网络 graph convolutional network: GCN 密切相关。原始的 GCN 算法是为直推式setting 的半监督学习而设计的,确切 exact 的算法要求在训练期间知道整个图的拉普拉斯算子。我们算法的一个简单变体可以视作 GCN 框架对归纳式setting 的扩展,我们将在正文部分重新讨论这一点。

8.1 模型

  1. 我们方法背后的关键思想是:我们学习如何从节点的局部邻域聚合特征信息(如,邻域节点的 degree 或文本属性)。我们首先描述 GraphSAGEembedding 生成(即,前向传播)算法,该算法在假设 GraphSAGE 模型参数已经学到的情况下为节点生成 embedding 。然后,我们描述了如何使用标准随机梯度下降和反向传播技术来学习 GraphSAGE 模型参数。

8.1.1 前向传播

  1. 这里我们将描述前向传播算法(也叫 embedding 生成算法),其中假设模型已经训练好并且参数是固定的。具体而言,假设我们已经学到了K$ K $ 个聚合函数AGGk,k{1,2,,K}$ \text{AGG}_k,k\in \{1,2,\cdots,K\} $ ,这些聚合函数用于聚合节点的邻域信息。假设我们也学到了K$ K $ 个权重矩阵W(k),k{1,2,,K}$ \mathbf W^{(k)},k\in \{1,2,\cdots,K\} $ ,它们用于在不同层之间传递信息。K$ K $ 也称作搜索深度,或 layer 层数。

    GraphSAGEembedding 生成算法为:

    • 输入:

      • G(V,E)$ \mathcal G(\mathcal V,\mathcal E) $ ,输入特征{xvvV}$ \left\{\mathbf{\vec x}_v\mid v\in \mathcal V\right\} $ ,搜索深度K$ K $ ,邻域函数N()$ \mathcal N(\cdot) $
      • K$ K $ 个权重矩阵W(k)$ \mathbf W^{(k)} $ ,K$ K $ 个聚合函数AGGk$ \text{AGG}_k $ ,k{1,,K}$ k\in \{1,\cdots,K\} $
      • 非线性激活函数σ()$ \sigma(\cdot) $
    • 输出:节点的embedding 向量{zvvV}$ \left\{\mathbf{\vec z}_v\mid v \in \mathcal V \right\} $

    • 算法步骤:

      • 初始化:hv(0)=xv,vV$ \mathbf{\vec h}_v^{(0)} = \mathbf{\vec x}_v, v\in \mathcal V $

      • 对每一层迭代,迭代条件为:k=1,2,,K$ k=1,2,\cdots,K $ 。迭代步骤:

        • 遍历每个节点vV$ v\in \mathcal V $ ,执行:

          hN(v)(k1)=AGGk({hu(k1)uN(v)})hv(k)=σ(W(k)concat(hv(k1),hN(v)(k1)))

          其中 concat() 表示向量拼接。

          这里是拼接融合,也可以考虑其它类型的融合方式。

        • 对每个节点v$ v $ 的隐向量归一化:

          hv(k)=hv(k)hv(k)2,vV
      • zv=hv(K)$ \mathbf{\vec z}_v= \mathbf{\vec h}_v^{(K)} $

  2. GraphSAGE 前向传播算法的背后直觉是:在每次迭代或搜索深度,节点都会聚合来自其局部邻域的信息;并且随着这个过程的迭代,节点将从图的更远范围逐渐获取越来越多的信息。

    在算法的外层循环中的每个 step 如下进行,其中k$ k $ 表示外层循环中的current step (也叫做搜索深度),h(k)$ \mathbf{\vec h}^{(k)} $ 表示该 step 中的 node representation

    • 首先,每个节点vV$ v\in \mathcal V $ 聚合其直接邻域中节点的 representation{hu(k1)uN(v)}$ \left\{\mathbf{\vec h}_u^{(k-1)}\mid u\in \mathcal N(v)\right\} $ 到一个向量hN(v)(k1)$ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 中。注意,这个聚合步骤依赖于第k1$ k-1 $ 轮迭代产生的 node representation (即h(k1)$ \mathbf{\vec h}^{(k-1)} $ ),并且k=0$ k=0 $ 时的 representation 被定义为节点输入特征x$ \mathbf{\vec x} $ 。

      邻域 representation 可以通过各种聚合器架构(以 AGGREGATE 占位符来表达)来完成,接下来我们会讨论不同的架构选择。

    • 然后,在聚合邻域特征向量之后,GraphSAGE 将节点的当前 representationhv(k1)$ \mathbf{\vec h}_v^{(k-1)} $ 和聚合后的邻域向量hN(v)(k1)$ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 拼接起来,然后通过一个带非线性激活函数σ()$ \sigma(\cdot) $ 的全连接层。这个全连接层的输出就是下一个 step 要用到的 representation,即h(k)$ \mathbf{\vec h}^{(k)} $ 。

      大多数节点 embedding 方法将学到的 embedding 归一化为单位向量,这里也做类似处理。

    为了记号方便,我们将第K$ K $ 步的 final representation 记做zv=hv(K),vV$ \mathbf{\vec z}_v=\mathbf{\vec h}_v^{(K)},\forall v\in \mathcal V $ 。

a. mini-batch 训练

  1. 为了将算法扩展到 mini-batch setting,给定一组输入节点,我们首先前向采样 forward sample 所需要的邻域集合(直到深度K$ K $ )然后执行内层循环,而不是迭代所有节点。我们仅计算满足每个k$ k $ 所需的 representation (而不是所有节点的 representation )。

  2. 为了使用随机梯度下降算法,我们需要对GraphSAGE 的前向传播算法进行修改,从而允许mini-batch 中每个节点能够执行前向传播、反向传播。

    即:确保前向传播、反向传播过程中用到的节点都在同一个 mini-batch 中。

  3. GraphSAGE mini-batch 前向传播算法(这里B$ \mathcal B $ 包含了我们想要为其生成 representation 的节点):

    • 算法输入:

      • G(V,E)$ \mathcal G(\mathcal V,\mathcal E) $ ,输入特征{xvvB}$ \left\{\mathbf{\vec x}_v\mid v\in \mathcal B\right\} $ ,搜索深度K$ K $ ,邻域函数N()$ \mathcal N(\cdot) $
      • K$ K $ 个权重矩阵W(k)$ \mathbf W^{(k)} $ ,K$ K $ 个聚合函数AGGk$ \text{AGG}_k $ ,k{1,,K}$ k\in \{1,\cdots,K\} $
      • 非线性激活函数σ()$ \sigma(\cdot) $
    • 输出:节点的embedding 向量{zvvB}$ \left\{\mathbf{\vec z}_v\mid v \in \mathcal B\right\} $

    • 算法步骤:

      • 初始化:B(K)=B$ \mathcal B^{(K)} = \mathcal B $

      • 迭代k=K,,1$ k=K,\cdots,1 $ ,迭代步骤为:

        • B(k1)=B(k)$ \mathcal B^{(k-1)} = \mathcal B^{(k)} $
        • 遍历uB(k)$ u\in \mathcal B^{(k)} $ ,计算B(k1)=B(k1)Nk(u)$ \mathcal B^{(k-1)} = \mathcal B^{(k-1)}\bigcup \mathcal N_k(u) $
      • 初始化:hv(0)=xvvB(0)$ \mathbf{\vec h}_v^{(0)} = \mathbf{\vec x}_v, v\in \mathcal B^{(0)} $

      • 对每一层迭代,迭代条件为:k=1,2,,K$ k=1,2,\cdots,K $ 。迭代步骤:

        • 遍历每个节点vB(k)$ v\in \mathcal B^{(k)} $ ,执行:

          hNk(v)(k1)=AGGk({hu(k1)uNk(v)})hv(k)=σ(W(k)concat(hv(k1),hNk(v)(k1)))

          这里用Nk(v)$ \mathcal N_k(v) $ 表示节点v$ v $ 的邻域在每个深度k$ k $ 都不相同,依赖于前向采样的结果。

        • 对每个节点v$ v $ 的隐向量归一化:

          hv(k)=hv(k)hv(k)2,vV
      • zv=hv(K),vB$ \mathbf{\vec z}_v = \mathbf{\vec h}_v^{(K)},v\in \mathcal B $

  4. mini-batch 前向传播算法的主要思想是:首先采样所有所需的节点。集合B(k1)$ \mathcal B^{(k-1)} $ 包含了第k$ k $ 轮迭代计算 representation 的节点所依赖的节点集合。由于B(k)B(k1)$ \mathcal B^{(k)} \sube \mathcal B^{(k-1)} $ ,所以在计算hv(k)$ \mathbf{\vec h}_v^{(k)} $ 时依赖的hv(k1)$ \mathbf{\vec h}_v^{(k-1)} $ 已经在第k1$ k-1 $ 轮已被计算。另外第k$ k $ 轮需要计算 representation 的节点更少,这避免计算不必要的节点。

    然后计算目标节点的 representation,这一步和 batch 前向传播算法相同。

    mini-batch 前向传播和 batch 前向传播的主要区别在于:mini-batch 前向传播还有一个前向采样的步骤。

  5. 我们使用Nk()$ \mathcal N_k(\cdot) $ 的k$ k $ 来表明:不同层之间使用独立的 random walk 采样。这里我们使用均匀采样,并且当节点邻域节点数量少于指定数量时采用有放回的采样,否则使用无放回的采样。

    有一些算法聚焦于如何更好地进行采样,从而优化最终效果。

  6. mini-batch 算法的采样过程在概念上与 batch 算法的迭代过程是相反的。我们从需要以深度K$ K $ 生成 representation 的节点开始,然后我们对它们的邻域进行采样(即,深度K1$ K-1 $ ),依此类推。这样做的一个后果是邻域采样规模的定义可能有点违反直觉。具体而言,假设K=2$ K=2 $ :

    • batch 算法中,我们在k=1$ k=1 $ 时对节点邻域内采样S1$ S_1 $ 个节点,在k=2$ k=2 $ 时对节点邻域内采样S2$ S_2 $ 个节点。

    • mibi-batch 算法中,我们在k=2$ k= 2 $ 时对节点邻域内采样S2$ S_2 $ 个节点,然后在k=1$ k=1 $ 时对节点邻域内采样S1×S2$ S_1\times S_2 $ 个节点。

      这样才能保证我们的目标B$ \mathcal B $ 中包含 mibi-batch 所需要计算的所有节点。

b. 和 WL-Test 关系

  1. GraphSAGE 算法在概念上受到图的同构性检验的经典算法的启发。在前向传播过程中,如果令K=|V|$ K=|\mathcal V| $ 、W(k)=I$ \mathbf W^{(k)} = \mathbf I $ ,并选择合适的hash 函数来作为聚合函数,同时移除非线性函数,则该算法是 Weisfeiler-Lehman:WL 同构性检验算法的一个特例,被称作 naive vertex refinement

    如果算法输出的 node representation{zv,vV}$ \left\{\mathbf{\vec z}_v,v\in \mathcal V\right\} $ 在两个子图是相等的,则 WL-test 算法认为这两个子图是同构的。虽然在某些情况下该检验会失败,但是大多数情况下该检验是有效的。

  2. GraphSAGEWL test 算法的一个continous 近似,其中GraphSAGE 使用可训练的神经网络聚合函数代替了不连续的哈希函数。虽然 GraphSAGE 的目标是生成节点的有效embedding 而不是检验图的同构性,但是GraphSAGEWL test 之间的联系为我们设计学习节点邻域拓扑结构的算法提供了理论背景。

  3. 可以证明:即使我们提供的是节点特征信息,GraphSAGE 也能够学到图的结构信息。参考 “理论分析” 部分。

c. 邻域定义

  1. GraphSAGE 中我们并没有使用完整的邻域,而是均匀采样一组固定大小的邻域,从而确保每个 batch 的计算代价是固定的。因此我们定义N(v)$ \mathcal N(v) $ 为:从集合{uuV,(u,v)E}$ \{u\mid u\in \mathcal V,(u,v)\in \mathcal E\} $ 中均匀采样的、固定大小的集合,并且我们在算法的每轮迭代k$ k $ 中采样不同的邻域。

    如果对每个节点使用完整的邻域,则每个 batch 的内存需求和运行时间是不确定的,最坏情况为O(|V|)$ O(|\mathcal V|) $ 。如果使用采样后的邻域,则每个 batch 的时间和空间复杂度固定为O(k=1KSk)$ O(\prod_{k=1}^KS_k) $ ,其中Sk$ S_k $ 表示第k$ k $ 轮迭代时的邻域大小。K$ K $ 以及Sk$ S_k $ 均为用户指定的超参数,实验发现K=2,S1×S2500$ K=2, S_1\times S_2\le 500 $ 时的效果较好。

    K$ K $ 和Sk$ S_k $ 依赖于具体的数据集和任务。

8.1.2 模型学习

  1. 为了在完全无监督的环境中学习有用的、预测性的 representation,我们将一个 graph-based 损失函数应用于 output representationzu,uV$ \mathbf{\vec z}_u,\forall u\in \mathcal V $ ,并且通过随机梯度下降来学习模型参数。这个 graph-based 损失函数鼓励临近的节点具有相似的 representation,同时迫使不相近的节点具有高度不相似的 representation

    JG(zu)=log(sigmoid(zuzv))Q×EvnPn(v)log(sigmoid(zuzvn))

    其中:

    • v$ v $ 是和节点u$ u $ 在一个长度为l$ l $ 的 random walk 上共现的节点。
    • sigmoid(.)sigmoid 函数。
    • Pn()$ P_n(\cdot) $ 为负采样用到的分布函数,vn$ v_n $ 为负采样到的 negative nodeQ$ Q $ 为负采样的样本数。

    重要的是,与之前的 embedding 方法不同,GraphSAGE 中的节点 representationzu$ \mathbf{\vec z}_u $ 是从节点局部邻域中包含的特征而生成的,而不是通过 embedding look-up 而生成的。

    可以看到,GraphSAGEDeepWalk 类似,也依赖于图上的随机游走过程。为了提高训练效率,通常在训练之前执行一次随机游走过程(避免在训练的每轮迭代中进行随机游走)。

  2. 以无监督方式学到的节点 embedding 可以作为通用 service 来服务于下游的机器学习任务。但是如果仅在特定的任务上应用,则可以简单地将特定于任务的监督学习损失替代或增强原始的无监督损失。

    通过结合监督损失和无监督损失,那么可以同时利用 labeled 数据和 unlabeled 数据,即半监督学习。

8.1.3 聚合函数

  1. 和网格型数据(如文本、图像)不同,图的节点之间没有任何顺序关系,因此算法中的聚合函数必须能够在无序的节点集合上运行。理想的聚合函数是对称的,同时可训练并保持较高的表达能力。这种对称性可以确保我们的神经网络模型可以用于任意顺序的节点邻域的训练和测试。

    对称性是指:对于给定的一组节点集合,无论它们以何种顺序输入到聚合函数,聚合后的输出结果不变。

    聚合函数有多种形式,我们检查了三种主要的聚合函数:均值聚合函数mean aggregatorLSTM聚合函数LSTM aggregator 、池化聚合函数 pooling aggregator

  2. mean aggregator:简单的使用邻域节点的特征向量的逐元素均值来作为聚合结果。这几乎等价于直推式 GCN 框架中的卷积传播规则。

    具体而言,如果我们将前向传播:

    hN(v)(k1)=AGGk({hu(k1)uN(v)})hv(k)=σ(W(k)concat(hv(k1),hN(v)(k1)))

    替换为:

    hv(k)=σ(W(k)MEAN({hv(k1)}{hu(k1)uN(v)}))

    则这得到直推式 GCN 的一个 inductive 变种,我们称之为基于均值聚合的卷积 mean-based aggregator convolutional 。它是局部谱卷积localized spectral convolution的一个粗糙的线性近似。

    GCN 的前向传播为:

    H(k)=W(k)(D~1/2A~D~1/2H(k1))

    其中:A~=A+I$ \tilde{\mathbf A} = \mathbf A + \mathbf I $ ,A$ \mathbf A $ 为邻接矩阵,D~$ \tilde{\mathbf D} $ 为A~$ \tilde{\mathbf A} $ 的 degree 矩阵。

    因此有:

    hv(k)=σ(W(k)MEAN({hv(k1)}{hu(k1)uN(v)}))

    注意,GCNhv(0)$ \mathbf{\vec h}_v^{(0)} $ 是通过 embedding look-up 而生成的(而不是输入特征xv$ \mathbf{\vec x}_v $ )。

    这个卷积聚合器与我们提出的其它聚合器之间的一个重要区别在于:它并未执行拼接操作(即,将hv(k1)$ \mathbf{\vec h}_v^{(k-1)} $ 和hN(v)(k1)$ \mathbf{\vec h}_{\mathcal N(v)}^{(k-1)} $ 拼接起来) 。这种拼接操作可以视为 GraphSAGE 算法的不同 search depth (或 layer)之间的 skip connection 的一种简单形式,它可以显著提高性能。

    事实上其它聚合器在拼接操作之后执行了带非线性激活函数的投影,因此破坏了这种 skip connection。是否修改为以下形式更好?

    hv(k)=hv(k1)+σ(W(k)hN(v)(k1))
  3. LSTM aggregator :和均值聚合相比,LSTM 具有更强大的表达能力。但是 LSTM 原生的是非对称的(即,LSTM 不是 permutation invariant 的),它依赖于节点的输入顺序。因此我们通过简单地将 LSTM 应用于邻域节点的随机排序,从而使得 LSTM 可以应用于无序的节点集合。

  4. pooling aggregator :池化聚合器是对称的、可训练的。在这种池化方法中,邻域每个节点的特征向量都通过全连接神经网络独立馈入,然后通过一个逐元素的最大池化来聚合邻域信息:

    hN(v)(k1)=max({σ(Wpoolhu(k1)+bpool)uN(v)})

    其中 max 表示逐元素的 max 运算符,σ()$ \sigma(\cdot) $ 是非线性激活函数。

    理论上可以在最大池化之前使用任意深度的多层感知机,但是我们这里专注于简单的单层网络结构。直观上看,可以将多层感知机视为一组函数,这组函数为邻域集合内的每个节点representation 计算特征。通过将最大池化应用到这些计算到的特征上,模型可以有效捕获邻域集合的不同方面 aspect

    理论上可以使用任何的对称向量函数(如逐元素均值)来替代 max 运算符。但是我们在实验中发现最大池化和均值池化之间没有显著差异,因此我们专注于最大池化。

8.1.4 理论分析

  1. 这里我们将探讨 GraphSAGE 的表达能力,以便深入了解 GraphSAGE 如何学习图结构,即使它本质上是基于特征的。作为案例研究,我们考虑 GraphSAGE 是否可以学习预测节点的聚类系数 clustering coefficient,即:在节点的 1-hop 邻域内,闭合的三角形占所有三角形(闭合的和未闭合的)的比例。聚类系数是衡量节点局部邻域聚类程度的常用指标,它可以作为许多更复杂的结构主题structural motifbuilding block。可以证明:GraphSAGE 算法能够将聚类系数逼近到任意精度。

  2. 定理:令xvU,vV$ \mathbf{\vec x}_v\in \mathbb U,v\in \mathcal V $ 作为 GraphSAGE 算法针对图G=(V,E)$ \mathcal G=(\mathcal V,\mathcal E) $ 的输入,其中U$ \mathbb U $ 是Rd$ \mathbb R^d $ 的一个紧致子集 compact subset。假设存在一个固定的正的常数CR+$ C\in \mathbb R^+ $ 使得xvxv2>C$ \left\|\mathbf{\vec x}_v - \mathbf{\vec x}_{v^\prime}\right\|_2\gt C $ 对任意节点 pair(v,v)$ (v,v^\prime) $ 成立,那么我们有:对于任意ϵ>0$ \epsilon \gt 0 $ ,这里存在一个参数 settingΘ$ \mathbf \Theta^* $ ,使得GraphSAGE 算法在K=4$ K=4 $ 轮迭代之后有:

    |zvcv|ϵ,vV

    其中:zvR$ z_v\in \mathbb R $ 为 GraphSAGE 算法的 final output 值,cv$ c_v $ 为节点的聚类系数。

    注意,这里假设 output representation 是一维的。

  3. 上述定理指出:对于任意的图,GraphSAGE 算法都存在一个参数 setting,如果每个每个节点的特征都是不同的(并且如果模型足够高维),那么算法可以将图的聚类系数逼近到任意精度。证明见原始论文。

    注意:作为该定理的推论,GraphSAGE 可以了解局部图结构,即使节点特征输入是从连续随机分布中采样的(因此特征输入与图结构无关)。

    证明背后的基本思想是:如果每个节点都有一个 unique 的特征,那么我们可以学习将节点映射到 indicator 向量并识别节点邻域。定理的证明依赖于池化聚合器的一些属性,这也提供了为什么 GraphSAGE-pool 优于 GCN 、以及 mean-based 聚合器的洞察。

8.2 实验

  1. 我们在三个 benchmark 任务上检验 GraphSAGE 的效果:Web of Science Citation 数据集的论文分类任务、Reddit 数据集的帖子分类任务、PPI 数据集的蛋白质分类任务。

    前两个数据集是对训练期间unseen 的节点进行预测,最后一个数据集是对训练期间unseen 的图进行预测。

  2. 数据集:

    • Web of Science Cor Collection 数据集:包含 2000 年到 2005 年六个生物学相关领域的所有论文,每篇论文属于六种主题类别之一。数据集包含 302424 个节点,节点的平均degree9.15 。其中:

      • Immunology 免疫学的标签为NI,节点数量 77356
      • Ecology 生态学的标签为 GU,节点数量 37935
      • Biophysics 生物物理学的标签为DA,节点数量 36688
      • Endocrinology and Metabolism 内分泌与代谢的标签为 IA ,节点数量 52225
      • Cell Biology 细胞生物学的标签为 DR,节点数量84231
      • Biology(other) 生物学其它的标签为 CU,节点数量 13988

      任务目标是预测论文主题的类别。我们根据 2000-2004 年的数据来训练所有算法,并用 2005 年的数据进行进行测试(其中 30% 用于验证)。

      我们使用节点degree 和文章的摘要作为节点的特征,其中节点摘要根据Arora 等人的方法使用 sentence embedding 方法来处理文章的摘要,并使用Gensim word2vec 的实现来训练了300 维的词向量。

    • Reddit 数据集:包含20149Reddit 上发布帖子的一个大型图数据集,节点标签为帖子所属的社区。我们采样了 50 个大型社区,并构建一个帖子到帖子的图。如果一个用户同时在两个帖子上发表评论,则这两个帖子将链接起来。数据集包含 232965 个节点,节点的平均degree492

      为了对社区进行采样,我们按照每个社区在 2014 年的评论总数对社区进行排名,并选择排名在 [11,50](包含)的社区。我们忽略了最大的那些社区,因为它们是大型的、通用的默认社区,会严重扭曲类别的分布。我们选择这些社区上定义的最大连通图largest connected component

      任务的目标是预测帖子的社区community。我们将该月前20 天用于训练,剩下的天数作为测试(其中 30% 用于验证)。

      我们使用帖子的以下特征:标题的平均embedding、所有评论的平均 embedding、帖子评分、帖子评论数。其中embedding 直接使用现有的 300 维的 GloVe CommonCral 词向量,而不是在所有帖子中重新训练。

    • PPI 数据集:包含Molecular Signatures Dataset 中的图,每个图对应于不同的人类组织,节点标签采用gene ontology sets ,一共121 种标签。平均每个图包含 2373 个节点,所有节点的平均 degree28.8

      任务的目的是评估模型的跨图泛化的能力。我们在 20 个随机选择的图上进行训练、2 个图进行验证、 2 个图进行测试。其中训练集中每个图至少有 15000 条边,验证集和测试集中每个图都至少包含 35000 条边。注意:对于所有的实验,验证集和测试集是固定选择的,训练集是随机选择的。我们最后给出测试图上的 micro-F1 指标。

      我们使用positional gene setsmotif gene sets 以及 immunological signatures 作为节点特征。我们选择至少在 10% 的蛋白质上出现过的特征,低于该比例的特征不被采纳。最终节点特征非常稀疏,有 42% 的节点没有非零特征(即,42% 的节点的特征全是空的),这使得节点之间的链接非常重要。

  3. Baseline 模型:

    • 随机分类器。
    • 基于节点特征的逻辑回归分类器(完全忽略图的结构信息)。
    • 代表因子分解方法的 DeepWalk 算法+逻辑回归分类器(完全忽略节点的特征)。
    • 拼接了 DeepWalkembedding 以及节点特征的方法(融合图的节点特征和结构特征)。

    我们使用了不同聚合函数的 GraphSAGE 的四个变体。由于卷积的变体是 GCNinductive 扩展,因此我们称其为 GraphSAGE-GCN

    我们使用了 GraphSAGE 的无监督版本,也直接使用分类交叉熵作为损失的有监督版本。

  4. 模型配置:

    • GrahSage

      • 所有GraphSAGE 模型都在 Tensorflow 中使用 Adam 优化器实现, 而 DeepWalk 在普通的随机梯度优化器中表现更好。
      • 为防止 GraphSAGE 聚合函数的效果比较时出现意外的超参数hacking,我们对所有 GraphSAGE 版本进行了相同的超参数配置:根据验证集的性能为每个版本提供最佳配置。
      • 对于所有的 GraphSAGE 版本设置K=2$ K=2 $ 以及邻域采样大小S1=25,S2=10$ S_1=25, S_2= 10 $ 。
      • 对于所有的 GraphSAGE ,我们对每个节点执行以该节点开始的 50 轮长度为 5 的随机游走序列,从而得到pair 节点对。我们的随机游走序列生成完全基于 Python 代码实现。
      • 由于节点 degree 分布的长尾效应,我们将 GraphSAGE 算法中所有图的边执行降采样预处理。经过降采样之后,使得没有任何节点的 degree 超过 128 。由于我们每个节点最多采样 25 个邻居,因此这是一个合理的权衡。
    • 为公平比较,所有模型都采样相同的 mini-batch 迭代器、损失函数(当然监督损失和无监督损失不同)、邻域采样器。

    • 对于原生特征模型,以及基于无监督模型的 embedding 进行预测时,我们使用 scikit-learn 中的 SGDClassifier 逻辑回归分类器,并使用默认配置。

    • 在所有配置中,我们都对学习率和模型的维度以及batch-size 等等进行超参数选择:

      • 除了 DeepWalk 之外,我们为监督学习模型设置初始学习率的搜索空间为{0.01,0.001,0.0001}$ \{0.01,0.001,0.0001\} $ ,为无监督学习模型设置初始学习率的搜索空间为{2×106,2×107,2×108}$ \{2\times 10^{-6},2\times 10^{-7},2\times 10^{-8}\} $ 。

        最初实验表明 DeepWalk 在更大的学习率下表现更好,因此我们选择DeepWalk 的初始学习率搜索空间为{0.2,0.4,0.8}$ \{0.2,0.4,0.8\} $ 。

      • 我们测试了每个GraphSAGE模型的big 版本和 small 版本。

        • 对于池化聚合函数,big 模型的池化层维度为 1024small 模型的池化层维度为 512
        • 对于 LSTM 聚合函数,big 模型的隐层维度为 256small 模型的隐层维度为 128

        注意,这里设置的是聚合器的维度,而不是 hidden representation 的维度。

      • 所有实验中,我们将GraphSAGE 每一层的hi(k)$ \mathbf{\vec h}_i^{(k)} $ 的维度设置为 256

      • 所有的 GraphSAGE 以及 DeepWalk 的非线性激活函数为 ReLU

      • 对于无监督 GraphSAGEDeepWalk 模型,我们使用 20 个负采样的样本,并且使用 0.75 的平滑参数对节点的degree 进行上下文分布平滑。

      • 对于监督 GraphSAGE,我们为每个模型运行 10epoch

      • 我们对 GraphSAGE 选择 batch-size = 512。对于 DeepWalk 我们使用 batch-size=64,因为我们发现这个较小的 batch-size 收敛速度更快。

  5. 硬件配置:

    • DeepWalkCPU 密集型机器上速度更快,它的硬件参数为 144 coreIntel Xeon CPU(E7-8890 V3 @ 2.50 GHz)2T 内存。
    • 其它模型在单台机器上实验,该机器具有 4NVIDIA Titan X Pascal GPU( 12 Gb 显存, 10Gbps 访问速度), 16 coreIntel Xeon CPU(E5-2623 v4 @ 2.60GHz),以及 256 Gb 内存。

    所有实验在共享资源环境下大约进行了3 天。我们预期在消费级的单 GPU 机器上(如配备了 Titan X GPU )的全部资源专用,可以在 47 天完成所有实验。

  6. DeepWalk 测试阶段:

    • 对于 Reddit 和引文数据集,我们按照 Perozzi 等人的描述对 DeepWalk 执行 oneline 训练。对于新的测试节点,我们进行了新一轮的 SGD 优化,从而得到新节点的 embedding

      现有的 DeepWalk 实现仅仅是 word2vec 代码的封装,它难以支持 embedding 新节点以及其它变体。这里我们根据 tensorflow 中的官方 word2vec 教程实现了 DeepWalk 。为了得到新节点的 embedding,我们在保持已有节点的 embedding 不变的情况下,对每个新的节点执行 50 个长度为 5 的随机游走序列,然后更新新节点的 embedding

      我们还测试了两种变体:一种是将采样的随机游走“上下文节点”限制为仅来自已经训练过的旧节点集合,这可以缓解统计漂移;另一种是没有该限制。我们总数选择性能最强的那个。

      尽管 DeepWalkinductive 任务上的表现很差,但是在 transductive 环境下测试时它表现出更强的竞争力。因为在该环境下 DeepWalk 可以在单个固定的图上进行持续的训练。我们观察到在 inductive 环境下 DeepWalk 的性能可以通过进一步的训练来提高。并且在某种情况下,如果让它比其它方法运行的时间长 1000 倍,则它能够达到与无监督 GraphSAGE (而不是有监督 GraphSAGE )差不多的性能。但是我们不认为这种比较对于 inductive 是有意义的。

    • PPI 数据集中我们无法应用 DeepWalk,因为在不同的、不相交的图上运行 DeepWalk 算法生成的 embedding 空间可以相对于彼此任意旋转。参考最后一小节的证明。

  7. GraphSAGEbaseline 在这三个任务上的表现如下表所示。这里给出的是测试集上的 micro-F1 指标,对于 macro-F1 结果也有类似的趋势。其中 Unsup 表示无监督学习,Sup 表示监督学习。

    • GraphSAGE 的性能明显优于所有的 baseline 模型。

    • 根据 GraphSAGE 不同版本可以看到:与GCN 聚合方式相比,可训练的神经网络聚合函数具有明显的优势。

      注意,这里的 GraphSAGE-mean 是将 GraphSAGE-poolmax 函数替换为 mean 得到。

    • 尽管LSTM 这种聚合函数是为有序数据进行设计而不是为无序 set 准备的,但是通过随机排列的方式,它仍然表现出出色的性能。

    • 和监督版本的 GraphSAGE 相比,无监督 GraphSAGE 版本的性能具有相当竞争力。这表明我们的框架无需特定于具体任务就可以实现强大的性能。

  8. 通过在 Reddit 数据集上不同模型的训练和测试的运行时间如下表所示,其中 batch size = 512,测试集包含 79534 个节点。可以看到:

    • 这些方法的训练时间相差无几,其中 GraphSAGE-LSTM 最慢。
    • 除了 DeepWalk 之外,其它方法的测试时间也相差无几。由于 DeepWalk 需要采样新的随机游走序列,并运行多轮SGD 随机梯度下降来生成unseen 节点的 embedding,这使得 DeepWalk 在测试期间慢了 100~500 倍。

  9. 对于 GraphSAGE 变体,我们发现和K=1$ K=1 $ 相比,设置K=2$ K=2 $ 使得平均准确性可以一致性的提高大约 10%~15% 。但是当K$ K $ 增加到 2 以上时会导致性能的回报较低(0~5%) ,但是运行时间增加到夸张的 10~100 倍,具体取决于采样邻域的大小。

    另外,随着采样邻域大小逐渐增加,模型获得的收益递减。因此,尽管对邻域的采样引起了更高的方差,但是 GraphSAGE 仍然能够保持较强的预测准确性,同时显著改善运行时间。下图给出了在引文网络数据集上 GraphSAGE-mean 模型采用不同邻域大小对应的模型性能以及运行时间,其中K=2$ K=2 $ 以及S1=S2$ S_1=S_2 $ 。

  10. 总体而言我们发现就平均性能和超参数而言,基于 LSTM 聚合函数和池化聚合函数的表现最好。为了定量的刻画这种比较优势,我们将三个数据集、监督学习/无监督学习两种方式一共六种配置作为实验,然后使用 Wilcoxon Signed-Rank Test 来量化不同模型的性能。

    结论:

    • 基于 LSTM 聚合函数和池化聚合函数的效果确实最好。
    • 基于LSTM 聚合函数的效果和基于池化聚合函数的效果相差无几,但是由于 GraphSAGE-LSTMGraphSAGE-pool 慢得多(大约2 倍),这使得基于池化的聚合函数总体上略有优势。

8.3 DeepWalk Embedding 旋转不变性

  1. DeepWalk,node2vec 以及其它类似的 node embedding 方法的目标函数都有类似的形式:

    L=αi,jAf(zizj)+βi,jBf(zizj)

    其中:

    • f(),g()$ f(\cdot),g(\cdot) $ 为平滑、连续的函数。
    • zi$ \mathbf{\vec z}_i $ 为直接优化的 node embedding (通过 embeddinglook up 得到)。
    • A,B$ \mathcal A,\mathcal B $ 为满足某些条件的节点 pair 对。

    事实上这类方法可以认为是一个隐式的矩阵分解ZZMR|V|×|V|$ \mathbf Z^\top\mathbf Z \simeq \mathbf M \in \mathbb R^{|\mathcal V|\times |\mathcal V|} $ ,其中:

    • ZRd×|V|$ \mathbf Z\in \mathbb R^{d\times |\mathcal V|} $ 的每一列代表一个节点的 embedding
    • MR|V|×|V|$ \mathbf M\in \mathbb R^{|\mathcal V|\times |\mathcal V|} $ 是一个包含某些随机游走统计量的矩阵。

    这类方法的一个重要结果是:embedding 可以通过任意单位正交矩阵变换,从而不影响矩阵分解:

    (QZ)(QZ)=ZQQZ=ZZM

    其中QRd×d$ \mathbf Q\in \mathbb R^{d\times d} $ 为任意单位正交矩阵。所以整个embedding 空间在训练过程中可以自由旋转。

  2. embedding 矩阵可以在 embedding 空间可以自由旋转带来两个明显后果:

    • 如果我们在两个单独的图 AB 上基于L$ \mathcal L $ 来训练 embedding 方法,如果没有一些明确的惩罚机制来强制两个图的节点对齐,则两个图学到的 embedding 空间将相对于彼此可以任意旋转。因此,对于在图 A 的节点 embedding 上训练的任何节点分类模型,如果直接灌入图 B 的节点 embedding ,这这等效于对该分类模型灌入随机数据。

      如果我们有办法在图之间对齐节点,从而在图之间共享信息,则可以缓解该问题。研究如何对齐是未来的方向,但是对齐过程不可避免地在新数据集上运行缓慢。

      GraphSAGE 完全无需做额外地节点对齐,它可以简单地为新节点生成 embedding 信息。

    • 如果在时刻t$ t $ 对图 A 基于L$ \mathcal L $ 来训练 embedding 方法,然后在学到的 embedding 上训练分类器。如果在时刻t+1$ t+1 $ ,图 A 添加了一批新的节点,并通过运行新一轮的随机梯度下降来更新所有节点的 embedding ,则这会导致两个问题:

      • 首先,类似于上面提到的第一点,如果新节点仅连接到少量的旧节点,则新节点的 embedding 空间实际上可以相对于原始节点的 embedding 空间任意旋转。
      • 其次,如果我们在训练过程中更新所有节点的 embedding,则相比于我们训练分类模型所依赖的原始 embedding 空间相比,我们新的 embedding 空间可以任意旋转。
  3. 这类embedding 空间旋转问题对于依赖成对节点距离的任务(如通过 embedding 的点积来预测链接)没有影响。

    因为不管 embedding 空间怎么旋转,节点之间的距离不变(如通过内积的距离,或通过欧式距离的距离)。

  4. 缓解这类统计漂移问题(即embedding 空间旋转)的一些方法为:

    • 为新节点训练 embedding 时,不要更新已经训练的 embedding
    • 在采样的随机游走序列中,仅保留旧节点为上下文节点,从而确保 skip-gram 目标函数中的每个点积操作都是一个旧节点和一个新节点。

    我们尝试了这两种方式,并始终选择效果最好的 DeepWalk 变体。

  5. 从经验来讲,DeepWalk 在引文网络上的效果要比 Reddit 网络更好。因为和引文网络相比,Reddit 的这种统计漂移更为严重:Reddit 数据集中,从测试集链接到训练集的边更少。在引文网络中,测试集有 96% 的新节点链接到训练集;在 Reddit 数据集中,测试集只有 73% 的新节点链接到训练集。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文