torch中提供了topk方法用来返回矩阵中对应维度中最大的K个元素以及在对应维度中的index,但是numpy并没有提供和torch一样的topk方法,所以在这里通过numpy的argpartition实现torch中的topk方法。

直接给出代码:

def topk_(matrix, K, axis=1):if axis == 0:row_index = np.arange(matrix.shape[1 - axis])topk_index = np.argpartition(-matrix, K, axis=axis)[0:K, :]topk_data = matrix[topk_index, row_index]topk_index_sort = np.argsort(-topk_data,axis=axis)topk_data_sort = topk_data[topk_index_sort,row_index]topk_index_sort = topk_index[0:K,:][topk_index_sort,row_index]else:column_index = np.arange(matrix.shape[1 - axis])[:, None]topk_index = np.argpartition(-matrix, K, axis=axis)[:, 0:K]topk_data = matrix[column_index, topk_index]topk_index_sort = np.argsort(-topk_data, axis=axis)topk_data_sort = topk_data[column_index, topk_index_sort]topk_index_sort = topk_index[:,0:K][column_index,topk_index_sort]return topk_data_sort, topk_index_sort

测试功能:

# torch.topk方法
>>> a=torch.rand(5,4)
>>> a
tensor([[0.0154, 0.5266, 0.6294, 0.6897],[0.2201, 0.7039, 0.2639, 0.0681],[0.1006, 0.0464, 0.3314, 0.2052],[0.3954, 0.4373, 0.2147, 0.3532],[0.6560, 0.0549, 0.8040, 0.3528]])
>>> a.topk(2,dim=1,largest=True,sorted=True)
values=tensor([[0.6897, 0.6294],[0.7039, 0.2639],[0.3314, 0.2052],[0.4373, 0.3954],[0.8040, 0.6560]]),
indices=tensor([[3, 2],[1, 2],[2, 3],[1, 0],[2, 0]]))
>>> a.topk(2,dim=0,largest=True,sorted=True)
torch.return_types.topk(
values=tensor([[0.6560, 0.7039, 0.8040, 0.6897],[0.3954, 0.5266, 0.6294, 0.3532]]),
indices=tensor([[4, 1, 4, 0],[3, 0, 0, 3]]))# 自定义numpy的topk方法
>>> a_np=np.array(a)
>>> a_np
array([[0.01537341, 0.5266498 , 0.6293524 , 0.689658  ],[0.2201249 , 0.70394784, 0.26386315, 0.06814277],[0.10058308, 0.04639381, 0.3313678 , 0.20519769],[0.395352  , 0.43731135, 0.21468669, 0.35324287],[0.655955  , 0.05492574, 0.80404747, 0.35280174]], dtype=float32)
>>> topk(a_np,2,axis=1)
(array([[0.689658  , 0.6293524 ],[0.70394784, 0.26386315],[0.3313678 , 0.20519769],[0.43731135, 0.395352  ],[0.80404747, 0.655955  ]], dtype=float32), array([[3, 2],[1, 2],[2, 3],[1, 0],[2, 0]]))
>>> topk(a_np,2,axis=0)
(array([[0.655955  , 0.70394784, 0.80404747, 0.689658  ],[0.395352  , 0.5266498 , 0.6293524 , 0.35324287]], dtype=float32), array([[4, 1, 4, 0],[3, 0, 0, 3]]))

可以发现已经完全实现了和torch.topk相同的功能~

numpy实现torch的topk方法相关推荐

  1. numpy和torch数据操作对比

    对numpy和torch数据操作进行对比,避免遗忘. ndarray和tensor import torch import numpy as npnp_data = np.arange(6).resh ...

  2. list,numpy,tensor之间相互转换的方法

    list,numpy,tensor之间相互转换的方法: a=[[1,2],[3,4]]#list print(a) b=np.array(a)#list->numpy print(b) c=to ...

  3. Pytorch函数之topk()方法

    根据Pytorch中的手册可以看到,topk()方法用于返回输入数据中特定维度上的前k个最大的元素. torch.topk(input, k, dim=None, largest=True, sort ...

  4. pytorch下Numpy,Torch,Spicy,NetworkX及其他基本数据类型相关操作(持续更新)

    Tricks 1. torch.sparse.FloatTensor(position, value) 稀疏张量表示为一对稠密张量:一个值张量和一个二维指标张量(每一维中存储多个值).一个稀疏张量可以 ...

  5. PyTorch中的topk方法以及分类Top-K准确率的实现

    PyTorch中的topk方法以及分类Top-K准确率的实现 Top-K 准确率 在分类任务中的类别数很多时(如ImageNet中1000类),通常任务是比较困难的,有时模型虽然不能准确地将groun ...

  6. python增加一列数据_Python编程给numpy矩阵添加一列方法示例

    首先我们有一个数据是一个mn的numpy矩阵现在我们希望能够进行给他加上一列变成一个m(n+1)的矩阵 import numpy as np a = np.array([[1,2,3],[4,5,6] ...

  7. PyTorch | 通过torch.normal()创建概率分布的张量 | torch.normal()如何使用?torch.normal()使用方法 | torch.normal()例子

    在公众号[计算机视觉联盟]后台回复[9076]获取独家200页手推AI笔记:我的微信:PursueWin:    --by Sophia 中科院学霸 | 上市AI算法工程师 | CSDN博客专家 通过 ...

  8. PyTorch | 通过torch.eye创建单位对角矩阵 | torch.eye()如何使用?torch.eye()例子 | torch.eye()使用方法

    在公众号[计算机视觉联盟]后台回复[9076]获取独家200页手推AI笔记:我的微信:PursueWin:    --by Sophia 中科院学霸 | 上市AI算法工程师 | CSDN博客专家 t ...

  9. PyTorch | torch.linspace()创建均分数列张量 | torch.linspace()如何使用?| torch.linspace()使用方法 | torch.linspace例子

    公众号[计算机视觉联盟]后台回复[PyTorch]可以获得独家PyTorch学习教程pdf版 通过torch.linspace创建均分数列 张量 步长=(Start - end)/(Steps - 1 ...

最新文章

  1. Http协议中的各种长度限制总结
  2. Dos中查找文件命令的使用find
  3. 发布ccnet的步骤
  4. Nginx正向代理实现
  5. Quo Vadis JUnit
  6. mongodb atlas_如何使用MongoDB Atlas将MERN应用程序部署到Heroku
  7. mysql创建的数据库都在哪里看_mysql 怎么查看创建的数据库和表
  8. 【Git、GitHub、GitLab】七 git中分支的删除以及出现分离头指针的情况
  9. JEECG第二期深入使用培训(报名截止2014-06-21)
  10. 博达路由器常见功能教学0
  11. nginx搭建视频服务器
  12. 《Microsoft SQL Server入门教程》第01篇 SQL Server 简介
  13. 解决添加打印机print spooler打印服务自动关闭故障
  14. Android网络编程(一次网络请求)
  15. python爬虫框架论文开题报告范文_研究思路及框架--开题报告
  16. Python Spider入门
  17. CentOS 7 minimal安装完成之后安装图形界面
  18. 网络直播电视之寻找直播地址(下)
  19. RulersGuides.js – 网站中实现 Photoshop 标尺效果
  20. 最全的Java版本历史

热门文章

  1. revit运行dll文件弹出:未能加载文件或程序集“presentationframework, Version=5.0.0.0, Culture=neutral, PublicKeyToken
  2. 易语言 python库_精易Python支持库 (1.1#1205版)发布啦!
  3. 老毛桃win10pe 启动显示B1InitializeLibrary failed 0xc000009a解决方法
  4. root = Tk() 和 root = Tkinter.Tk() 区别
  5. lms算法的matlab实现,LMS算法的MATLAB实现
  6. 【摸鱼系列】如何用Python做一个有趣的Loading彩蛋游戏~
  7. 第十一次 作业 视图的应用
  8. MTK Android software Tools工具的说明
  9. R语言升级版本和迁移老版本中的包到新版本上的一些问题
  10. XSS Challenges/刷题/Stage #4