返回介绍

数学基础

统计学习

深度学习

工具

Scala

二十八、MPNN [2017]

发布于 2023-07-17 23:38:24 字数 103216 浏览 0 评论 0 收藏 0

  1. 机器学习预测分子和材料的性质仍处于起步阶段。迄今为止,将机器学习应用于化学任务的大多数研究都围绕着特征工程展开,神经网络在化学领域并未广泛采用。这使人联想到卷积神经网络被广泛采用之前的图像模型image model 的状态,部分原因是缺乏经验证据表明:具有适当归纳偏置inductive bias 的神经网络体系结构可以在该领域获得成功。

    最近,大规模的量子化学计算 quantum chemistry calculation 和分子动力学模拟molecular dynamics simulation,加上高通量high throughput 实验的进展,开始以前所未有的速度产生数据。大多数经典的技术不能有效地利用现在的大量数据。假设我们能找到具有适当归纳偏置的模型,将更强大和更灵活的机器学习方法应用于这些问题的时机已经成熟。原子系统的对称性表明,在图结构数据上操作并对图同构graph isomorphism 不变的神经网络可能也适合于分子。足够成功的模型有朝一日可以帮助实现药物发现或材料科学中具有挑战性的化学搜索问题的自动化。

    在论文 《Neural Message Passing for Quantum Chemistry》 中,作者的目标是为化学预测问题展示有效的机器学习模型,这些模型能够直接从分子图 molecular graph 中学习特征,并且对图同构不变 invariant 。为此,论文描述了一个在图上进行监督学习的一般框架,称为信息传递神经网络(Message Passing Neural Network: MPNN)。MPNN 简单地抽象了现有的几个最有前景的图神经模型之间的共性,以便更容易理解它们之间的关系,并提出新的变体。鉴于许多研究人员已经发表了适合 MPNN 框架的模型,作者认为社区应该在重要的图问题上尽可能地推动这种通用方法,并且只提出由application 所启发的新变体,例如论文中考虑的应用:预测小有机分子的量子力学特性(如下图所示)。

    最后,MPNN 在分子属性预测benchmark 上取得了 state-of-the-art 的结果。

    论文贡献:

    • 论文开发了一个 MPNN 框架 ,它在所有13个目标target 上都取得了 SOTA 的结果,并在 13 个目标中的 11 个目标上预测到 DFT 的化学准确性。
    • 论文开发了几种不同的 MPNN ,在 13 个目标中的5个目标上预测到 DFT 的化学准确性,同时仅对分子的拓扑结构进行操作(没有空间信息作为输入)。
    • 论文开发了一种通用的方法来训练具有更大 node representationMPNN,而不需要相应地增加计算时间或内存,与以前的MPNN相比,在高维node representation 方面产生了巨大的节省。

    作者相信论文的工作是朝着使设计良好的 MPNN成为中等大小分子上的监督学习的默认方法迈出的重要一步。为了实现这一点,研究人员需要进行仔细的实证研究,以找到使用这些类型的模型的正确方法,并对其进行必要的改进。

  2. 相关工作:尽管原则上量子力学可以让我们计算分子的特性,但物理定律导致的方程太难精确解决。因此,科学家们开发了一系列的量子力学近似方法,对速度和准确率进行了不同的权衡,如带有各种函数的密度功能理论(Density Functional Theory: DFT)以及量子蒙特卡洛 Quantum Monte-Carlo 。尽管被广泛使用,DFT 仍然太慢,无法应用于大型系统(时间复杂度为O(Ne3)$ O(N^3_e ) $ ,其中Ne$ N_e $ 为电子数),并且相对于薛定谔方程的精确解,DFT 表现出系统误差和随机误差。

    《Combined first-principles calculation and neural-network correction approach for heat of formation 》 使用神经网络来近似 DFT 中一个特别麻烦的项,即交换相关势能 exchange correlation potential ,以提高DFT的准确性。然而,他们的方法未能提高DFT的效率,而是依赖于一大套临时的原子描述符 atomic descriptor。另一个方向试图直接对量子力学的解进行近似,而不求助于 DFT 。这两个方向都使用了有固有局限性的手工设计的特征。

28.1 MPNN

  1. 为简单起见我们考虑无向图。给定无向图G=(V,E)$ \mathcal G=(\mathcal V, \mathcal E) $ ,其中V={v1,,vn}$ \mathcal V=\{v_1,\cdots,v_n\} $ 为节点集合,E={ei,j}$ \mathcal E=\{e_{i,j}\} $ 为边集合。

    • 每个节点vi$ v_i $ 关联一个节点特征xiRdf$ \mathbf{\vec x}_i\in \mathbb R^{d_f} $ 。
    • 每条边ei,j$ e_{i,j} $ 关联一个边特征ei,jRde$ \mathbf{\vec e}_{i,j}\in \mathbb R^{d_e} $ 。

    将无向图推广到有向的多图multigraph(即多条边)也很容易。

  2. GNN 的前向传播具有两个阶段:消息传递阶段、readout 阶段:

    • 消息传递阶段执行T$ T $ 个时间 step,它通过消息函数message functionMt()$ M_t(\cdot) $ 和节点更新函数update functionUt()$ U_t(\cdot) $ 来定义。

      在消息传递阶段,节点v$ v $ 的隐状态hv(t+1)$ \mathbf{\vec h}_v^{(t+1)} $ 是基于消息mv(t+1)$ \mathbf{\vec m}_v^{(t+1)} $ 来更新的:

      mv(t+1)=uNvMt(hv(t),hu(t),ev,u)hv(t+1)=Ut(hv(t),mv(t+1)),hv(0)=xv

      其中Nv$ \mathcal N_v $ 为节点v$ v $ 的邻域。

    • readout 阶段根据所有节点在T$ T $ 时刻的状态来计算整个图的embedding 向量y^$ \hat{\mathbf{\vec y}} $ :

      y^=R({hv(T)vG})

      其中R()$ R(\cdot) $ 为readout 函数readout function

      R()$ R(\cdot) $ 函数对节点的状态集合进行操作,并且必须满足对节点集合的排列不变性 permutation invariant 从而使得 MPNN 对图的同构不变性graph isomorphism invariant

  3. 注意:你也可以在 MPNN 中通过引入边的状态向量hev,u(t)$ \mathbf{\vec h}_{e_{v,u}}^{(t)} $ 来学习边特征,并采取类似的更新方式:

    mev,u(t+1)=sNvMte(hv(t),hs(t),hev,s(t))+sNuMte(hu(t),hs(t),heu,s(t))hev,u(t+1)=Ute(hev,u(t),mev,u(t+1)),hev,u(0)=ev,u
  4. 消息函数Mt()$ M_t(\cdot) $ 、节点更新函数Ut()$ U_t(\cdot) $ 、readout 函数R()$ R (\cdot) $ 都是待学习的可微函数。接下来我们通过给定不同的Mt(),Ut(),R()$ M_t(\cdot),U_t(\cdot),R(\cdot) $ 来定义已有的一些模型。

    • 《Convolutional Networks for Learning Molecular Fingerprints》

      • 消息函数Mt(hv(t),hu(t),ev,u)=[hu(t)||ev,u]$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u}\right) =\left[\mathbf{\vec h}_u^{(t)}||\mathbf{\vec e}_{v,u}\right] $ ,其中[||]$ [\cdot||\cdot] $ 为向量拼接。

      • 节点更新函数Ut(hv(t),mv(t+1))=σ(Mdeg(v)(t)mv(t+1))$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=\sigma\left(\mathbf M^{(t)}_{\text{deg(v)}}\mathbf{\vec m}_v^{(t+1)}\right) $ ,其中:

        • M(t)$ \mathbf M^{(t)} $ 为t$ t $ 时刻待学习的映射矩阵,下标deg(v)$ \text{deg}(v) $ 表示节点v$ v $ 的 degree,并且不同的 degree 使用不同的映射矩阵。
        • σ()$ \sigma(\cdot) $ 为 sigmoid 函数。
      • Readout 函数R()$ R(\cdot) $ 通过 skip connection 连接所有节点的所有历史状态hv(t)$ \mathbf{\vec h}_v^{(t)} $ ,并且等价于:

        f(vVt=1Tsoftmax(W(t)hv(t)))

        其中f()$ f(\cdot) $ 为一个神经网络,W(t)$ \mathbf W^{(t)} $ 为待学习的参数。

      这种消息传递方案可能是有问题的,因为得到的消息向量mv(t+1)=[uNvhu(t)||uNvev,u]$ \mathbf{\vec m}_v^{(t+1)}= \left[\sum_{u\in \mathcal N_v}\mathbf{\vec h}_u^{(t)}||\sum_{u\in \mathcal N_v}\mathbf{\vec e}_{v,u}\right] $ 分别在节点和边上进行求和。因此,这种消息传递方案无法识别节点状态和边状态之间的相关性。

    • Gated Graph Neural Networks:GG-NN

      • 消息函数Mt(hv(t),hu(t),ev,u)=Aev,uhu(t)$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u}\right) = \mathbf A_{e_{v,u}}\mathbf{\vec h}_u^{(t)} $ , 其中Aev,u$ \mathbf A_{e_{v,u}} $ 为待学习的矩阵,对每个 edge labele$ e $ (即边的类型)学习一个矩阵。注意:模型假设边的label是离散的。

      • 节点更新函数Ut(hv(t),mv(t+1))=GRU(hv(t),mv(t+1))$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=\text{GRU}\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right) $ ,其中GRUGated Recurrent Unit

        该工作使用了权重绑定 weight tying,因此在每个时间步都使用相同的更新函数。

        即,它将每个节点的t$ t $ 个时间步视为一个序列。

      • Readout 函数R()=vVσ(fi(hv(T),hv(0)))(fj(hv(T)))$ R(\cdot)=\sum_{v\in \mathcal V} \sigma\left(f_i\left(\mathbf{\vec h}_v^{(T)},\mathbf{\vec h}_v^{(0)}\right)\right)\odot \left(f_j\left(\mathbf{\vec h}_v^{(T)}\right)\right) $ ,其中fi(),fj()$ f_i(\cdot),f_j(\cdot) $ 都是神经网络,$ \odot $ 为逐元素的乘积,σ()$ \sigma(\cdot) $ 为 sigmoid 函数。

    • Interaction Networks:该工作既考虑了 node-level 目标,也考虑了 graph-level 目标。也考虑了在节点上施加的外部效应。

      • 消息函数Mt(hv(t),hu(t),ev,u)=[hv(t)||hu(t)||ev,u]$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u}\right) =\left[\mathbf{\vec h}_v^{(t)}||\mathbf{\vec h}_u^{(t)}||\mathbf{\vec e}_{v,u}\right] $ ,其中[||]$ [\cdot||\cdot] $ 为向量拼接。
      • 节点更新函数Ut(hv(t),mv(t+1))=g([hv(t)||fv(t)||mv(t+1)])$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=g\left(\left[\mathbf{\vec h}_v^{(t)}||\mathbf{\vec f}_v^{(t)}||\mathbf{\vec m}_v^{(t+1)}\right]\right) $ ,其中fv(t)$ \mathbf{\vec f}_v^{(t)} $ 为t$ t $ 时刻对节点v$ v $ 施加某些外部效应的外部向量,g()$ g(\cdot) $ 为神经网络函数。
      • 当进行 graph-level 输出时,Readout 函数R()=f(vGhv(T))$ R(\cdot) = f\left(\sum_{v\in \mathcal G}\mathbf{\vec h}_v^{(T)}\right) $ ,其中f()$ f(\cdot) $ 为神经网络函数。在原始论文中,T$ T $ 仅仅为 1
    • Molecular Graph Convolutions:该工作和 MPNN 稍有不同,因为它在消息传递阶段更新了边的表示ev,u(t)$ \mathbf{\vec e}_{v,u}^{(t)} $ 。

      • 消息函数Mt(hv(t),hu(t),ev,u(t))=ev,u(t)$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u}^{(t)}\right) =\mathbf{\vec e}_{v,u}^{(t)} $ 。

      • 节点更新函数Ut(hv(t),mv(t+1))=relu[W1(relu(W0hv(t)))||mv(t+1)]$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=\text{relu}\left[\mathbf W_1\left(\text{relu}\left(\mathbf W_0\mathbf{\vec h}_v^{(t)}\right)\right)||\mathbf{\vec m}_v^{(t+1)}\right] $ ,reluReLU 非线性激活函数,W0,W1$ \mathbf W_0,\mathbf W_1 $ 为待学习的权重矩阵。

      • 边更新函数:

        ev,u(t+1)=Ute(ev,u(t),hv(t),hu(t))=relu[W4(relu(W2ev,u(t)))||relu(W3[hv(t)||hu(t)])]

        其中W2,W3,W4$ \mathbf W_2,\mathbf W_3,\mathbf W_4 $ 为待学习的权重矩阵。

    • Deep Tensor Neural Networks

      • 消息函数Mt(hv(t),hu(t),ev,u)=tanh(W1((W2hu(t)+b1)(W3ev,u+b2)))$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u} \right) =\tanh\left(\mathbf W_1\left(\left(\mathbf W_2\mathbf{\vec h}_u^{(t)}+\mathbf{\vec b}_1\right)\odot\left(\mathbf W_3\mathbf{\vec e}_{v,u}+\mathbf{\vec b}_2\right)\right)\right) $ 。其中W1,W2,W3$ \mathbf W_1,\mathbf W_2,\mathbf W_3 $ 为待学习的权重矩阵,b1,b2$ \mathbf{\vec b}_1,\mathbf{\vec b}_2 $ 为待学习的 bias 向量。
      • 节点更新函数Ut(hv(t),mv(t+1))=hv(t)+mv(t+1)$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=\mathbf{\vec h}_v^{(t)}+\mathbf{\vec m}_v^{(t+1)} $ 。
      • Readout 函数R()=vNN(hv(T))$ R(\cdot) = \sum_v \text{NN}\left(\mathbf{\vec h}_v^{(T)}\right) $ ,其中NN()$ \text{NN}(\cdot) $ 为单层神经网络。
    • Laplacian Based Methods,例如GCN

      • 消息函数Mt(hv(t),hu(t),ev,u)=cv,uhu(t)$ M_t\left(\mathbf{\vec h}_v^{(t)} , \mathbf{\vec h}_u^{(t)},\mathbf{\vec e}_{v,u} \right) = c_{v,u}\mathbf{\vec h}_u^{(t)} $ ,其中:

        cv,u=Av,udeg(v)×deg(u)

        其中Av,u$ A_{v,u} $ 为邻接矩阵A$ \mathbf A $ 的项,deg(v) 为节点v$ v $ 的 degree

      • 节点更新函数Ut(hv(t),mv(t+1))=relu(W(t)mv(t+1))$ U_t\left(\mathbf{\vec h}_v^{(t)},\mathbf{\vec m}_v^{(t+1)}\right)=\text{relu}\left(\mathbf W^{(t)}\mathbf{\vec m}_v^{(t+1)}\right) $ 。

  5. 将这些方法抽象为通用的 MPNN 的好处是:我们可以确定关键的实现细节,并可能达到这些模型的极限,从而指导我们进行未来的模型改进。

    所有这些方法的缺点之一是计算时间。最近的工作通过在每个time step 仅在图的子集上传递消息,已经将 GG-NN 架构应用到更大的图。这里我们也提出了一种可以改善计算成本的 MPNN 修改。

28.2 MPNN 变体

  1. 我们基于 GG-NN 模型探索 MPNN,我们认为 GG-NN 是一个很强的 baseline 。我们聚焦于探索不同的消息函数、输出函数,从而找到适当的输入 representation 以及正确调优的超参数。

  2. 消息函数探索:

    • 矩阵乘法作为消息函数:首先考察 GG-NN 中使用的消息函数,它定义为

      Mt(hv(t),hu(t),ev,u)=Aev,uhu(t)

      其中Aev,u$ \mathbf A_{e_{v,u}} $ 为待学习的矩阵,对每个 edge labele$ e $ (即边的类型)学习一个矩阵。注意:模型假设边的label是离散的。

    • Edge Network:为了支持向量值的 edge 特征,我们使用以下消息函数:

      Mt(hv(t),hu(t),ev,u)=A(ev,u)hu(t)

      其中A(ev,u)$ \mathcal A(\mathbf{\vec e}_{v,u}) $ 是一个神经网络,它将每个edge 特征ev,u$ \mathbf{\vec e}_{v,u} $ 映射到一个Rd×d$ \mathbb R^{d\times d} $ 矩阵。其中d$ d $ 表示内部隐状态的维度。

    • Pair Message:前面两种消息函数仅依赖于隐状态hu$ \mathbf{\vec h}_u $ 和边特征ev,u$ \mathbf{\vec e}_{v,u} $ ,而不依赖于隐状态hv$ \mathbf{\vec h}_v $ 。理论上如果消息同时包含源节点和目标节点的信息,则网络可能更有效地传递消息。因此消息函数定义为:

      Mt(hv(t),hu(t),ev,u)=f(hv(t),hu(t),ev,u)

      其中f()$ f(\cdot) $ 为一个神经网络。

    当我们将上述消息函数应用于有向图时,将使用两个独立的函数Mt(in)$ M_t^{(\text{in})} $ 和Mt(out)$ M_t^{(\text{out})} $ 。至于在特定的边ev,u$ e_{v,u} $ 上应用哪一个,则取决于边的方向。

  3. 虚拟节点 & 虚拟边:我们探索了两种方式来在图中添加虚拟元素,从而修改了消息传递的方式(使得消息传播得更广):

    • 虚拟边:在未连接节点pair 对之间添加虚拟边,这个边的类型是特殊类型。这可以实现为数据预处理步骤,并允许消息在传播阶段传播很长一段距离。

    • 虚拟节点:虚拟一个 master 节点,该节点以特殊的边类型来连接到图中的每个输入节点。

      此时master 节点充当全局暂存空间,每个节点都在消息传递的每个step 中从master 读取信息、向 master 写入信息。这允许信息在传播阶段传播很长的距离。

      我们允许 master 节点具有单独的节点维度dmaster$ d_{\text{master}} $ ,也允许 master 节点在内部状态更新函数中使用单独的权重矩阵。

      由于加入了 master 节点,理论上模型复杂度有所增加,并提升了模型型容量。

  4. Readout 函数:我们尝试了两种 Readout 函数。

    • 一种是在 GG-NN 中使用的 Readout 函数:

      R()=vVσ(fi(hv(T),hv(0)))(fj(hv(T)))
    • 另一种是 Set2Set 模型,该模型专门为Set 输入而设计的,并且比简单地累加final node state具有更强的表达能力。

      该模型首先将线性投影应用于每个元组(hv(T),xv)$ \left(\mathbf{\vec h}_v^{(T)},\mathbf{\vec x}_v\right) $ ,然后将一个 set 的元组投影作为输入。然后,在经过N$ N $ 个计算 step 之后,Set2Set 模型将产生 graph-level embeddingqt$ \mathbf{\vec q}_t^* $ ,该 embedding 对于set 的顺序具有不变性。我们将这个 embeddingqt$ \mathbf{\vec q}_t^* $ 馈入一个神经网络从而产生输出。

  5. Multiple TowersMPNN 的一个问题是可扩展性,特别是对于稠密图。消息传递阶段的每个 step 需要O(n2d2)$ O(n^2d^2) $ 次浮点乘法。当n$ n $ 或者d$ d $ 较大时,其计算代价太大。为解决这个问题:

    • 我们将d$ d $ 维的节点 embeddinghv(t)$ \mathbf{\vec h}_v^{(t)} $ 拆分为K$ K $ 个维度为d/k$ d/k $ 的 embedding{hv(t,k)}k=1,,K$ \left\{\mathbf{\vec h}_v^{(t,k)}\right\}_{k=1,\cdots,K} $ ,每个拆分代表图在某个隐空间下的一种 embedding

    • 然后我们在每个隐空间k$ k $ 上独立地执行消息传递和节点更新,从而得到临时的 embedding{h~v(t,k)}$ \left\{\tilde{\mathbf{\vec h}}_v^{(t,k)}\right\} $ 。

    • 最后这K$ K $ 种 embedding 结果通过以下方式混合:

      (hv(t,1)||||hv(t,K))=g(h~v(t,1)||||h~v(t,K))

      其中:g()$ g(\cdot) $ 表示一个神经网络,并且g()$ g(\cdot) $ 在所有节点上共享;||$ || $ 表示向量拼接。

    这种混合方式保留了节点的排列不变性permutation invariant ,同时允许图的不同embedding 在传播阶段相互交流。

    这种方法是有利的,因为对于相同数量的参数数量,它能产生更大的假设空间,表达能力更强。并且时间复杂度更低。当消息函数是矩阵乘法时,某种 embedding 的传播step 花费O(n2(d/K)2)$ O(n^2(d/K)^2) $ 的时间,一共有K$ K $ 种embedding ,因此总的时间复杂度为O(n2d2/K)$ O(n^2d^2/K) $ 。另外还有一些额外的开销,因为有用于混合的神经网络。

    Multiple Towers 就是 multi-head 的思想。

28.3 实验

  1. 数据集:QM-9 分子数据集,包含 130462 个分子。我们随机选择 10000 个样本作为验证集、10000 个样本用于测试集、其它作为训练集。特征(如下表所示)和 label 的含义参考原始论文。

    我们使用验证集进行早停和模型选择,并在测试集上报告mean absolute error:MAE

    结论:

    • 针对每个目标训练一个模型始终优于对所有13 个目标进行联合训练。
    • 最优的 MPNN 变体使用edge network 消息函数。
    • 添加虚拟边、添加master 节点、将 graph-level 输出修改为 Set2Set 输出对于 13 个目标都有帮助。
    • Multiple Towers 不仅可以缩短训练时间,还可以提高泛化性能。

    具体实验细节参考原始论文。

    下图中,enn-s2s 表示最好的 MPNN 变体(使用 edge network 消息函数、set2set 输出、以及在具有显式氢原子的图上操作),enn-s2s-ens5 表示对应的 ensemble

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文