GNN教程:DGL框架实现GCN算法!

Datawhale

共 4385字,需浏览 9分钟

 · 2020-12-22

↑↑↑关注后"星标"Datawhale
每日干货 & 每月组队学习,不错过
 Datawhale干货 
作者:秦州,算法工程师,Datawhale成员

引言

本文为GNN教程的第七篇文章【使用DGL框架实现GCN算法】。图神经网络的计算模式大致相似,节点的Embedding需要汇聚其邻接节点Embedding以更新,从线性代数的角度来看,这就是邻接矩阵和特征矩阵相乘。然而邻接矩阵通常都会很大,因此另一种计算方法是将邻居的Embedding传递到当前节点上,再进行更新。很多图并行框架都采用详细传递的机制进行运算(比如Google的Pregel)。而图神经网络框架DGL也采用了这样的思路。

后台回复【GNN】进图神经网络交流群。

从本篇博文开始,我们使用DGL做一个系统的介绍,我们主要关注他的设计,尤其是应对大规模图计算的设计。这篇文章将会介绍DGL的核心概念 — 消息传递机制,并且使用DGL框架实现GCN算法。

DGL 核心 — 消息传递

DGL 的核心为消息传递机制(message passing),主要分为消息函数 (message function)和汇聚函数(reduce function)。如下图所示:

  • 消息函数(message function):传递消息的目的是将节点计算时需要的信息传递给它,因此对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递到目的节点;对于每个目的节点来说,它可能会受到多个源节点传过来的消息,它会将这些消息存储在"邮箱"中。
  • 汇聚函数(reduce function):汇聚函数的目的是根据邻居传过来的消息更新跟新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox['m'])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox['m'])内消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。

下面我们以GCN的算法为例,详细说明消息传递的机制是如何work的。

用消息传递的方式实现GCN

GCN 的线性代数表达

GCN 的逐层传播公式如下所示:

从线性代数的角度,节点Embedding的的更新方式为首先左乘邻接矩阵以汇聚邻居Embedding,再为新Embedding做一次线性变换(右乘)。

简而言之:每个节点拿到邻居节点信息汇聚到自身 embedding 上在进行一次变换。具体 GCN 内容介绍可参考之前的文章

从消息传递的角度分析

上面的数学描述可以利用消息传递的机制实现为:

  1. 在 GCN 中每个节点都有属于自己的表示 ;
  2. 根据消息传递(message passing)的范式,每个节点将会收到来自邻居节点发送的Embedding;
  3. 每个节点将会对来自邻居节点的 Embedding进行汇聚以得到中间表示 
  4. 对中间节点表示  进行线性变换,然后在利用非线性函数进行计算:;
  5. 利用新的节点表示  对该节点的表示 进行更新。

具体实现

step 1,引入相关包

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

step 2,我们需要定义 GCN 的 message 函数和 reduce 函数, message 函数用于发送节点的Embedding,reduce 函数用来对收到的 Embedding 进行聚合。在这里,每个节点发送Embedding的时候不需要任何处理,所以可以通过内置的copy_scr实现,out='m'表示发送到目的节点后目的节点的mailbox用m来标识这个消息是源节点的Embedding。

目的节点的reduce函数很简单,因为按照GCN的数学定义,邻接矩阵和特征矩阵相乘,以为这更新后的特征矩阵的每一行是原特征矩阵某几行相加的形式,"某几行"是由邻接矩阵选定的,即对应节点的邻居所在的行。因此目的节点reduce只需要通过sum将接受到的信息相加就可以了。

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

step 3,我们定义一个应用于节点的 node UDF(user defined function),即定义一个全连接层(fully-connected layer)来对中间节点表示  进行线性变换,然后在利用非线性函数进行计算:

class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation

def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}

step 4,我们定义 GCN 的Embedding更新层,以实现在所有节点上进行消息传递,并利用 NodeApplyModule 对节点信息进行计算更新。

class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

def forward(self, g, feature):
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')

step 5,最后,我们定义了一个包含两个 GCN 层的图神经网络分类器。我们通过向该分类器输入特征大小为 1433 的训练样本,以获得该样本所属的类别编号,类别总共包含 7 类。

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, F.relu)

def forward(self, g, features):
x = self.gcn1(g, features)
x = self.gcn2(g, x)
return x
net = Net()
print(net)

step 6,加载 cora 数据集,并进行数据预处理。

from dgl.data import citation_graph as citegrh
def load_cora_data():
data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask)
g = data.graph
# add self loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask

step 7,训练 GCN 神经网络。

import time
import numpy as np
g, features, labels, mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(30):
if epoch >=3:
t0 = time.time()

logits = net(g, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch >=3:
dur.append(time.time() - t0)

print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))

后话

本篇博文介绍了如何利用图神经网络框架DGL编写GCN模型,接下来我们会介绍如何利用DGL实现GraphSAGE中的采样机制,以减少运算规模。

Reference

  1. DGL Basics
  2. Graph Convolutional Network
  3. PageRank with DGL Message Passing
  4. DGL 作者答疑!关于 DGL 你想知道的都在这里
“干货学习,三连
浏览 43
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报