import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraphgcn_msg=fn.copy_src(src="h",out="m")
gcn_reduce=fn.sum(msg="m",out="h")#聚合邻居节点的特征#定义节点的UDF apply_nodes  他是一个完全连接层
class NodeApplyModule(nn.Module):#初始化def __init__(self,in_feats,out_feats,activation):super(NodeApplyModule,self).__init__()self.linear=nn.Linear(in_feats,out_feats)self.activation=activation#前向传播def forward(self,node):h=self.linear(node.data["h"])if self.activation is not None:h=self.activation(h)return {"h":h}#定义GCN模块  GCN模块的本质是在所有节点上执行消息传递  然后再调用NOdeApplyModule全连接层
class GCN(nn.Module):#初始化def __init__(self,in_feats,out_feats,activation):super(GCN,self).__init__()#调用全连接层模块self.apply_mod=NodeApplyModule(in_feats,out_feats,activation)#前向传播def forward(self,g,feature):g.ndata["h"]=feature#feature应该对应的整个图的特征矩阵g.update_all(gcn_msg,gcn_reduce)g.apply_nodes(func=self.apply_mod)#将更新操作应用到节点上return g.ndata.pop("h")#利用cora数据集搭建网络然后训练
class Net(nn.Module):#初始化网络参数def __init__(self):super(Net,self).__init__()self.gcn1=GCN(1433,16,F.relu)#第一层GCNself.gcn2=GCN(16,7,None)#前向传播def forward(self,g,features):x=self.gcn1(g,features)x=self.gcn2(g,x)return x
net=Net()
net#使用DGL内置模块加载cora数据集
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():data = citegrh.load_cora()#加载数据集features=th. FloatTensor(data.features)#特征向量  张量的形式labels=th.LongTensor(data.labels)#所属类别train_mask=th.BoolTensor(data.train_mask)#那些参与训练test_mask=th.BoolTensor(data.test_mask)#哪些是测试集g=data.graphg.remove_edges_from(nx.selfloop_edges(g))#删除自循环的边g = DGLGraph(g)g.add_edges(g.nodes(), g.nodes())return g, features, labels, train_mask, test_maskg, features, labels, train_mask, test_mask=load_cora_data()import matplotlib.pyplot as plt
nx.draw(g.to_networkx(),node_size=50,with_labels=True)
plt.show()#测试模型
def evaluate(model, g, features, labels, mask):model.eval()#会通知所有图层您处于评估模式with th.no_grad():logits = model(g, features)logits = logits[mask]labels = labels[mask]_, indices = th.max(logits, dim=1)correct = th.sum(indices == labels)return correct.item() * 1.0 / len(labels)#训练网络
import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()#定义优化器
optimizer=th.optim.Adam(net.parameters(),lr=1e-3)
dur=[]#时间
for epoch in range(100):print(epoch)if epoch>=3:t0=time.time()net.train()logits = net(g, features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >=3:dur.append(time.time() - t0)acc = evaluate(net, g, features, labels, test_mask)print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), acc, np.mean(dur)))

DGL系列之(二):使用DGL实现GCN相关推荐

  1. DGL教程【二】如何通过DGL表示一个Graph

    通过本节,将学到: 从头开始用DGL构建一个Graph 给Graph添加节点和边的特征 获取一些图的信息,如节点的度或节点的其他属性 将DGL graph 转换到另一个graph 加载.保存DGL g ...

  2. 经典算法研究系列:二、Dijkstra 算法初探

    经典算法研究系列:二.Dijkstra 算法初探  July   二零一一年一月 ====================== 本文主要参考:算法导论 第二版.维基百科. 写的不好之处,还望见谅. 本 ...

  3. 容器开启数据服务之旅系列(二):Kubernetes如何助力Spark大数据分析

    摘要: 容器开启数据服务之旅系列(二):Kubernetes如何助力Spark大数据分析 (二):Kubernetes如何助力Spark大数据分析 概述 本文为大家介绍一种容器化的数据服务Spark ...

  4. Skype for business混合部署系列之二自定义拓扑信息

    Skype for business混合部署系列之二自定义拓扑信息 此次部署前端服务器共3台,后端数据库2台采用always on方式,2台SQL Server服务器已经安装完成,在这里不做文档,本章 ...

  5. 黄聪:Microsoft Enterprise Library 5.0 系列教程(二) Cryptography Application Block (高级)

    原文:黄聪:Microsoft Enterprise Library 5.0 系列教程(二) Cryptography Application Block (高级) 本章介绍的是企业库加密应用程序模块 ...

  6. 深入理解 Linux Cgroup 系列(二):玩转 CPU

    原文链接:深入理解 Linux Cgroup 系列(二):玩转 CPU 上篇文章主要介绍了 cgroup 的一些基本概念,包括其在 CentOS 系统中的默认设置和控制工具,并以 CPU 为例阐述 c ...

  7. 【冰极峰教程系列之二】:牢不可破的九宫格布局

    原创:冰极峰 转载请注明出处 时间:2009年6月22日 8:40:16 冰极峰教程系列之一:九宫格基本布局 冰极峰教程系列之二:牢不可破的九宫格布局 冰极峰教程系列之三:三层分离的完美九宫格 冰极峰 ...

  8. WPF技术触屏上的应用系列(二): 嵌入百度地图、API调用及结合本地数据库在地图上进行自定义标点的实现...

    原文:WPF技术触屏上的应用系列(二): 嵌入百度地图.API调用及结合本地数据库在地图上进行自定义标点的实现 去年某客户单位要做个大屏触屏应用,要对档案资源进行展示之用.客户端是Window7操作系 ...

  9. MySQL优化系列(二)--查找优化(1)(非索引设计)

    MySQL优化系列(二)--查找优化(1)(非索引设计) 接下来这篇是查询优化,用户80%的操作基本都在查询,我们有什么理由不去优化他呢??所以这篇博客将会讲解大量的查询优化(索引以及库表结构优化等高 ...

  10. 精通八大排序算法系列:二、堆排序算法

    精通八大排序算法系列:二.堆排序算法 作者:July .二零一一年二月二十日 本文参考:Introduction To Algorithms,second edition. ------------- ...

最新文章

  1. Nginx 独立图片服务器的搭建
  2. SQL Date Utility
  3. 第三十四期:花了一个星期,我终于把RPC框架整明白了!
  4. 简单的优化mysql,提高查询性能
  5. oracle apache服务占用80端口
  6. chrome developer tool 调试技巧2
  7. Sobel 边缘检测 matlab代码实现
  8. C#批量发送短信操作
  9. 用C语言计算矩阵求和
  10. 《燃点》-- 星星之火可以燎原
  11. 787. K 站中转内最便宜的航班
  12. 知识点滴 - 什么是YAML文件
  13. 如何使用文件保险箱加密 Mac 数据?
  14. win11无法连接wifi怎么办?
  15. 简单演示Exploit SEH原理(未开启SafeSEH模块)
  16. 镜头与相机法兰距匹配
  17. 英语----我们快乐生活的一部分
  18. 机器人行业发展方向预测报告
  19. 新股民零基础入市必读
  20. 一位金融工程小硕的华丽逆袭人生!超真实Quant菜鸟的修行路

热门文章

  1. macOS 10.10 u盘安装win7
  2. C++实现走迷宫算法(1)
  3. java打印取消页眉页脚_Javascript页面打印的页眉页脚的清除与设置
  4. Spark列级血缘(字段级别血缘)开发与实现
  5. 如何给女朋友解释什么是面向对象编程?
  6. 4245. 【五校联考6day2】er
  7. JZOJ 4252. 【五校联考7day2】QYQ的图
  8. dhcp服务器怎么设置虚拟网段,配置DHCP服务器不同网段分配ip
  9. 如何在 Excel 中筛选数据透视表中的数据?
  10. 深入理解DNS(域名系统)