CDIMC-Net[1] 中有个对整个数据集求 kNN 图的函数 get_kNNgraph2[2],是用 dense 的 numpy.ndarray 存的,空间复杂度 O ( n 2 ) O(n^2) O(n2),大数据集很吃内存,但其实 kNN 图很稀疏。这里用 scipy.sparse 的 API 改写。

Code

  • csr_matrix:row slicing 高效,因为一行对应一个 datum 的邻接链表,取 batch 是对行取,所以用它。
  • lil_matrix:说是「改变稀疏结构很高效」,用在图的构造时,构造完再转 csr_matrix(本来直接用 csr_matrix 构造,然后它建议用 lil_matrix)。
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
# import torchdef get_kNNgraph2(data,K_num):"""原来的构图函数https://github.com/DarrenZZhang/CDIMC-Net/blob/main/CDIMC-net-handwritten_final.py#L46"""# each row of data is a samplex_norm = np.reshape(np.sum(np.square(data), 1), [-1, 1])  # column vectorx_norm2 = np.reshape(np.sum(np.square(data), 1), [1, -1])  # column vectordists = x_norm - 2 * np.matmul(data, np.transpose(data))+x_norm2num_sample = data.shape[0]graph = np.zeros((num_sample,num_sample),dtype = np.int)for i in range(num_sample):distance = dists[i,:]small_index = np.argsort(distance)graph[i,small_index[0:K_num]] = 1graph = graph-np.diag(np.diag(graph))resultgraph = np.maximum(graph,np.transpose(graph))return resultgraphdef get_kNNgraph2_sparse(X, K_num, batch_size=256):"""sparse version of kNN graph calculation"""n = X.shape[0]  # full size# `(n, n)`  NOT `[n, n]`G = lil_matrix((n, n), dtype=np.int8)x_norm_all = np.sum(np.square(X), axis=1, keepdims=True).T  # [1, n]for _begin in range(0, n, batch_size):_end = min(_begin + batch_size, n)X_batch = X[_begin: _end]# euclidean distancex_norm = np.sum(np.square(X_batch), axis=1, keepdims=True)  # [batch_size, 1]D = x_norm - 2 * np.matmul(X_batch, np.transpose(X)) + x_norm_all  # [batch_size, n]small_index = np.argsort(D, axis=1)[:, :K_num]  # [batch_size, K_num]# mask the kNNfor i in range(small_index.shape[0]):_row_id = _begin + i_small_idx = small_index[i]G[_row_id, _small_idx] = 1# no self-loopG.setdiag(0)# symmetrizeG = G.maximum(G.transpose())# convert to `csr_matrix` for fast row slicingG = G.tocsr()return G"""验证一致性"""
N = 6  # num of data
D = 3  # data dim
K = N // 2
for i in range(150):# print(i)X = np.random.permutation(N * D).reshape(N, D)G1 = get_kNNgraph2(X, K)G2 = get_kNNgraph2_sparse(X, K).todense()diff = (G1 != G2).sum()if diff != 0:print("diff:", i, diff)  # 无输出# print("PyTorch sparse matrix")# x_nz, y_nz = G2.nonzero()# I = torch.cat([# torch.from_numpy(x_nz),# torch.from_numpy(y_nz),# ], 0).long()# V = torch.ones(x_nz.shape[0]).float()# breakprint("DONE")

References

  1. DarrenZZhang/CDIMC-Net
  2. get_kNNgraph2
  3. Sparse matrices (scipy.sparse)
  4. scipy.sparse.csr_matrix
  5. scipy.sparse.lil_matrix
  6. torch.sparse

scipy.sparse使用简例相关推荐

  1. scipy笔记:scipy.sparse

    1 稀疏矩阵介绍 在networkx包中,很多运算返回的是sparse matrix(如nx.laplacian_matrix),这是稀疏矩阵格式.隶属于scipy.sparse import net ...

  2. scipy.sparse.csr_matrix函数和coo_matrix函数

    Scipy高级科学计算库:和Numpy联系很密切,Scipy一般都是操控Numpy数组来进行科学计算.统计分析,所以可以说是基于Numpy之上了.Scipy有很多子模块可以应对不同的应用,例如插值运算 ...

  3. 看的懂的scipy.sparse.csr_matrix和scipy.sparse.csc_matrix

    一.导入   在用python进行科学运算时,常常需要把一个稀疏的np.array压缩,这时候就用到scipy库中的sparse.csr_matrix函数和sparse.csc_matric函数.   ...

  4. scipy.sparse.coo_matrix、csr_matrix、lil_matrix、dia_matrix

    文章目录 coo_matrix csr_matrix lil_matrix dia_matrix coo_matrix 1.coo啥意思?COOrdinate(坐标) 2.那么coo_matrix又是 ...

  5. oracle供应商导入,AP供应商导入简例.pdf

    AP供应商导入简例 Oracle 完全测试记录 供应商导入 吴若童 总述总述 总述总述 供应商供应商接口接口的原理的原理?? 供应商供应商接口接口的原理的原理?? 系统从三个表分别导入供应商.供应商地 ...

  6. Android RuntimePermissions运行时权限:单个运行时权限申请简例

    Android RuntimePermissions运行时权限:单个运行时权限申请简例 Android运行时权限申请的框架结构和步骤比较简单和固定,一般现状代码启动后检查当前的Android SDK版 ...

  7. 5.3linux下C语言socket网络编程简例

    原创文章,转载请注明转载字样和出处,谢谢! 这里给出在Linux下的简单socket网络编程的实例,使用tcp协议进行通信,服务端进行监听,在收到客户端的连接后,发送数据给客户端:客户端在接受到数据后 ...

  8. Ansible 入门:安装 简例 playbook应用

    Mysql 内:select unix_timestamp('2016-10-20')  <---> select from_unixtime(147662104) 转时间戳:date + ...

  9. 关于稀疏矩阵转化为稠密矩阵问题 (scipy.sparse格式和tensor稀疏张量格式)

    scipy.sparse: todense() pytorch中的稀疏张量tensor: to_dense()

最新文章

  1. python操作mongodb数据库
  2. Miniconda3+Tensorflow2.3(GPU版)+Win10_x64+GTX1060深度学习环境搭建
  3. ASP.NET Calendar 控件
  4. 模组使用之天线阻抗匹配、匹配过程、天线选型注意、RF走线Layout建议
  5. sdn和nfv的区别—Vecloud微云
  6. wxpython实现鼠标拖动事件
  7. [网络]------TCP UDP HTTP Socket 区别
  8. JS动态赋值同时触发onchange方法
  9. 中山大学提出新型行人重识别方法和史上最大评测基准
  10. axios-引入-常用语法-源码
  11. 韩顺平php视频笔记68 析构函数 php垃圾回收机制
  12. 年终盘点:2010年半导体产业的十大进展-转自老杳
  13. ubuntu 设置tab键自动补全
  14. python日期时间模块_Python模块|时间处理模块-日期时间模块,python,datetime
  15. ubuntu16.04安装搜狗拼音输入法
  16. 【重温经典】《谁谋杀了我们的游戏?》出自《黑神·话悟空》制作人Yocar
  17. 基于嵌入式端的人脸识别算法
  18. 学生表mysql查询语句
  19. Python基础入门实验3附加题
  20. Linux 学习资料

热门文章

  1. 自然语言处理之蒙古文词网生成系统
  2. 六级考研单词之路-二十八
  3. 怎样选择适合自己的发色?ps教你快速染发
  4. 中国各省区块链政策竞争力指数TOP10(2020年10月)|链塔月榜
  5. uni.showModal
  6. 3、【Xilinx下载器】【ILA】使用ILA调试时出错的解决方案
  7. plsql安装与配置
  8. Thymeleaf---基础知识
  9. 大数据学习之经典数据分析算法详解
  10. 干货解读 |大数据,数据挖掘,机器学习的区别和联系