之前在网上找到了一个文本匹配实现仓库,但是没有提供DSSM的代码,我就根据那个代码实现以下DSSM。数据集采用的是蚂蚁金服的数据集。也参考过别人的代码,但是总感觉怪怪的,DSSM原文中,一个query有对应的正样本和负样本,因此在实现的时候分别计算query与正负样本的余弦相似度,最后拼接再接softmax,但是蚂蚁金服数据集中每一个样本都已一个query和doc,对应一个label,并没有成对的正负样本,因此在实现中遇到了困难,因此最后我索性直接将余弦值作为网络输出,貌似还取得了不错的效果,那么代码会有些许不同。
第一,损失函数采用了二分类损失函数:

class torch.nn.BCELoss(weight=None, size_average=True)

第二,判断类别时:

def correct_predictions(output_probabilities, targets):"""Compute the number of predictions that match some target classes in theoutput of a model.Args:output_probabilities: A tensor of probabilities for different outputclasses.targets: The indices of the actual target classes.Returns:The number of correct predictions in 'output_probabilities'."""# _, out_classes = output_probabilities.max(dim=1)out_classes = output_probabilities.ge(0.5).byte().float()correct = (out_classes == targets).sum()return correct.item()

第三,网络结构设计如下:

class DSSM(nn.Module):def __init__(self, dropout=0.2,device="gpu"):super(DSSM, self).__init__()self.device = deviceself.embed = nn.Embedding(7901, 100)self.fc1 = nn.Linear(100, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512,256)self.dropout = nn.Dropout(dropout)self.Sigmoid = nn.Sigmoid() #method1self.relu = nn.ReLU()def forward(self, a, b):a = self.embed(a).sum(1)b = self.embed(b).sum(1)a = self.relu(self.fc1(a)) #torch.tanh# a = self.dropout(a)a = self.relu(self.fc2(a))# a = self.dropout(a)a = self.relu(self.fc3(a))# a = self.dropout(a)b = self.relu(self.fc1(b))# b = self.dropout(b)b = self.relu(self.fc2(b))# b = self.dropout(b)b = self.relu(self.fc3(b))# b = self.dropout(b)cosine = torch.cosine_similarity(a, b, dim=1, eps=1e-8)  #计算两个句子的余弦相似度# cosine = self.Sigmoid(cosine-0.5)cosine = self.relu(cosine)cosine = torch.clamp(cosine,0,1)return cosine

这样在蚂蚁金服测试集的准确率可以达到77以上,如果cosine后面不接relu,我跑到了78以上,但是总感觉出现了过拟合现象。此外,加入dropout效果反而不好,可能这个网络本身就不复杂吧。
其他的训练代码我参考了:https://github.com/zhaogaofeng611/TextMatch

DSSM pytorch实现相关推荐

  1. DSSM双塔模型及pytorch实现

    本文介绍用于商业兴趣建模的 DSSM 双塔模型.作为推荐领域中大火的双塔模型,因为效果不错并且对工业界十分友好,所以被各大厂广泛应用于推荐系统中. 通过构建 user 和 item 两个独立的子网络, ...

  2. 【PyTorch基础教程30】DSSM双塔模型(线上召回 | 模型更新)

    内容总结 召回中,一般的训练方式分为三种:point-wise.pair-wise.list-wise.RecHub中用参数mode来指定训练方式,每一种不同的训练方式也对应不同的Loss.对应的三种 ...

  3. 深度学习模型训练的一般方法(以DSSM为例)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 本文主要用于记录DSSM模型学习期间遇到的问题及分析.处理经验.先统领性地提出深度学习模型训练 ...

  4. 自然语言推断(NLI)、文本相似度相关开源项目推荐(Pytorch 实现)

    Awesome-Repositories-for-NLI-and-Semantic-Similarity mainly record pytorch implementations for NLI a ...

  5. pytorch从hdfs载入模型、从二进制字符串载入模型

    1,问题描述 离线用torch训练了一个简单的双塔模型,存到了hdfs上,希望可以在spark离线任务中使用.但spark离线任务要载入这个模型,就无法像pytorch官方的模型载入方式(这里:tor ...

  6. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  7. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  8. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  9. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

最新文章

  1. C# 中奇妙的函数 -- 1. ToLookup
  2. Linux 配置yum本地安装源
  3. SAP Netweaver gateway framework序列化
  4. shp与json互转(转载)
  5. android打包工具多渠道批量打包,Android 快速渠道批量打包详解教程-美团多渠道打包方案...
  6. python numpy常用操作、Numpy 多维数组、矩阵相乘、矩阵乘以向量
  7. cobol_在尝试之前不要讨厌COBOL
  8. 苹果供应商:iPhone SE 3 5G和AirPods生产平稳
  9. 找不到该项目,请确认该项目的位置的解决办法
  10. html网页文档无法复制粘贴图片,教你处理不能复制粘贴在网页中的详细图文
  11. 计算机网络 —— 冲突域和广播域
  12. 触动精灵中return-break-exit的使用总结
  13. 谷歌网盘下载 根据文件ID miniimagenet
  14. document.referrer和history.go(-1)退回上一页区别
  15. WebStorm英文版汉化
  16. 关于GPS坐标转换(一)
  17. div失去焦点事件onblur()不触发解决方法
  18. Keycloak实现手机验证码登录
  19. C语言小项目之扫雷游戏(简易版)
  20. ZOJ ~ 3469 ~ Food Delivery (区间DP)

热门文章

  1. Mac labelme安装及运行时崩溃bug解决
  2. cesium1.102和以上的版本,自定义材质报‘texture2D‘ : no matching overloaded function found错误
  3. AutoCAD对象模型笔记(一)(vba)
  4. 推荐一部关于母亲的西语电影
  5. REXROTH力士乐溢流阀DBDS20P18/200
  6. Mysql80解压版安装与卸载
  7. sqlserver设置密码_sql server 用户#x27;sa#x27;登录失败(错误18456)
  8. 互联网时代,人的大脑用来思考,记录交给电子名片
  9. nacos的配置中心
  10. CAD二次开发之撤销上一步(Undo)