MLP、GCN、GAT在数据集citeseer等上的节点分类任务

算是GNN的helloworld,直接上代码,注释很详细

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 18 19:10:05 2022@author: lz
"""from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root = 'dataset', name='CiteSeer', transform=NormalizeFeatures())print()
print(f'Dataset:{dataset}')
print(f'Number of Graph:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')data = dataset[0]print()
print(data)print()
print(f'Number of nodes:{data.num_nodes}')
print(f'Number of edges:{data.num_edges}')
print(f'Average node degree:{data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes:{data.train_mask.sum()}')
print(f'Training node label rate:{data.train_mask.sum() / data.num_nodes:.2f}')
print(f'Contains isolated nodes:{data.has_isolated_nodes()}')
print(f'Contains self-loops:{data.has_self_loops()}')
print(f'Is undirected:{data.is_undirected()}')'''
可视化节点表征分布的方法
'''
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize(h, color):z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")plt.show()'''
MLP神经网络的构造
'''
import torch
from torch.nn import Module
from torch.nn import Linear
import torch.nn.functional as Fclass MLP(Module):def __init__(self, hidden_channels):super(MLP, self).__init__()torch.manual_seed(12345)self.lin1 = Linear(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.lin2 = Linear(hidden_channels, dataset.num_classes)def forward(self, x):x = self.lin1(x)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.lin2(x)return xmodel = MLP(hidden_channels=16)
print()
print('MLP神经网络的构造')
print(model)print()
print('利用交叉熵损失和Adam优化器来训练这个简单的MLP神经网络')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay= 5e-4)def train():model.train()#!optimizer.zero_grad()out = model(data.x)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()#!optimizer.step()return lossprint('开始训练')for epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('看看测试集上的表现')def test():model.eval()#!out = model(data.x)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]#检查标签是否正确test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) return test_acctest_acc = test()
print(f'Test Accuracy:{test_acc:.4f}')print('将MLP中的torch.nn.Linear 替换为torch_geometric.nn.GCNConv,我们就可以得到一个GCN网络')
from torch_geometric.nn import GCNConvclass GCN(Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.conv2 = GCNConv(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.conv2(x, edge_index)return xmodel = GCN(hidden_channels=16)
print(model)        print()
print('可视化未经训练的GCN生成的节点表征')
model.eval()   out = model(data.x, data.edge_index)
visualize(out, color=data.y)print()
print('训练GCN图神经网络')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)#进行一次正向计算loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('测试集上的准确性')
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]test_acc = int(test_correct.sum()) / int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')print()
print('可视化训练后的GCN生成的节点表征')
model.eval()out = model(data.x, data.edge_index)
visualize(out, color=data.y)print()
print('将MLP中的torch.nn.Linear 替换为torch_geometric.nn.GCNConv,我们就可以得到一个GCN网络')
from torch_geometric.nn import GATConv
class GAT(Module):def __init__(self, hidden_channels):super(GAT, self).__init__()torch.manual_seed(12345)self.conv1 = GATConv(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.conv2 = GATConv(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.conv2(x, edge_index)return xmodel = GAT(hidden_channels=16)
print(model)print()
print('训练GAT图神经网络')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)#进行一次正向计算loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('测试集上的准确性')
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]test_acc = int(test_correct.sum()) / int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')print()
print('可视化训练后的GAT生成的节点表征')
model.eval()out = model(data.x, data.edge_index)
visualize(out, color=data.y)

datawhalechina-GNN组队学习 作业:PyG不同模块在PyG数据集上的应用相关推荐

  1. 迁移学习的使用技巧和在不同数据集上的选择

    迁移学习的使用技巧和在不同数据集上的选择 1.迁移学习是指调整预训练的神经网络并应用到新的不同数据集上. 根据以下两个方面:新数据集的大小,以及新数据集和原始数据集之间的相似性 使用迁移学习的方式将不 ...

  2. Datawhale组队学习周报(第035周)

    希望开设的开源内容 目前Datawhale的开源内容分为两种:第一种是已经囊括在我们的学习路线图内的Datawhale精品课,第二种是暂未囊括在我们的学习路线图内的Datawhale打磨课.我们根据您 ...

  3. Datawhale组队学习周报(第032周)

    希望开设的开源内容 目前Datawhale的开源内容分为两种:第一种是已经囊括在我们的学习路线图内的Datawhale精品课,第二种是暂未囊括在我们的学习路线图内的Datawhale打磨课.我们根据您 ...

  4. 跟优秀的人一起进步:四月组队学习

    Datawhale学习 主办:Datawhale,人民邮电出版社异步社区 寄语:本次组队学习涵盖了机器学习算法.计算机视觉.Pandas.爬虫编程实践四个模块的内容. 第十期:Datawhale联合伯 ...

  5. Datawhale组队学习周报(第028周)

    吼一嗓子: 如果您有开源的内容希望通过组队学习的方式与大家分享,那么请跟我联系,我们来排期. 如果您对Datawhale某一门开源内容感兴趣,希望跟我们一起为学习者答疑解惑,那么请跟我联系,我们来排期 ...

  6. Datawhale组队学习周报(第027周)

    吼一嗓子: 如果您有开源的内容希望通过组队学习的方式与大家分享,那么请跟我联系,我们来排期. 如果您对Datawhale某一门开源内容感兴趣,希望跟我们一起为学习者答疑解惑,那么请跟我联系,我们来排期 ...

  7. AI学习笔记(十一)CNN之图像识别(上)

    AI学习笔记之CNN之图像识别(上) 图像识别 图像识别简介 模式识别 图像识别的过程 图像识别的应用 分类与检测 VGG Resnet 迁移学习&inception 卷积神经网络迁移学习fi ...

  8. HNU工训中心STC-B学习板大作业-基于OLED模块的多功能MP4

    主要功能在下面这张流程图里(直接用报告的流程图了) 下面展示一下效果(数码管的"welcome"比较抽象) ps. 后面新加的功能(我觉得MP4应该还具有看小说的功能,但是小说字太 ...

  9. python编程语言的优缺点_组队学习优秀作业 | Python的发展历史及其前景

    ↑↑↑关注后"星标"BioPython每日干货 & 每月组队学习,不错过BioPython学习 开源贡献: BioPython团队 创始人 Guido van Rossum ...

最新文章

  1. 报名 | DeeCamp2019:实战AI 铸造定雨神针
  2. 瑞友虚拟服务器网页登录,瑞友云端虚拟专网系统
  3. 深入剖析 iOS 编译 Clang LLVM(编译流程)
  4. python语言句块的标记_Python语言语句块的标记是()
  5. tp5模板 使用php代码,thinkPHP的Html模板标签使用方法
  6. js Math用法jquery是否为空对象判断
  7. java并发中的延迟初始化
  8. Win8消费者预览版下载地址 包含中文下载地址及中文手册
  9. 将Javascript带到边缘设备
  10. python可迭代对象,迭代器,生成器
  11. SQL Server删除整个数据库中表数据的方法(不影响表结构)
  12. sqoop导入/导出
  13. android绘制过程3d图形,Android开发 OpenGL ES绘制3D 图形实例详解
  14. [转] C# mysql 事务回滚
  15. 格雷码转换成二进制c语言程序,各位老师格雷码和二进制有什么区别,怎么转换....
  16. 软件设计文档——概要设计书
  17. 华为路由器IPv6 over IPv4 GRE隧道配置详解
  18. iPhone iPad Cydia 软件源 大全
  19. 方面级情感分析论文泛读02:Syntax-Aware Aspect-Level Sentiment Classification with Proximity-Weighted Convolution
  20. 老A:什么是抖音弹幕互动游戏,玩法以及如何参与

热门文章

  1. 三,容器类型及其函数(列表,元组,字典,集合)
  2. frp工具实现内网穿透以及配置多个ssh和web服务
  3. python 实现多线程编程
  4. Qt开发 — QtQuick无法加载
  5. Web前端开发技术————期末编程例题
  6. Date ----数码时钟
  7. ZEMAX中的非球面参数解释
  8. 软件测试入门知识——学习笔记
  9. excel表格中打开可以显示整个表格但是打印却只能打印一个单元格
  10. vs2010使用过程中的问题和解决、vs密钥