使用的Cora数据集,该数据集由2708篇论文,以及它们之间引用关系构成5429条边组成。根据主题这些论文被分为七类:神经网络、强化学习、规则学习、概率方法、遗传算法、理论研究、案例相关。废话不多说、直接上完整实验代码:(训练测试一键运行、最后给出两张实验结果可视化的图)

七个分类

#-*- coding: utf-8 -*-
import itertools
import os
import os.path as osp   #os.path 模块主要用于获取文件的属性
import pickle   #该pickle模块实现了用于序列化和反序列化Python对象结构的二进制协议。 “Pickling”是将Python对象层次结构转换为字节流的过程, “unpickling”是反向操作,从而将字节流(来自二进制文件或类似字节的对象)转换回对象层次结构。pickle模块对于错误或恶意构造的数据是不安全的。
import urllib   #urllib.request 模块提供了最基本的构造 HTTP 请求的方法,利用它可以模拟浏览器的一个请求发起过程,
from collections import namedtuple
# 全局取消证书验证
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optimfrom sklearn.manifold import TSNE
from matplotlib import cm
from matplotlib import pyplot as plt #用于保存处理好的数据
Data = namedtuple('Data', ['x', 'y', 'adjacency', 'train_mask', 'val_mask', 'test_mask'])
num_nodes = 0
class CoreData(object):download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"filenames = ["ind.cora.{}".format(name) for name in['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]def __init__(self, data_root="cora", rebuild=False):""" 包括数据下载/处理/加载等功能当数据集的缓存文件存在时,将使用缓存文件,否则将下载、处理、并缓存到磁盘Args:-----------data_root: string, optional存放数据的目录,原始数据路径:{data_root}/raw缓存数据路径:{data_root}/processed_cora.pklrebuild: boolean, optional是否需要重新构建数据集,当设置为true时,如果缓存数据存在也会重建数据"""self.data_root = data_rootsave_file = osp.join(self.data_root, "processed_cora.pkl")if osp.exists(save_file) and not rebuild:print("Using Cached file:{}".format(save_file))self._data = pickle.load(open(save_file, "rb"))else:self.maybe_download()self._data = self.process_data()with open(save_file, "wb") as f:pickle.dump(self.data, f)print("Catched file:{}".format(save_file))@propertydef data(self):return self._datadef maybe_download(self):save_path = osp.join(self.data_root,"raw")for name in self.filenames:if not osp.exists(osp.join(save_path, name)):self.download_data("{}/{}".format(self.download_url, name), save_path)@staticmethoddef download_data(url, save_path):"""数据下载工具,当原始数据不存在时将会进行下载"""if not osp.exists(save_path):os.makedirs(save_path)data = urllib.request.urlopen(url)filename = osp.basename(url)with open(osp.join(save_path,filename),'wb') as f:f.write(data.read())return Truedef process_data(self):"""处理数据,得到节点特征和标签,邻接矩阵,训练集、验证集及测试集"""print("Process data...")_,tx,allx,y,ty,ally,graph,test_index = [self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames]train_index = np.arange(y.shape[0])val_index = np.arange(y.shape[0], y.shape[0] + 500)sorted_test_index = sorted(test_index)x = np.concatenate((allx,tx), axis=0)y = np.concatenate((ally,ty), axis=0).argmax(axis=1)x[test_index] = x[sorted_test_index]y[test_index] = y[sorted_test_index]num_nodes = x.shape[0]print("num_nodes",num_nodes)train_mask = np.zeros(num_nodes, dtype=np.bool)val_mask = np.zeros(num_nodes, dtype=np.bool)test_mask = np.zeros(num_nodes, dtype=np.bool)train_mask[train_index] = Trueval_mask[val_index] = Truetest_mask[test_index] = Trueprint("Node's feature shape: ", x.shape)print("Node's label shape: ", y.shape)print("Number of training nodes: ", train_mask.sum())print("Number of validation nodes: ", val_mask.sum())print("Number of test nodes: ", test_mask.sum())adjacency = self.build_adjacency(graph)print("Adjacency's shape: ", adjacency.shape)return Data(x=x, y=y, adjacency=adjacency, train_mask=train_mask, val_mask=val_mask,test_mask=test_mask)@staticmethoddef build_adjacency(adj_dict):"""根据邻接表创建邻接矩阵"""edge_index = []num_modes = len(adj_dict)print('num_modes',num_modes)for src, dst in adj_dict.items():edge_index.extend([src, v] for v in dst)edge_index.extend([v, src] for v in dst)print('edge_index',len(edge_index)) #21716#由于上述结果中存在重复的边,删掉重复的边edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index)))#print('edge_index111',len(edge_index)) #10556edge_index = np.asarray(edge_index)adjacency = sp.coo_matrix((np.ones(len(edge_index)),(edge_index[:, 0], edge_index[:, 1])),shape = (num_modes, num_modes), dtype="float32")print(adjacency)return adjacency@staticmethoddef read_data(path):"""使用不同的方式读取原始数据以进一步处理"""name = osp.basename(path)if name == "ind.cora.test.index":out = np.genfromtxt(path, dtype="int64")return outelse:out = pickle.load(open(path, "rb"), encoding="latin1")out = out.toarray() if hasattr(out, "toarray") else outreturn out############   根据GCN的定义来定义GCN层:
class GraphConvolution(nn.Module):def __init__(self, input_dim, output_dim, use_bias=True):"""图卷积:L*X*\thetaArgs:----------input_dim: int节点输入特征的维度output_dim: int输出特征维度use-bias: bool, optional是否使用偏置"""super(GraphConvolution, self).__init__()   #super() 函数是用于调用父类(超类)的一个方法。self.input_dim = input_dimself.output_dim = output_dimself.use_bias = use_biasself.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))if self.use_bias:self.bias = nn.Parameter(torch.Tensor(output_dim))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):init.kaiming_uniform_(self.weight)if self.use_bias:init.zeros_(self.bias)def forward(self, adjacency, input_feature):"""邻接矩阵是稀疏矩阵,因此在计算时使用稀疏矩阵乘法Args:--------------adjacency: torch.sparse.FloatTensor邻接矩阵input_feature: torch.Tensor输入特征"""support = torch.mm(input_feature, self.weight)output = torch.sparse.mm(adjacency, support)if self.use_bias:output += self.biasreturn output# 设计两层 GCN 的模型
class GcnNet(nn.Module):def __init__(self, input_dim=1433):super(GcnNet,self).__init__()self.gcn1 = GraphConvolution(input_dim, 16)self.gcn2 = GraphConvolution(16,7)def forward(self, adjacency, feature):h = F.relu(self.gcn1(adjacency, feature))logits = self.gcn2(adjacency, h)return logits# 模型构建与数据准备
def normalization(adjacency):"""计算 L = D^-0.5 * (A+I) * D^-0.5 """adjacency += sp.eye(adjacency.shape[0]) #增加自连接degree = np.array(adjacency.sum(1))d_hat = sp.diags(np.power(degree, -0.5).flatten())return d_hat.dot(adjacency).dot(d_hat).tocoo()#超参数定义
learning_rate = 0.1
weight_decay = 5e-4
epochs = 300
#模型定义,包括模型实例化、损失函数与优化器定义
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GcnNet().to(device)
#损失函数使用交叉熵
criterion = nn.CrossEntropyLoss().to(device)
#优化器使用Adam
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
#加载数据,并转化为torch.Tensor
dataset = CoreData().data
x = dataset.x / dataset.x.sum(1, keepdims=True) #归一化数据,使得每一行和为1
tensor_x = torch.from_numpy(x).to(device)
tensor_y = torch.from_numpy(dataset.y).to(device)
tensor_train_mask = torch.from_numpy(dataset.train_mask).to(device)
tensor_val_mask = torch.from_numpy(dataset.val_mask).to(device)
tensor_test_mask = torch.from_numpy(dataset.test_mask).to(device)
normalize_adjacency = normalization(dataset.adjacency) #规范化邻接矩阵
indices = torch.from_numpy(np.asarray([normalize_adjacency.row, normalize_adjacency.col]).astype('int64')).long()
values = torch.from_numpy(normalize_adjacency.data.astype(np.float32))
tensor_adjacency = torch.sparse.FloatTensor(indices, values, (2708, 2708)).to(device)#模型训练与测试def train():loss_history = []val_acc_history = []model.train()train_y = tensor_y[tensor_train_mask]for epoch in range(epochs):logits = model(tensor_adjacency, tensor_x) #前向传播train_mask_logits = logits[tensor_train_mask] #只选择训练节点进行监督loss = criterion(train_mask_logits, train_y) #计算损失值optimizer.zero_grad()loss.backward() #反向传播计算参数的梯度optimizer.step() #使用优化方法进行梯度更新train_acc = test(tensor_train_mask) #计算当前模型在训练集上的准确率val_acc = test(tensor_val_mask) #计算当前模型在验证集上的准确度#计算训练过程中损失值和准确率的变化,用于画图loss_history.append(loss.item())val_acc_history.append(val_acc.item())print("Epoch {:03d}:Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(epoch, loss.item(), train_acc.item(), val_acc.item()))#TSNE降维可视化tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) # TSNE降维,降到2# 只需要显示前500个plot_only = 300# 降维后的数据low_dim_embs = tsne.fit_transform(train_mask_logits.data.numpy()[:plot_only, :])# 标签labels = train_y.numpy()[:plot_only]plot_with_labels(low_dim_embs, labels,"epoch{}_step{}".format(epochs,epochs))plt.show()plt.clf() #清除画布#训练Log可视化font1 = {'family' : 'Times New Roman', 'weight' : 'normal', 'size'  : 9, } font2 = {'family' : 'Times New Roman', 'weight' : 'normal', 'size'  : 14, }#第一条线:lossepochi = range(0, len(loss_history))plt.plot(epochi, loss_history, color='blue', label='$Loss$', linewidth=1.2) # 绘制,指定颜色、标签、线宽,标签采用latex格式plt.ylim(0, 1)                      # 设定y轴范围hl=plt.legend(loc='upper right', prop=font1, frameon=False)                # 绘制图例,指定图例位置#set(hl,'Box','off');#第二条曲线:val_accepochi = range(0, len(val_acc_history))plt.plot(epochi, val_acc_history, color='red', label='$ValAcc$', linewidth=1.2)plt.legend(loc='upper right', prop=font1, frameon=False)                # 绘制图例,指定图例位置#plt.xticks([])                        # 去掉x坐标轴刻度plt.xlim(0, epochs)plt.savefig("train_log.png", dpi=600)plt.show()return loss_history, val_acc_historydef plot_with_labels(lowDWeights, labels,i):plt.cla()# 降到二维了,分别给x和yX, Y = lowDWeights[:, 0], lowDWeights[:, 1]# 遍历每个点以及对应标签for x, y, s in zip(X, Y, labels):c = cm.rainbow(int(255/9 * s)) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间plt.text(x, y, s, backgroundcolor=c, fontsize=9)plt.xlim(X.min(), X.max())plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer')plt.savefig("{}.jpg".format(i))def test(mask):model.eval()with torch.no_grad():logits = model(tensor_adjacency, tensor_x)test_mask_logits = logits[mask]predict_y = test_mask_logits.max(1)[1]accuracy = torch.eq(predict_y, tensor_y[mask]).float().mean()return accuracytrain()test_acc = test(tensor_test_mask)
print('test_acc',test_acc)

深入浅出图神经网络书本 GCN源码实战相关推荐

  1. GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析

    GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 文章目录 GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析 SetUp,库声明 数据准 ...

  2. Java 并发编程 -- 线程池源码实战

    一.概述 小编在网上看了好多的关于线程池原理.源码分析相关的文章,但是说实话,没有一篇让我觉得读完之后豁然开朗,完完全全的明白线程池,要么写的太简单,只写了一点皮毛,要么就是是晦涩难懂,看完之后几乎都 ...

  3. matlab实现cnn代码,CNN 经典的卷积神经网络MATLAB实现源码,可直接运行。 276万源代码下载- www.pudn.com...

    文件名称: CNN下载  收藏√  [ 5  4  3  2  1 ] 开发工具: matlab 文件大小: 47017 KB 上传时间: 2016-11-03 下载次数: 93 提 供 者: 郝永达 ...

  4. 一步一步手绘Spring AOP运行时序图(Spring AOP 源码分析)

    相关内容: 架构师系列内容:架构师学习笔记(持续更新) 一步一步手绘Spring IOC运行时序图一(Spring 核心容器 IOC初始化过程) 一步一步手绘Spring IOC运行时序图二(基于XM ...

  5. 近期爬虫学习体会以及爬豆瓣Top250源码实战

    近期爬虫学习体会以及爬豆瓣Top250源码实战 我是在B站https://www.bilibili.com/video/BV12E411A7ZQ?p=25里学习的,至今已经可以手写爬豆瓣Top250代 ...

  6. 深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习

    深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习 文章目录 深入浅出图神经网络|GNN原理解析☄学习笔记(四)表示学习 表示学习 表示学习的意义 离散表示与分布式表示 端到端学习 基于重构损 ...

  7. 《深入浅出图神经网络》读书笔记 1-2

    <深入浅出图神经网络>读书笔记 1-2 第1章 图的概述 第2章 神经网络基础 2.1 机器学习基本概念 2.2 神经网络 2.4 训练神经网络 第1章 图的概述 图神经网络(Graph ...

  8. 静态网页-改图宝(附源码)

    静态网页------改图宝(赋源码) 这是人家官网的: 我做的是这样的: 总结 二级菜单有些问题,使用百分比后缩小网页时和一级菜单的一个边框对不齐,使用精准的像素后放大网页时对不齐 页脚用行块盒真的省 ...

  9. 最新COS美图在线写真站源码+去授权版

    正文: 价值200_元的美图在线写真站源码去授权版,无需授权,源码上传服务器即可使用,网站内容自行在config.php配置. 程序: wwefss.lanzoul.com/iOhZC0cok65i ...

最新文章

  1. 5G时代下,边缘计算产品的未来展望
  2. Linux修改主机名的方法
  3. 滚蛋吧,流量!红利已见顶,是时候步入「留量时代」了
  4. Exchange Server 2003群集系统方案
  5. svm学习之线性部分总结
  6. oracle增量备份0级,oracle_linux自动运行rman增量备份脚本,一、增量备份脚本0级备份脚本...
  7. 拳王虚拟项目公社:闲鱼操作虚拟资源的案例拆解,教你玩转闲鱼虚拟资源,货源+方法
  8. java核心技术 pdf下载_JAVA程序员面试秘笈 PDF 下载_Java知识分享网
  9. AI队列长度检测:R-CNN用于使用Keras进行自定义对象检测
  10. 为什么要使用 React-Redux?
  11. 程序员社区骂战:不满政治正确,LLVM元老宣布退出
  12. Android开发技巧——PagerAdapter再简单的包
  13. java 苹果支付(内购)
  14. JAVA疫情数据可视化系统毕业设计 开题报告
  15. Android系统的系统运行库层,Android系统框架
  16. Unix编程常见问题解答
  17. 生态建设发展势头迅猛,OKB未来价值空间广阔
  18. [转]关于计算机研究生报考方向的简要介绍
  19. ALL in —— 雷军的极致
  20. 关于CTF竞赛的了解

热门文章

  1. 河南是中国的一个缩影.
  2. 耳鸣是什么原因造成的?
  3. 现代控制理论习题解答与Matlab程序示例
  4. 批量友情链接监控检测查询工具
  5. (JavaScript学习记录):jQuery 样式操作
  6. c语言编程判断输入的一个字符串是否是“回文”。所谓“回文”字符串就是左读和右读都一样的字符串。例如: “abcba“就是一个回文字符串。
  7. 高斯消元法python编程_高斯消元法c语言实现
  8. commons email 简介
  9. pytorch之NIN
  10. 微信最新8.0.8版:专属铃声、折叠聊天、特别关注!