numpy实现torch的topk方法
torch中提供了topk方法用来返回矩阵中对应维度中最大的K个元素以及在对应维度中的index,但是numpy并没有提供和torch一样的topk方法,所以在这里通过numpy的argpartition实现torch中的topk方法。
直接给出代码:
def topk_(matrix, K, axis=1):if axis == 0:row_index = np.arange(matrix.shape[1 - axis])topk_index = np.argpartition(-matrix, K, axis=axis)[0:K, :]topk_data = matrix[topk_index, row_index]topk_index_sort = np.argsort(-topk_data,axis=axis)topk_data_sort = topk_data[topk_index_sort,row_index]topk_index_sort = topk_index[0:K,:][topk_index_sort,row_index]else:column_index = np.arange(matrix.shape[1 - axis])[:, None]topk_index = np.argpartition(-matrix, K, axis=axis)[:, 0:K]topk_data = matrix[column_index, topk_index]topk_index_sort = np.argsort(-topk_data, axis=axis)topk_data_sort = topk_data[column_index, topk_index_sort]topk_index_sort = topk_index[:,0:K][column_index,topk_index_sort]return topk_data_sort, topk_index_sort
测试功能:
# torch.topk方法
>>> a=torch.rand(5,4)
>>> a
tensor([[0.0154, 0.5266, 0.6294, 0.6897],[0.2201, 0.7039, 0.2639, 0.0681],[0.1006, 0.0464, 0.3314, 0.2052],[0.3954, 0.4373, 0.2147, 0.3532],[0.6560, 0.0549, 0.8040, 0.3528]])
>>> a.topk(2,dim=1,largest=True,sorted=True)
values=tensor([[0.6897, 0.6294],[0.7039, 0.2639],[0.3314, 0.2052],[0.4373, 0.3954],[0.8040, 0.6560]]),
indices=tensor([[3, 2],[1, 2],[2, 3],[1, 0],[2, 0]]))
>>> a.topk(2,dim=0,largest=True,sorted=True)
torch.return_types.topk(
values=tensor([[0.6560, 0.7039, 0.8040, 0.6897],[0.3954, 0.5266, 0.6294, 0.3532]]),
indices=tensor([[4, 1, 4, 0],[3, 0, 0, 3]]))# 自定义numpy的topk方法
>>> a_np=np.array(a)
>>> a_np
array([[0.01537341, 0.5266498 , 0.6293524 , 0.689658 ],[0.2201249 , 0.70394784, 0.26386315, 0.06814277],[0.10058308, 0.04639381, 0.3313678 , 0.20519769],[0.395352 , 0.43731135, 0.21468669, 0.35324287],[0.655955 , 0.05492574, 0.80404747, 0.35280174]], dtype=float32)
>>> topk(a_np,2,axis=1)
(array([[0.689658 , 0.6293524 ],[0.70394784, 0.26386315],[0.3313678 , 0.20519769],[0.43731135, 0.395352 ],[0.80404747, 0.655955 ]], dtype=float32), array([[3, 2],[1, 2],[2, 3],[1, 0],[2, 0]]))
>>> topk(a_np,2,axis=0)
(array([[0.655955 , 0.70394784, 0.80404747, 0.689658 ],[0.395352 , 0.5266498 , 0.6293524 , 0.35324287]], dtype=float32), array([[4, 1, 4, 0],[3, 0, 0, 3]]))
可以发现已经完全实现了和torch.topk相同的功能~
numpy实现torch的topk方法相关推荐
- numpy和torch数据操作对比
对numpy和torch数据操作进行对比,避免遗忘. ndarray和tensor import torch import numpy as npnp_data = np.arange(6).resh ...
- list,numpy,tensor之间相互转换的方法
list,numpy,tensor之间相互转换的方法: a=[[1,2],[3,4]]#list print(a) b=np.array(a)#list->numpy print(b) c=to ...
- Pytorch函数之topk()方法
根据Pytorch中的手册可以看到,topk()方法用于返回输入数据中特定维度上的前k个最大的元素. torch.topk(input, k, dim=None, largest=True, sort ...
- pytorch下Numpy,Torch,Spicy,NetworkX及其他基本数据类型相关操作(持续更新)
Tricks 1. torch.sparse.FloatTensor(position, value) 稀疏张量表示为一对稠密张量:一个值张量和一个二维指标张量(每一维中存储多个值).一个稀疏张量可以 ...
- PyTorch中的topk方法以及分类Top-K准确率的实现
PyTorch中的topk方法以及分类Top-K准确率的实现 Top-K 准确率 在分类任务中的类别数很多时(如ImageNet中1000类),通常任务是比较困难的,有时模型虽然不能准确地将groun ...
- python增加一列数据_Python编程给numpy矩阵添加一列方法示例
首先我们有一个数据是一个mn的numpy矩阵现在我们希望能够进行给他加上一列变成一个m(n+1)的矩阵 import numpy as np a = np.array([[1,2,3],[4,5,6] ...
- PyTorch | 通过torch.normal()创建概率分布的张量 | torch.normal()如何使用?torch.normal()使用方法 | torch.normal()例子
在公众号[计算机视觉联盟]后台回复[9076]获取独家200页手推AI笔记:我的微信:PursueWin: --by Sophia 中科院学霸 | 上市AI算法工程师 | CSDN博客专家 通过 ...
- PyTorch | 通过torch.eye创建单位对角矩阵 | torch.eye()如何使用?torch.eye()例子 | torch.eye()使用方法
在公众号[计算机视觉联盟]后台回复[9076]获取独家200页手推AI笔记:我的微信:PursueWin: --by Sophia 中科院学霸 | 上市AI算法工程师 | CSDN博客专家 t ...
- PyTorch | torch.linspace()创建均分数列张量 | torch.linspace()如何使用?| torch.linspace()使用方法 | torch.linspace例子
公众号[计算机视觉联盟]后台回复[PyTorch]可以获得独家PyTorch学习教程pdf版 通过torch.linspace创建均分数列 张量 步长=(Start - end)/(Steps - 1 ...
最新文章
- Http协议中的各种长度限制总结
- Dos中查找文件命令的使用find
- 发布ccnet的步骤
- Nginx正向代理实现
- Quo Vadis JUnit
- mongodb atlas_如何使用MongoDB Atlas将MERN应用程序部署到Heroku
- mysql创建的数据库都在哪里看_mysql 怎么查看创建的数据库和表
- 【Git、GitHub、GitLab】七 git中分支的删除以及出现分离头指针的情况
- JEECG第二期深入使用培训(报名截止2014-06-21)
- 博达路由器常见功能教学0
- nginx搭建视频服务器
- 《Microsoft SQL Server入门教程》第01篇 SQL Server 简介
- 解决添加打印机print spooler打印服务自动关闭故障
- Android网络编程(一次网络请求)
- python爬虫框架论文开题报告范文_研究思路及框架--开题报告
- Python Spider入门
- CentOS 7 minimal安装完成之后安装图形界面
- 网络直播电视之寻找直播地址(下)
- RulersGuides.js – 网站中实现 Photoshop 标尺效果
- 最全的Java版本历史
热门文章
- revit运行dll文件弹出:未能加载文件或程序集“presentationframework, Version=5.0.0.0, Culture=neutral, PublicKeyToken
- 易语言 python库_精易Python支持库 (1.1#1205版)发布啦!
- 老毛桃win10pe 启动显示B1InitializeLibrary failed 0xc000009a解决方法
- root = Tk() 和 root = Tkinter.Tk() 区别
- lms算法的matlab实现,LMS算法的MATLAB实现
- 【摸鱼系列】如何用Python做一个有趣的Loading彩蛋游戏~
- 第十一次 作业 视图的应用
- MTK Android software Tools工具的说明
- R语言升级版本和迁移老版本中的包到新版本上的一些问题
- XSS Challenges/刷题/Stage #4