在pytorch 中实现真正的 pairwise distances
文章目录
- 问题
- 解决方法
问题
pairwise distances即输入两个张量,比如张量 AM×D,BN×DA^{M \times D} ,B^{N \times D}AM×D,BN×D,M,N分布代表数据数量,D为特征维数,输出张量A和B 两两之间的距离,即一个 M×NM \times NM×N 的张量.
这个在 sklearn 中有个很方便的函数 pairwise_distances
,其实这个功能在 pytorch 中也有实现.
但是很坑跌的是,torch 中居然要求张量 A,B 的形状一样= =||||.
因此这里记录一下,自己的处理方法:即借用广播机制重新处理一下输入张量
解决方法
import torch
from torch import nna=torch.tensor([[1,1,1],[2,2,2]])
b=torch.tensor([[2,2,2],[1,1,1],[2,2,2],[1,1,1],[2,2,2]
])
print(a.shape)
print(b.shape)
def pdist(a: torch.Tensor, b: torch.Tensor, p: int = 2) -> torch.Tensor:return (a-b).abs().pow(p).sum(-1).pow(1/p)a_=a.unsqueeze(1)
b_=b.unsqueeze(0)print(pdist(a_,b_))
输出:
>>> torch.Size([2, 3])
>>> torch.Size([5, 3])
>>> tensor([[1.7321, 0.0000, 1.7321, 0.0000, 1.7321],[0.0000, 1.7321, 0.0000, 1.7321, 0.0000]])
在pytorch 中实现真正的 pairwise distances相关推荐
- pytorch中调整学习率的lr_scheduler机制
pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...
- pytorch中如何处理RNN输入变长序列padding
一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...
- PyTorch中的MIT ADE20K数据集的语义分割
PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...
- PyTorch中nn.Module类中__call__方法介绍
在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...
- 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型
作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...
- 实践指南 | 用PyTea检测 PyTorch 中的张量形状错误
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨陈萍.泽南 来源丨机器之心 编辑丨极市平台 导读 韩国首尔大学 ...
- 实践教程 | 浅谈 PyTorch 中的 tensor 及使用
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...
- 详解PyTorch中的ModuleList和Sequential
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨小占同学@知乎(已授权) 来源丨https://zhuanla ...
- 在PyTorch中进行双线性采样:原理和代码详解
↑ 点击蓝字 关注视学算法 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/257958558 编辑丨极市平台 在pytorch中的双线性采样(Bilinear Sa ...
最新文章
- Veeam Backup Replication试用(四):配置同步(Replication Job)与恢复(Restore)
- 线性时间复杂度求数组中第K大数
- 傅德良:选择视频编码器的误区
- 《Linux From Scratch》第二部分:准备构建 第五章:构建临时文件系统- 5.2 工具链技术备注...
- mysql序列号生成_一文看懂mycat的6种全局序列号实现方式
- es6 箭头函数 rest参数 扩展运算符
- [转]游戏UI与flash 组件开发
- 机器视觉用c还是python_机器视觉_opencv-python环境搭建
- jinja Macros
- 我爬了价值1800亿的商品信息
- Exchange 2013反垃圾邮件功能
- Exchange 2010 使用http访问 OWA
- jsp九大内置对象的作用及用法
- Shell判断路径是否存在
- 项目管理中如何如何进行风险控制
- Batch Normalization详解(原理+实验分析)
- 游戏 - PS4 海绵宝宝: 争霸比基尼海滩重注版
- micro:bit 了解
- 付费专栏热销排行榜·0315更新
- 奥西300工程机服务器装系统,奥西pw300驱动
热门文章
- 使用geocoder_你在哪? 使用Geocoder PHP实现地理位置
- PLSQL-创建函数
- 手把手教你在ubuntu上安装搜狗输入法
- MySQL 约束(Constraint)
- ffmpeg转码视频真的好用!(ffmpeg的简单使用方法)
- 计算机应用基础进制转换说课稿,进制和进制转换说课稿.doc
- python爬虫大众点评_python爬虫——按城市及店铺面爬取大众点评分类
- mysql使用Navicat自动备份+javamail发送邮件
- 这些页面还有这一些的埋点知识|风控人应知系列
- Dgraph对GraphQL最新云端无服务器支持