图神经网络相似度计算

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
今天和大家分享一篇关于图神经网络相似度计算的论文
SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
#博学谷IT学习技术支持#

文章目录

  • 图神经网络相似度计算
  • 前言
  • 一、训练数据
  • 二、模型的输入
  • 三、图神经网络提取更新每个点的信息
  • 四、计算点和点之间的关系得到直方图特征
  • 四、Attention Layer 得到图的特征
  • 五、运用NTN网络计算图和图之间的关系得到特征
  • 六、预测得到模型的结果
  • 总结

前言

图神经网络是当下比较火的模型之一,使用神经网络来学习图结构数据,提取和发掘图结构数据中的特征和模式,满足聚类、分类、预测、分割、生成等图学习任务需求的算法。本文是主要通过图神经网络来对两个图的相似性进行快速打分的模型。


一、训练数据

本文采用torch内置数据集GEDDataset,直接调用就可以了,数据集一共有700个图,每个图最多有10个点组成,每个点由29种特征组成

代码如下(示例):

 def process_dataset(self):"""Downloading and processing dataset."""print("\nPreparing dataset.\n")self.training_graphs = GEDDataset("datasets/{}".format(self.args.dataset), self.args.dataset, train=True)self.testing_graphs = GEDDataset("datasets/{}".format(self.args.dataset), self.args.dataset, train=False)

二、模型的输入

每次输入两幅图,包含边的信息了,点的特征

代码如下(示例):

 def forward(self, data):edge_index_1 = data["g1"].edge_indexedge_index_2 = data["g2"].edge_indexfeatures_1 = data["g1"].xprint(features_1.shape)features_2 = data["g2"].xprint(features_2.shape)batch_1 = (data["g1"].batchif hasattr(data["g1"], "batch")else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes))batch_2 = (data["g2"].batchif hasattr(data["g2"], "batch")else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes))

三、图神经网络提取更新每个点的信息

这里运用直方图方式做特征比较新颖。

    def convolutional_pass(self, edge_index, features):"""Making convolutional pass.:param edge_index: Edge indices.:param features: Feature matrix.:return features: Abstract feature matrix."""features = self.convolution_1(features, edge_index)features = F.relu(features)features = F.dropout(features, p=self.args.dropout, training=self.training)features = self.convolution_2(features, edge_index)features = F.relu(features)features = F.dropout(features, p=self.args.dropout, training=self.training)features = self.convolution_3(features, edge_index)return features
#每个点都走三层gcn
abstract_features_1 = self.convolutional_pass(edge_index_1, features_1)
print(abstract_features_1.shape)
abstract_features_2 = self.convolutional_pass(edge_index_2, features_2)
print(abstract_features_2.shape)

四、计算点和点之间的关系得到直方图特征

    def calculate_histogram(self, abstract_features_1, abstract_features_2, batch_1, batch_2):abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)B1, N1, _ = abstract_features_1.size()B2, N2, _ = abstract_features_2.size()mask_1 = mask_1.view(B1, N1)mask_2 = mask_2.view(B2, N2)num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))scores = torch.matmul(abstract_features_1, abstract_features_2.permute([0, 2, 1])).detach()hist_list = []for i, mat in enumerate(scores):mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)hist = torch.histc(mat, bins=self.args.bins)hist = hist / torch.sum(hist)hist = hist.view(1, -1)hist_list.append(hist)print(torch.stack(hist_list).view(-1, self.args.bins).shape)return torch.stack(hist_list).view(-1, self.args.bins)
if self.args.histogram:hist = self.calculate_histogram(abstract_features_1, abstract_features_2, batch_1, batch_2)

四、Attention Layer 得到图的特征

    def forward(self, x, batch, size=None):size = batch[-1].item() + 1 if size is None else sizemean = scatter_mean(x, batch, dim=0, dim_size=size)transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))weighted = coefs.unsqueeze(-1) * xreturn scatter_add(weighted, batch, dim=0, dim_size=size)pooled_features_1 = self.attention(abstract_features_1, batch_1)
pooled_features_2 = self.attention(abstract_features_2, batch_2)

五、运用NTN网络计算图和图之间的关系得到特征

def forward(self, embedding_1, embedding_2):batch_size = len(embedding_1)scoring = torch.matmul(embedding_1, self.weight_matrix.view(self.args.filters_3, -1))scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) #filters_3可以理解成找多少种关系scoring = torch.matmul(scoring, embedding_2.view(batch_size, self.args.filters_3, 1)).view(batch_size, -1)combined_representation = torch.cat((embedding_1, embedding_2), 1)block_scoring = torch.t(torch.mm(self.weight_matrix_block, torch.t(combined_representation)))scores = F.relu(scoring + block_scoring + self.bias.view(-1))return scores

六、预测得到模型的结果

 def process_batch(self, data):self.optimizer.zero_grad()data = self.transform(data)target = data["target"]prediction = self.model(data)loss = F.mse_loss(prediction, target, reduction="sum")loss.backward()self.optimizer.step()return loss.item()

总结

本文通过点和点的比较,加上图和图的比较,结合在一起,最后计算出两幅图的相似度。其中运用到GCN ,NTN,ATTENTION,直方图等方法。较为有创意。

图神经网络相似度计算相关推荐

  1. 图神经网络如何对知识图谱建模? | 赠书

    几乎所有早期的知识图谱嵌入的经典方法都是在对每个三元组打分,在实体和关系的表示中并没有完全考虑到整幅图的结构. 早期,图神经网络的方法在知识图谱嵌入中并没有被重视,主要由于: 早期的图神经网络更多是具 ...

  2. 【读文献笔记】图神经网络加速结构综述

    [读文献笔记]图神经网络加速结构综述 前言 一.图神经网络来源 1.图神经网络用途 2.图神经网络特点 3.图神经网络主要阶段 4.图神经网络加速面临的挑战 5.本笔记内容包含内容 二.图与图神经网络 ...

  3. 基于多尺度图神经网络的流场预测,实现精度与速度的平衡

    项目简介 本项目来源于飞桨AI for Science共创计划的论文复现赛题,复现论文为<AMGNET: multi-scale graph neural networks for flow f ...

  4. 图神经网络代码_第一篇:图神经网络(GNN)计算框架绪论

    写在开头: 这个专栏是为了总结我本科毕业设计中所设计的题目<基于GPU的图神经网络算法库的设计钰实现>.这半年来一直在这个方向上啃代码,读论文,真的学到了很多东西.尤其是阅读了大佬团队写的 ...

  5. 计算未来轻沙龙 | 当深度学习遇上归纳推理,图神经网络有多强大?

    作为一名新世纪的深度学习炼丹师 是否整天面对各种结构的原(shu)料(ju)? 对于无规则的空间数据 传统炼丹大法好像并不能发挥奇效 图作为一种非常神奇的表示方式 可以表示生活中绝大多数现象或情境 那 ...

  6. 图神经网络应用——基于深度学习的图相似度计算(以SIMGNN为例的保姆级讲解)

    为啥想写这篇文章呢..因为之前提到的图神经网络应用篇鸽了一年多了,把自己的研究方向做一个总结,并向其他同样研究方向的朋友做一个报告,如有错误,敬请指出.而且,这个研究方向人太少了,万望能借此引起更多人 ...

  7. 实录分享 | 计算未来轻沙龙:图神经网络前沿研讨会

    2019年12月1日,PaperWeekly携手清华大学计算机系推出的计算未来轻沙龙之图神经网络前沿研讨会在清华大学FIT楼二层多功能报告厅成功举办. 本次论坛邀请了清华大学计算机系硕士生岑宇阔.博士 ...

  8. 神经网络特征图计算_GNNFiLM:基于线性特征调制的图神经网络

    GNN-FiLM:基于线性特征调制的图神经网络 论文链接:https://arxiv.org/abs/1906.12192v3 源代码:https://github.com/Microsoft/tf- ...

  9. 软考知识点——Gant图与Pert图、McCabe复杂度计算

    目录 一.Gant图与Pert图 1.Gant图与Pert图的概念 2.关键路径.总时差.松弛时间 3.真题 (1)2021下半年软考上午真题18~19 (2)2021上半年软考上午真题17~18 二 ...

最新文章

  1. jQuery中 :first 和 :last 选择器诡异问题
  2. tensorboard的初次使用
  3. 亚太地区数学建模优秀论文_数学建模美赛强势来袭!
  4. 彻底理解js中的和||
  5. 使用ubuntu18搭建nfs分布式文件系统
  6. matlab有限元分析程序,matlab有限元分析与应用(书及源程序)
  7. Hive on spark执行子查询报错code3
  8. Maven实战_许晓斌
  9. php如何实现快速压缩视频,如何把大视频压缩小 怎么将视频压缩到最小方便储存...
  10. 数据分析师职业规划——数据分析师的职业焦虑与未来发展
  11. 构建一个代号为1的聊天应用程序4
  12. 基于opencv和pillow实现人脸识别系统(附详细源代码)
  13. 喝酒聚会神器小程序部署
  14. Flutter 修改App Logo图标
  15. ATTCK靶场系列(二)
  16. 知乎周源微信_每周源代码7
  17. Android studio中的警告Hardcoded text
  18. 数据库与MPP数仓(十九):高效SQL
  19. OLAP介绍(zhuanzai)
  20. 【Unity3d】在Unity3d中使用百度AI人脸识别功能

热门文章

  1. 如何用FLASH做网页背景
  2. 安装Ubuntu+gpu+tensorflow+py2+py3
  3. ng-appdata-ng-app
  4. 王者荣耀s22服务器维护,王者荣耀安卓无法更新解决方法 S22更新问题汇总
  5. java面试题大全必备神器
  6. match_phrase 跨值查询中 position_increment_gap 参数用法
  7. IOS中Socket详解
  8. VUE插槽/插槽传参
  9. ML之RF:基于葡萄牙银行机构营销活动数据集(年龄/职业/婚姻/违约等)利用Pipeline框架(两种类型特征并行处理)+RF模型预测(调参+交叉验证评估+模型推理)客户是否购买该银行的产品二分类案例
  10. NDCG指标——qjzcy的博客