Siamese+LSTM

网络结构模块

  • 孪生网络

    首先要理解什么是孪生网络模块,我们在词嵌入和编码(LSTM)过程中使用的是相同的参数,比如说我下面的代码中,在编码阶段,对于两个句子的输入,我都使用了相同的LSTM,这也就是Siamese+LSTM。

  • 疑问
    不过我有个疑问,我在某些文章中看到用LSTM来判断句子相似度,它不是Siamese+LSTM,在文章中对比的两种方法,一种是Siamese+LSTM,还有一种是LSTM。我也不清楚这个单独的LSTM是什么意思,后来我思考了一下,可能是在对两个句子进行编码的时候,使用了两个不同的LSTM结构(我猜是这样)。

  • 编码
    编码完之后,也就是LSTM的输出阶段。LSTM的输出尺寸是[len,batch,hidden],因为在LSTM结构中我没有去声明batch_first这个属性,对LSTM的输出,我选取了最后一个输出,相当于final_ht。

  • 距离
    两个尺寸都是[batch,hidden],然后对两个句子取绝对值的差作为全连接神经网络的输入,torch.abs(),这里有个dim参数有时候有用,我在这里没写,正常情况下应该是dim=-1。

  • sigmoid()函数
    关于最后为什么使用sigmoid函数,我最终是把输出的值通过sigmoid函数归到[0,1],最终loss函数采用的nn.BCEloss()。当然这里你也可以进行二分类,我试了一下,感觉没有输出[0,1]之间的效果好,虽然效果都不是很好。。。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import picklefile1 = r"./data/vocab.pkl"
file2 = r"./data/vocabs_matrix.pkl"vocabs = pickle.load(open(file1,'rb'))
Embedding_matrix = pickle.load(open(file2,'rb'))
Vocab_size = len(vocabs)class LSTM1(nn.Module):def __init__(self):super(LSTM1, self).__init__()self.Vocab_size = Vocab_sizeself.batch_size = 500self.input_size = 300self.n_hidden1 = 128self.Embedding_dim = 300self.n_class=2self.seq_len = 20self.dropout = nn.Dropout(0.2)self.Embedding_matrix = Embedding_matrixself.word_embeds = nn.Embedding(self.Vocab_size+1, self.Embedding_dim)pretrained_weight = np.array(self.Embedding_matrix)self.word_embeds.weight.data.copy_(torch.from_numpy(pretrained_weight))self.Lstm1 = nn.LSTM(self.Embedding_dim, hidden_size=self.n_hidden1, bidirectional=False)#self.fc = nn.Linear(self.n_hidden1*2,self.n_class,bias=False)self.fc1 = nn.Linear(self.n_hidden1,32,bias=False)self.b1 = nn.Parameter(torch.rand([32]))self.fc2 = nn.Linear(32,1,bias=False)self.b2 = nn.Parameter(torch.rand([1]))passdef forward(self,train_left,train_right):device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')train_left = self.word_embeds(train_left).to(device)train_right = self.word_embeds(train_right).to(device)train_left = train_left.transpose(0,1)train_right = train_right.transpose(0,1)hidden_state1 = torch.rand(1,self.batch_size,self.n_hidden1).to(device)cell_state1 = torch.rand(1,self.batch_size,self.n_hidden1).to(device)outputs1_L,(final_state1_L,_) =self.Lstm1(train_left,(hidden_state1,cell_state1))outputs1_L = self.dropout(outputs1_L)outputs1_R,(final_state1_R,_) =self.Lstm1(train_right,(hidden_state1,cell_state1))outputs1_R = self.dropout(outputs1_R)outputs1 = outputs1_L[-1]outputs2 = outputs1_R[-1]output = torch.abs(outputs1-outputs2)output = self.fc1(output)+self.b1output = self.dropout(output)output = self.fc2(output)+self.b2output = torch.sigmoid(output)return outputpass

train模块

def train(model, device, train_dataloader, optimizer, epoch):model.train()train_loss = 0num_correct = 0for batch_idx,(train_left,train_right,lables) in enumerate(train_dataloader):train_left = train_left.to(device)train_right = train_right.to(device)lables = lables.to(device)optimizer.zero_grad()output = model(train_left,train_right)output = output.view_as(lables)loss = loss_fn(output,lables)loss.backward()optimizer.step()train_loss += float(loss.item())true = lables.data.cpu()predict = torch.round(output).cpu()num_correct += torch.eq(predict, true).sum().float().item()total_len = len(train_dataloader.dataset)train_acc = num_correct / len(train_dataloader.dataset)train_loss = train_loss/ len(train_dataloader)print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t Acc: {:.6f}'.format(epoch,batch_idx * len(train_left),len(train_dataloader.dataset),100. * batch_idx / len(train_dataloader),train_loss,train_acc))

结果(数据集采用的是Quora Question Pairs)

D:\Anaconda3\envs\python36\python.exe D:/TextMatching/NewProject2/train.py
Learning rate:  0.01
Epochs:  10
Training on 363861 samples...
Train epoch: 1 [363000/363861 (100%)]    Loss: 0.450285      Acc: 0.787457
Make prediction for 40429 samples...
F1 improved at epoch: 1 ; best_F1:0.75543 ; best_Accuracy:0.79368 ; best_Precision:0.67190 ; best_Recall:0.86267Training on 363861 samples...
Train epoch: 2 [363000/363861 (100%)]    Loss: 0.368117      Acc: 0.833332
Make prediction for 40429 samples...
F1:0.75543 Accuracy:0.79368 Precision:0.67190 Recall:0.86267 No improvement since epoch: 1Training on 363861 samples...
Train epoch: 3 [363000/363861 (100%)]    Loss: 0.336102      Acc: 0.848426
Make prediction for 40429 samples...
F1 improved at epoch: 3 ; best_F1:0.76229 ; best_Accuracy:0.79602 ; best_Precision:0.66912 ; best_Recall:0.88559Training on 363861 samples...
Train epoch: 4 [363000/363861 (100%)]    Loss: 0.320470      Acc: 0.855816
Make prediction for 40429 samples...
F1 improved at epoch: 4 ; best_F1:0.77047 ; best_Accuracy:0.81200 ; best_Precision:0.70121 ; best_Recall:0.85491Training on 363861 samples...
Train epoch: 5 [363000/363861 (100%)]    Loss: 0.308956      Acc: 0.862206
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Training on 363861 samples...
Train epoch: 6 [363000/363861 (100%)]    Loss: 0.302345      Acc: 0.864965
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Training on 363861 samples...
Train epoch: 7 [363000/363861 (100%)]    Loss: 0.296583      Acc: 0.867658
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Training on 363861 samples...
Train epoch: 8 [363000/363861 (100%)]    Loss: 0.293175      Acc: 0.868862
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Training on 363861 samples...
Train epoch: 9 [363000/363861 (100%)]    Loss: 0.288697      Acc: 0.871360
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Training on 363861 samples...
Train epoch: 10 [363000/363861 (100%)]   Loss: 0.287799      Acc: 0.871404
Make prediction for 40429 samples...
F1:0.77047 Accuracy:0.81200 Precision:0.70121 Recall:0.85491 No improvement since epoch: 4Process finished with exit code 0

最后,如果是做文本相似方向的可以和我一起交流一下,我也是个新手。。

Siamese+LSTM+Distance(abs)相关推荐

  1. 面向中文自然语言处理的60余类系统开源实践项目与工业探索索引

    项目介绍 面向中文自然语言处理的六十余类实践项目及学习索引,涵盖语言资源构建.社会计算.自然语言处理组件.知识图谱.事理图谱.知识抽取.情感分析.深度学习等几个学习主题.包括作者个人简介.学习心得.语 ...

  2. 系统学习NLP(十七)--文本相似度

    转自:https://blog.csdn.net/qq_28031525/article/details/79596376 看原文吧,这里公式改不过来,烂 在自然语言处理(Natural Langua ...

  3. 一文详解文本语义相似度的研究脉络和最新进展

    每天给你送来NLP技术干货! ©作者 | 崔文谦 单位 | 北京邮电大学 研究方向 | 医学自然语言处理 编辑 | PaperWeekly 本文旨在帮大家快速了解文本语义相似度领域的研究脉络和进展,其 ...

  4. 标题相似度算法_智能客服问题相似度算法设计——第三届魔镜杯大赛第12名解决方案...

    目录: 一.比赛介绍 二.数据介绍 三.解决方案 (一)问题分析 (二)数据探索 (三)模型 (四)调参 (五)特征工程 (六)模型集成 (七)后处理 四.比赛总结 (一)比赛成绩 (二)代码分享 ( ...

  5. 2018 ATEC NLP比赛 15th 总结

    这次比赛跟以往的比赛似乎很不一样(虽然这个是我第一次参加),以往比赛的特征技巧,融合技巧,以及一些典型的模型都在这次比赛都失效.我一度怀疑蚂蚁金服是故意设计了数据.... 赛题介绍 问题相似度计算,即 ...

  6. LeetCode简单题之找到最近的有相同 X 或 Y 坐标的点

    题目 给你两个整数 x 和 y ,表示你在一个笛卡尔坐标系下的 (x, y) 处.同时,在同一个坐标系下给你一个数组 points ,其中 points[i] = [ai, bi] 表示在 (ai, ...

  7. 数据挖掘技术在出行体验上的应用!

    桔妹导读:每天滴滴都会为上千万人提供出行服务,在这一过程中积累了海量轨迹数据.这些轨迹数据来自于公共服务,本文介绍如何利用这些数据回馈大众,改善出行体验. 1.  背景 首先简要介绍一下什么是数据挖掘 ...

  8. android绘制高亮区域,实现高亮某行的RecyclerView效果

    最终效果 全部代码:github 方式有二 组合控件,RecyclerView + View 自定义RecyclerView 1中只需要控制View,但是不好封装. 2中需要重写RecyclerVie ...

  9. (数据挖掘-入门-1)基于用户的协同过滤之最近邻

    主要内容: 1.什么是基于用户的协同过滤 2.python实现 1.什么是基于用户协同过滤: 协同过滤:Collaborative Filtering,一般用于推荐系统,如京东,亚马逊等电商网站上的& ...

最新文章

  1. Win7 64位的SSDTHOOK(1)---SSDT表的寻找
  2. Python之range和xrange的区别
  3. Autodesk Maya 2019中文版
  4. java学绘图吗_Java绘图
  5. 硬盘的原理以及SQL Server如何利用硬盘原理减少IO
  6. 寻找唯一特等奖java,大工斩获唯一特等奖!这次,请为我工老师疯狂打call!
  7. 大学python怎么过_大学生该不该学Python?太纠结了?
  8. 您收到一封 2019 阿里云峰会 (北京) 邀请函
  9. 强调团体与配合的jinbiguandan
  10. 水滴石穿C语言之声明的语法
  11. Joe博客模板Typecho主题
  12. vs studio2015导入本地项目_Visual Studio2019自定义项目模板
  13. pbs分解_产品分解结构
  14. 提高网站访问速度的方法汇总
  15. winrar解压时出现诊断信息怎么办?
  16. 水晶报表的宽度调整方法(设计器、代码调整、rpt文件属性)
  17. 大数据统计分析毕业设计_数据分析毕业设计 大数据可视化毕业设计
  18. 【项目管理】干系人管理
  19. 必应(bing)广告的费用是多少?bing搜索广告推广简介
  20. 使用keybase给你的Github commit加上GPG Verified签名认证(keybase教程)

热门文章

  1. 【问题总结(12)】Cascader 省市区联动 obj[] Object.key() some.() String() forEach() 数组筛选过滤filter
  2. 【图像识别】基于卷积神经网络cnn实现银行卡数字识别matlab源码
  3. 通过例子学TLA+(十三)--多进程与await
  4. JS 分享(微博,微信,QQ ,QQ 空间)
  5. Java项目:基于jsp+mysql+Spring+SpringMVC+mybatis的办公用品领用管理系统
  6. ATE测试程序:ATE测试程序中的public、protected、private类权限
  7. 【CV】第 6 章:使用迁移学习的视觉搜索
  8. 程序员的注意事项(网上拷贝)
  9. 2021年真正强大、最值得推荐的的视频播放器(全平台)
  10. python学习手册beaut_平平无奇的python学习手册