title: ARNN复现反思
date: 2022-04-26 22:37:56
tags: NLP的一些收获

因为找遍了一二三四作,都没有能得到An Attentional Recurrent Neural Networkfor Personalized Next Location Recommendation这篇论文的代码,一作没反应,二三四都让我找一作…麻了,所以硬下头皮准备复现。

其实任务量还好,最幸运的是这篇论文的模型架构与另外一篇DeepMove的模型十分相似,都是先embedding序列后,对序列元素进行attention的思路,不过也是有很多不同的。

思路

这篇论文的思路很清晰,先将check in序列处理好,得到用户的历史轨迹,这个历史轨迹包括loc、time和word的序列,分别把它们用相应的维度embedding后,在第三个维度拼接起来得到tensor:x(batch_size, num, dim),这样这个轨迹序列的元素就融合了地点、时间、语义上的含义。

之后对于历史轨迹的每一个loc,与该地点的所有neighbors求相似度,然后加权进行一次attention,得到targer,也就是与之最相似的loc向量,结果为ck。

然后将x和ck同样在第三个维度拼接起来得到一个新的tensor,让每一个位置的元素融合入与其它loc的转移关系,然后将它pack后输入LSTM,取出最后一个hidden state,融合入user的embedding,最后用softmax得到next poi的概率分布。

其中的loc的neighbors得到的方法是使用基于meta path的随机游走模型得到的,将历史轨迹序列构成图,我这里的操作其实和pageRank的处理方法类似,搞一个邻接表,然后严格按照原路径的类型进行游走,将访问到的loc纳入path,也就是起点loc的neighbor。具体做法详见上篇blog。

遇到的问题

(1)在随机游走时,太慢了,虽然现在也不算快,但是一开始参照一位githuber的pageRank代码改造,是用dict代替多维list,连每一步的带权乘法都要自己用for循环写,很慢。后来想着可以把它搞成矩阵,然后转化为tensor,既能调用torh里的乘法,还可以放到GPU上运算,所以就这么做了,真的有很大的改进,但确实也不算快。

(2)不会写attention,第一次复现嘛,一开始很傻,从tensor中一条一条数据遍历,然后找到对应的loc_id,再通过loc_id找到对应的neighbors的序列,然后再对neighbors embedding…太繁琐了,导致一个batch就要两三分钟。

所以,我思考了一下,可不可以把neighbors的embedding也做成一个tensor,然后让二者去运算,这样是可以调用torch.matmul,方便加速的。但是问题在于不同的loc邻居的数量也不同,所以我采用的办法是取游走获得的path中出现次数最多的前n个邻居作为loc的neighbors,这样维度就统一了。

假设batch_size为128,n是10,dim是100,那么一开始的loc_neighbor_emb就是(128, 464, 100, 10),原本的loc_emb是(128, 464, 100),为了方便相乘,unsqueeze一下为(128, 464, 1, 100),这样二者的batch就统一了,为128*464,因为torch.matmul规定四维tensor的运算前两维为batch,前者转置一下,相乘后softmax就是相似度的矩阵了,大小为(128, 464, 1, 10)。这个大小一看就很对,对于每个loc都有10个neighbors对应的weight。最后再将其和neighbors的embedding相乘,得到最终的结果。

然后用这样的方法再尝试,果然快了很多,一个batch就20s左右。时间估计都用在attention前loc_neighbor_emb的构建上了。

(3)每个epoch的最后一个batch不足batch_size,一开始我还想着continue过去,但是想想会影响到测试集到验证的,所以就查了查,发现有解决办法,在DataLoader中设置参数drop_last = True,其实也是drop掉了,不过在一开始就去掉了,不会产生影响。

总结

总结一下,最困难的部分也就是核心模块的编写和随机游走的编写了,对于框架代码的编写其实没有涉足过多,毕竟是拿别人的代码改动的,下次有机会还是尝试一下自己从0开始,体会应该更深刻一些吧。另外就是一些torch的函数,了解的还是太少,还是应该多分析分析别人的代码。

最后

代码:ARNN-master

我学识鄙陋,有问题一定要告诉我!

An Attentional Recurrent Neural Networkfor Personalized Next Location Recommendation【ARNN】代码复现反思相关推荐

  1. Nerf(Representing Scenes as Neural Radiance Fields for View Synthesis)代码复现笔记

    前言:本文旨在帮助小白快速了解or学习复现出Nerf的代码,整体结构保持不变,不过会针对部分细节为了更好理解进行了修改. 本文会相应更新讲解视频于B站,id 出门吃三碗饭,有问题到b站评论区留言 同步 ...

  2. Paper:《Generating Sequences With Recurrent Neural Networks》的翻译和解读

    Paper:<Generating Sequences With Recurrent Neural Networks>的翻译和解读 目录 Generating Sequences With ...

  3. Attention和增强RNN (Attention and Augmented Recurrent Neural Networks)

    原文: Attention and Augmented Recurrent Neural Networks 递归神经网络是一种主流的深度学习模型,它可以用神经网络模型来处理序列化的数据,比如文本.音频 ...

  4. Paper:RNN之《Generating Sequences With Recurrent Neural Networks用循环神经网络生成序列》的翻译和解读

    Paper:<Generating Sequences With Recurrent Neural Networks>的翻译和解读 目录 Generating Sequences With ...

  5. 【多标签文本分类】Ensemble Application of Convolutional and Recurrent Neural Networks for Multi-label Text

    ·阅读摘要:   本文提出基于Seq2Seq模型,提出CNN-RNN模型应用于多标签文本分类.论文表示CNN-RNN模型在大型数据集上表现的效果很好,在小数据集效果不好. ·参考文献:   [1] E ...

  6. (zhuan) Recurrent Neural Network

    Recurrent Neural Network 2016年07月01日 Deep learning Deep learning 字数:24235 this blog from: http://jxg ...

  7. RNN(Recurrent Neural Network)的几个难点

    1. vanish of gradient RNN的error相对于某个时间点t的梯度为: \(\frac{\partial E_t}{\partial W}=\sum_{k=1}^{t}\frac{ ...

  8. Recurrent Neural Network系列2--利用Python,Theano实现RNN

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  9. 循环神经网络教程Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs

    Recurrent Neural Networks (RNNs) are popular models that have shown great promise in many NLP tasks. ...

最新文章

  1. mysql教程联合索引_MySQL中的联合索引学习教程
  2. java基础----IO序列化Serializable
  3. http post请求 参数放在路径后面 java_「思唯网络学院」网络基本概念之HTTP协议...
  4. 小程序WXML基本使用
  5. sql order by 结合case when then
  6. nginx配置文件祥解
  7. sql优化基数和耗费_基数估计在SQL Server优化过程中的位置
  8. 【语音分析】基于matlab GUI语音信号线性预测(LPC)分析【含Matlab源码 910期】
  9. NLPIR系统的中文语义分析模式介绍
  10. 统计学、统计学习和统计推断之间的关系
  11. inode客户端linux 怎样运行,H3C_iNode智能客户端安装指导(Linux)
  12. UBUNTU18.04系统安装打印机
  13. C语言经典题目50题
  14. lr mysql 增删改查_ssh增删改查流程
  15. 数独android程序,简单实现Android数独游戏
  16. 电商设计的文字的选择与排版
  17. 网易云音乐热评的规律,44万条数据告诉你
  18. LeGo-LOAM激光雷达定位算法源码阅读(二)
  19. 你总是喜欢抓不住的东西
  20. [Leetcode] 33. Search in Rotated Sorted Array 解题报告

热门文章

  1. 《Graph Neural Networks Foundations,Frontiers and Applications》第一部分第一章1.1节翻译和解读
  2. 世界杯php源码 haoid,DoYouHaoBaby(PHP开发框架)v2.5.2 Release20130727
  3. 第20章 一些随机波动率模型的近似解
  4. 福昕阅读器软件foxit设置快…
  5. Foxit Reader(福昕阅读器)(CVE-2020-14425)命令注入漏洞复现
  6. 计算机毕业设计之吊炸天Python+Spark电影推荐系统 电影采集大数据分析 电影购票系统 电影购票小程序app 电影院管理系统 电影数据分析大屏
  7. S3C2440上MMC/SD卡驱动分析(二)
  8. 云客Drupal源码分析之系统AJAX(一):概述与示例
  9. acwing 4378
  10. 用JSP完成图书信息查询功能