文章目录

  • 0 输入数据
  • 1 余弦相似度(Cosine Similarity)
  • 2 torch.cosine_similarity
  • 3 问题
  • 4 分析与解决
    • 4.1 答案
  • 5 另外的实现方法

0 输入数据

import torch
# 设置随机数种子,以保证结果可重现
torch.manual_seed(0)
a = torch.randn(4, 3)
tensor([[ 1.5410, -0.2934, -2.1788],[ 0.5684, -1.0845, -1.3986],[ 0.4033,  0.8380, -0.7193],[-0.4033, -0.5966,  0.1820]])

1 余弦相似度(Cosine Similarity)

  余弦相似度的公式如下所示:

2 torch.cosine_similarity

  可以使用torch自带的余弦相似度计算函数(下面三种用哪一个都可以,效果是一样的):

torch.cosine_similarity(x1, x2, dim=1, eps=1e-08)
torch.nn.CosineSimilarity(x1, x2, dim=1, eps=1e-08)
torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-08) → Tensor

  该函数原文档在:torch官方文档

3 问题

  cosine_similarity中的参数要两个tensor数据,而我们的需求是求一个tensor内的行与行之间的余弦相似度。很显然不能直接使用该函数。

4 分析与解决

  变量a的shape是(4, 3), 求a中的行与行之间的相似性,每行与每行有一个相似度,那么最终得到的结果应该是shape为(4, 4)。这个结果意味着a中的每一行要与包括本行在内的4行都有一个相似度数值。
  如下图所示,以第一行与所有行的相似度计算为例:

  分别拿着第一行的数据再与右侧所有行的数据按照余弦相似度计算,然后再换第二行的数据再与右侧所有行数据计算。这是正常人类的计算过程。
  固然上图中的处理torch 配合for循环依然能够实现。但是torch的一种快捷做法如下图所示,直接将第一行数据复制多份,每一份分别与右侧的每一行计算

  按照这个逻辑,下图则是计算的全过程:

  如此就满足torch.cosine_similarity需要输入两个tensor的要求了。左侧的数据是每一行都分别复制了4次,右侧的是整体所有行被复制了4次。

  a.unsqueeze(1)得到一个shape为(4,1,3)的tensor,a.unsqueeze(0)得到一个shape为(1,4,3)的tensor。(4,1,3)的tensor要和(1,4,3)的tensor进行运算,必须扩充到同一个shape,这样就都变成了(4,4,3),这就是torch的广播机制(broadcast)
  而在广播过程中,shape为(4,1,3)的tensor要变成(4,4,3)的tensor就只能把后边的3多复制几次。也就实现了上图中的左侧tensor的效果;同理shape为(1,4,3)的tensor要变成(4,4,3)的tensor就只能把后边的(4,3)多复制几次。

4.1 答案

  因此对于一个shape为(4,3)的tensor,使用下行公式可以得到其行与行的余弦相似度。

similarity = torch.cosine_similarity(a.unsqueeze(1), a.unsqueeze(0), dim=-1)
tensor([[ 1.0000,  0.8499,  0.6155, -0.4228],[ 0.8499,  1.0000,  0.1493,  0.1182],[ 0.6155,  0.1493,  1.0000, -0.9087],[-0.4228,  0.1182, -0.9087,  1.0000]])

5 另外的实现方法

  上述是一行实现余弦相似度的代码,在代码角度上非常简洁,但是会消耗较多的时间(因为参与计算的维度增大的原因)。下面可以根据公式两行代码实现,由公式可以看出,两行之间的余弦相似度是通过方差归一化后的两行数值内积得到的。
所以方法如下:

a = a / torch.norm(a, dim=-1, keepdim=True) # 方差归一化,即除以各自的模
similarity = torch.mm(a, a.T) # 矩阵乘法

  该方法可以不通过扩充快速实现同一tensor矩阵内每行之间的余弦相似度。该方法在CPU上的求解速度几乎是torch.cosine_similarity的10X倍。

如果该内容对您有用,请点击 收藏+点赞

pytorch一行实现:计算同一tensor矩阵内每行之间的余弦相似度相关推荐

  1. 文本相似度计算(切词、生成词向量,使用余弦相似度计算)

    项目需求 有多个文本,分别是正负样本,使用余弦相似度计算负样本与正样本的样本相似度,若准确率高,后期可判断新加样本与正样本的相似度. 输入如下所示: content label 今天下午,在龙口市诸由 ...

  2. 计算两个矩阵的行向量之间的欧式距离

    1 问题描述 矩阵P的大小为[m, d]   用行向量表示为P1, P2,...,Pm 矩阵C的大小为[n, d]    用行向量表示为C1, C2,...,Cn 求矩阵P的每个行向量与矩阵C的每个行 ...

  3. 余弦相似度计算的实现方式

    目录 一.余弦相似度计算方式 1.python 2.sklearn 3.scipy 4.numpy 5.pytorch 6.faiss 二.规模暴增计算加速 1.numpy矩阵计算GPU加速--cup ...

  4. 计算向量相似度 ---余弦相似度

    1.余弦相似度可用来计算两个向量的相似程度 对于如何计算两个向量的相似程度问题,可以把这它们想象成空间中的两条线段,都是从原点([0, 0, -])出发,指向不同的方向.两条线段之间形成一个夹角,如果 ...

  5. python用角度计算余弦值_Python 使用sklearn计算余弦相似度

    背景 在计算相似度时,常常用到余弦夹角来判断相似度,Cosine(余弦相似度)取值范围[-1,1],当两个向量的方向重合时夹角余弦取最大值1,当两个向量的方向完全相反夹角余弦取最小值-1,两个方向正交 ...

  6. 相似度计算(1)——余弦相似度

    余弦相似度   余弦相似度:用向量空间中两向量夹角的余弦值作为衡量两个个体之间差异的大小.余弦值越接近1,表明两个向量的夹角越接近0度,则两个向量越相似.余弦值越接近0,表明两个向量的夹角越接近180 ...

  7. 相似度计算——欧氏距离、汉明距离、余弦相似度

    计算图像间的相似性可以使用欧氏距离.余弦相似度/作为度量,前者强调点的思想,后者注重线的思想. 欧氏距离 欧式距离/Euclidean Distance即n维空间中两个点之间的实际距离.已知两个点A= ...

  8. 基于 TF-IDF 计算古诗之间的文本相似度

    步骤 对每一首古诗进行分词 计算每一个词的 tfidf 值 利用每首古诗的词向量计算两首古诗之间的余弦相似度 import pandas as pd import numpy as np import ...

  9. java计算余弦相似度

    您可以使用以下java代码来计算余弦相似度 import java.util.List;public class CosineSimilarity {public static double cosi ...

最新文章

  1. 写得蛮好的linux学习笔记(二)
  2. centos 6.5 安装 lamp 后mysql不能启动_CentOS 6.5系统安装配置LAMP(Apache+PHP5+MySQL)服务器环境...
  3. arm中断保护和恢复_ARM中断返回的详细分析
  4. nginx log_format 中的变量
  5. 经典文章解释apache与tomcat!看完秒懂
  6. Windows 下用 SecureCRT 连接 VirtualBox 中的 Ubuntu
  7. java match parent_Maven的聚合(多模块)和Parent继承
  8. 草稿 listview动态绑定数据
  9. infer的用法_imply和infer的用法区别:阅读理解题里频繁看到的词汇
  10. dmp导入数据 oracle_一文看懂oracle12c数据库跨小版本迁移
  11. 如何过滤掉xml中的转义字符_水肥一体化应用中如何选择过滤器?
  12. Django restful Framework 之序列化与反序列化
  13. 浪曦_Struts2应用开发系列_第2讲.Struts2的类型转换-学习笔记
  14. 智慧路灯解决方案-最新全套文件
  15. sql 语句,主键外键详解
  16. 白嫖!白嫖!【尚学堂】高淇Java300集全套学习资料!
  17. 企业网站建设流程是什么?三个流程要知道
  18. B/S架构与C/S架构
  19. 模拟器提示关闭 hyper-V,但 hyper-V实际上并没有开启
  20. 实现一个函数输入123456789,输出123,456,789”

热门文章

  1. 计算机会计系统风险与防范论文,【会计电算化论文】会计电算化的风险与防范措施(共3301字)...
  2. Java学习笔记3--类与对象
  3. 第8章 虚拟现实技术的相关软件
  4. 最详细的Excel模块Openpyxl教程(三)-使用公式
  5. 2017年最新15个漂亮的 HTML 摄影网站模板
  6. C++程序设计之STL学习笔记思维导图(拙作)
  7. 打印机驱动特殊安装步骤
  8. 2006JAVA类图书读者投票排行榜
  9. Mysql 查询表的记录
  10. 【安全知识分享】5S管理EHS和质量管理体系阐述(附下载)