返回介绍

数学基础

统计学习

深度学习

工具

Scala

五、GNN 评估陷阱 [2018]

发布于 2023-07-17 23:38:24 字数 11374 浏览 0 评论 0 收藏 0

  1. 图的半监督节点分类是图挖掘中的一个经典问题,最近提出的图神经网络graph neural network:GNN 在这个任务上取得了瞩目的成就。尽管取得了巨大的成功,但是由于实验评估程序的某些问题,我们无法准确判断模型所取得的进展:

    • 首先,很多提出的模型仅仅是在《Revisiting semi-supervised learning with graph embeddings》 给出的三个数据集(CORA,CiteSeer,PubMed) 上、且使用该论文相同的train/validation/test 数据集拆分上评估的。

      这种实验配置倾向于寻找最过拟合overfit the most 的模型,并且违反了train/validation/test 拆分的目的:寻找泛化能力最佳best generalization 的模型。

    • 其次,在评估新模型的性能时,新模型和 baseline 通常采用不同的训练程序。例如,有的采用了早停策略,有的没有采用早停策略。这使得很难确定新模型性能的提升是来自于新模型的优秀架构,还是来自于训练过程或者超参数配置。这使得对新模型产生不公平的好处unfairly benefit

    有鉴于此,论文《Pitfalls of Graph Neural Network Evaluation》 表明现有的 GNN 模型评估策略存在严重缺陷。

    为解决这些问题,论文在 transductive 半监督节点分类任务上对四种著名的 GNN 架构进行了全面的实验评估。论文在同一个框架内实现了四个模型:GCN, MoNet, GraphSage, GAT

    在评估过程中,论文专注于两个方面:

    • 对所有模型都使用标准化的训练和超参数选择过程。在这种情况下,性能的差异可以高度确定地 high certainty 归因于模型架构的差异,而不是其它因素。

    • 对四个著名的引文网络数据集进行实验,并为节点分类问题引入四个新的数据集。

      对于每个数据集使用 100 次随机 train/validation/test 拆分,对于每次拆分分别执行20 次模型的随机初始化。

      这种配置使得我们能够更准确地评估不同模型的泛化性能,而不仅仅是评估模型在一个固定测试集的上的性能。

    论文声明:作者不认为在 benchmark 数据集上的准确性是机器学习算法的唯一重要特性。发展和推广现有方法的理论,建立与其他领域的联系(和适应来自其他领域的思想)是推动该领域发展的重要研究方向。然而,彻底的实证评估对于理解不同模型的优势和局限性至关重要。

    实验结果表明:

    • 考虑数据集的不同拆分会导致模型排名的显著不同。
    • 如果对所有模型都合理地调整超参数和训练过程,那么简单的 GNN 架构就可以超越复杂的 GNN 架构。

5.1 模型和数据集

  1. 考虑图上的transductive 半监督节点分类问题。本文中我们比较了以下四种流行的图神经网络架构:

    • Graph Convolutional Network: GCN:是早期的模型,它对谱域卷积执行线性近似。
    • Mixture Model Network: MoNet:推广了 GCN 架构从而允许学习自适应的卷积滤波器。
    • Graph Attention Network: GAT :采用一种注意力机制从而允许在聚合步骤中以不同的权重加权邻域中的节点。
    • GraphSAGE :侧重于 inductive 节点分类,但是也可用于 transductive 配置中。我们考虑原始论文中的三个变体:GS-mean, GS-meanpool, GS-maxpool

    所有上述模型的原始论文和参考实现均使用不同的训练程序,包括:不同的早停策略、不同的学习率衰减decay、不同的 full-batch /mini-batch 训练。如下图所示:

    不同的实验配置使得很难凭实验确定模型性能提升的背后原因。因此在我们的实验中,我们对所有模型都使用标准化的训练和超参数调优程序,从而进行更公平的比较。

    此外,我们考虑了四个 baseline 模型,包括:逻辑回归Logistic Regression: LogReg、多层感知机 Multilayer Perceptron: MLP、标签传播Label Propagation: LabelProp 、归一化的拉普拉斯标签传播Normalized Laplacian Label Propagation: LabelProp NL 。其中: LogReg,MLP 是基于属性的模型,它们不考虑图结构;LabelProp, LabelProp NL 仅考虑图结构而忽略节点属性。

  2. 数据集:

    我们考虑四个著名的引文网络数据集:PubMed, CiteSeer, CORA, CORA-Full。其中 CORA-FullCORA 的扩展版本。

    我们还为节点分类任务引入了四个新的数据集:Coauthor CSCoauthor PhysicsAmazon ComputersAmazon Photo

    • Amazon ComputersAmazon PhotoAmazon co-purchase 图的一部分。其中:节点代表商品、边代表两个商品经常被一起购买,节点特征为商品评论的 bag-of-word,类别标签label 为产品类目category
    • Coauthor CSCoauthor Physics 是基于 Microsoft Academic Graphco-authorship 图。其中:节点代表作者、边代表两名作者共同撰写过论文,节点特征为每位作者论文的论文关键词,类别标签为作者最活跃的研究领域。

    我们对数据集进行了标准化,其中对 CORA_full 添加了 self-loop,并删除CORA_full 样本太少的类别。 我们删除了 CORA_full 中样本数量少于 50 个节点的 3 个类别,因为我们对这些类别无法执行很好的数据集拆分(在后续数据集拆分中,每个类别至少要有 20 个标记节点作为训练集、30 个标记节点作为验证集)。

    对于所有数据集,我们将图视为无向图,并且仅考虑最大连通分量。

    数据集的统计量如下表所示。其中:

    • Label rate 为数据集的标记率,它表示训练集的标记节点的占比。因为我们对每个类别选择 20 个标记节点作为训练集,因此:

      Label rate=×20
    • Edge density 为图的链接占所有可能链接的比例,它等于:

      Edge density=×(1)/2

5.2 实验配置

  1. 模型架构:我们保持原始论文/参考实现中相同的模型架构,其中包括:层layer 的类型和顺序、激活函数的选择、dropout 的位置、L2$ L_2 $ 正则化位置的选择。

    我们还将 GATattention head 数量固定为 8MoNet 高斯核的数量固定为 2

    所有模型都有 2 层:input features --> hidden layer --> output layer

  2. 训练过程:为了更公平的比较,我们对所有模型都使用相同的训练过程。对于所有的模型:

    • 相同的优化器,即:带默认参数的 Adam 优化器。
    • 相同的初始化,即:根据 Glorot 初始化权重,而 bias 初始化为零。
    • 都没有学习率衰减。
    • 相同的最大训练 epoch 数量。
    • 相同的早停准则、相同的 patience、相同的验证频率validation frequency
    • 都使用 full-batch 训练,即:每个 epoch 都使用训练集中的所有节点。
    • 同时优化所有的模型参数,包括:GATattention weightsMoNetkernel parameters、所有模型的权重矩阵。
    • 所有情况下,我们选择每个类别 20 个带标签的节点作为训练集、每个类别 30 个带标签的节点作为验证集、剩余节点作为测试集。

    我们最多训练 100kepoch,但是由于我们使用了严格的早停策略,因此实际训练时间大大缩短了。具体而言,如果总的验证损失(数据损失 +正则化损失)在 50epoch 都没有改善,则提前停止训练。一旦训练停止,则我们将权重的状态重置为验证损失最小的 step

  3. 超参数选择:我们对每个模型使用完全相同的策略进行超参数选择。具体而言,我们对学习率、hidden layer 维度、L2$ L_2 $ 正则化强度、dropout rate 执行网格搜索。搜索空间:

    • 隐层维度:[8, 16, 32, 64]
    • 学习率:[0.001, 0.003, 0.005, 0.008, 0.01]
    • dropout rate[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    • attention 系数的 dropout rate(仅用于 GAT):[0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    • L2$ L_2 $ 正则化强度:[1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1]

    对于每个模型,我们选择使得 Cora 数据集和 CiteSeer 数据集上平均准确率最高的超参数配置。这是对每个数据集执行 100 次随机 train/validation/test 拆分、对于每次拆分分别执行20 次模型的随机初始化从而取得的。所选择的最佳超参数配置如下表所示,这些配置用于后续实验。

    这里应该针对不同数据集进行超参数调优,而不是所有数据集都采用相同的超参数(即,在 Cora 数据集和 CiteSeer 数据集上平均准确率最高的超参数配置)。

    注意:

    • GAT 有两个 dropout rate:节点特征上的 dropout、注意力系数上的 dropout
    • 所有 GraphSAGE 模型都有额外的权重用于 skip connection,这使得实际的 hidden size 翻倍。因此这里 GS-meanEffective hiden size = 32
    • GS-meanpool/GS-maxpool 具有两个 hidden size:隐层的 size、中间特征转换的 size
    • GAT 使用 8headmulti-head 架构,而 MoNet 使用 2head

5.3 实验结果

  1. 所有 8 个数据集的所有模型的平均准确率(以及标准差)如下表所示。结果是在100 次随机 train/validation/test 拆分、对于每次拆分我们分别执行20 次模型的随机初始化上取得的。

    对每个数据集,准确率最高的得分用粗体标记。N/A 表示由于 GPU RAM 的限制而无法由 full-batch 版本的 GS-maxpool 处理的数据集。

    结论:

    • 首先,在所有数据集中,基于 GNN 的方法(GCN, MoNet, GAT, GraphSAGE) 显著优于 baseline 方法(MLP, LogReg, LabelProp, LabelProp NL) 。

      这符合我们的直觉,并证明了与仅考虑属性或仅考虑结构的方法相比,同时考虑了结构和属性信息的、基于 GNN 的方法的优越性。

    • 其次,GNN 方法中没有明显的winner 能够在所有数据集中占据主导地位。

      实际上,在8 个数据集中的 5 个数据集,排名第二、第三的方法得分和排名最高的方法得分,平均相差不到 1%

      如果我们有兴趣对模型之间进行比较,则可以进行pairwise t-test 。这里我们考虑模型之间的相对准确率作为 pairwise t-test 的替代。具体而言:

      • 首先对每个数据集进行随机拆分。
      • 然后在这个拆分中,训练并得到每个模型的准确率(已经对 20 次随机初始化取平均)。
      • 对于这个拆分中,准确率最高模型的准确率为最优准确率。我们将每个模型的准确率除以最优准确率,则得到相对准确率。
      • 然后我们对模型根据相对准确率排名,1 表示最佳、10 表示最差。

      对于每个模型,考虑所有拆分的排名、以及平均相对准确率,如下表所示。我们观察到:GCN 在所有模型中实现最佳性能。

      尽管这个结论令人惊讶,但是在其它领域都有类似报道。如果对所有方法均谨慎地执行超参数调优,则简单的模型通常会超越复杂的模型。

    • 最后,令人惊讶的是 GAT 针对 Amazon ComputersAmazon Photo 数据集获得的结果得分相对较低,且方差很大。

      为研究这个现象,我们可视化了Amazon Photo 数据集上不同模型的准确率得分。尽管所有 GNN 模型的中位数median 得分都很接近,但是由于某些权重初始化,GAT 模型的得分非常低(低于 40%)。尽管这些异常值较少出现(2000 次结果中有 138 次发生),但是这显著降低了 GAT 的平均得分。

  2. 我们评估 train/validation/test 拆分的效果。为此,我们执行以下简单实验:将数据集按照 《Revisiting semi-supervised learning with graph embeddings》 中的随机拆分(仅拆分一次),然后运行四个模型并评估模型的相对准确率。

    可以看到:如果执行另外一次随机拆分(拆分比例都相同),则模型的相对准确率排名完全不同。

    这证明了单次拆分中的评估结果的脆弱和误导性。考虑到小扰动的情况下,GNN 的预测可能发生很大变化,这进一步明确了基于多次拆分的评估策略的必要性。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文