mpi4py的wrapper

mpi4py写了个wrapper。包括并行写入和读出,对于numpy arraysplitscatterbcastgather,基本完成。如果有新想法应该会持续更新,加入新功能。

import h5py as h5
from mpi4py import MPI
import time
import numpy as npmpi_comm = MPI.COMM_WORLD
mpi_size = mpi_comm.Get_size()
mpi_rank = mpi_comm.Get_rank()def process_size(total_size, rank=mpi_rank, size=mpi_size):'''Split an array into chunks for mpi process, the chunk length for each chunks.'''if rank < int(total_size % size):return int(total_size//size + 1)else:return int(total_size//size)def ind_end(total_size, rank=mpi_rank, size=mpi_size):'''Split an array into chunks for mpi process, the end index for each chunks.'''all_size = [int(total_size//size + 1)]* int(total_size % size)#print(total_size, all_size)all_size += [int(total_size//size)]* (total_size - int(total_size % size))#print(size, all_size)return np.cumsum(all_size)[rank]def ind_start(total_size, rank=mpi_rank, size=mpi_size):'''Split an array into chunks for mpi process, the start index for each chunks.'''return ind_end(total_size, rank=rank, size=size) - process_size(total_size, rank=rank, size=size)def parallel_save_dataset(filename, key, data, group_name=None, axis=0):'''Save data from different process to one file, data from different process is concatenate along input axis.\nfilenaem: filename to save the data\nkey: key of the dataset to save data\ndata: ndarray, data to save\ngroup_name: if None, save data to file[key] else save data to file[group_name][key]\naxis: along with axis to concatenate data, length of other axis of data should be the same for different process.\n'''data = np.asarray(data)shp = list(data.shape)num = shp[axis]len_axes = mpi_comm.gather(num, root=0)if mpi_rank == 0:ied = np.cumsum(len_axes)else:ied = Noneied = mpi_comm.scatter(ied, root=0)ist = ied - numsave_slice = [slice(None,None,None)]*len(shp)save_slice[axis] = slice(ist, ied, None)save_slice = tuple(save_slice)if mpi_rank == 0:shp[axis] = sum(len_axes)with h5.File(filename, 'a') as filein:if group_name is None:filein.create_dataset(key, shape=shp, dtype=data.dtype)else:filein.create_group(group_name)filein[group_name].create_dataset(key, shape=shp, dtype=data.dtype)mpi_comm.barrier()if group_name is None:print_key = keyelse:print_key = '%s\' in group \'%s'%(key, group_name)for ii in range(mpi_size):if ii == mpi_rank:for _ in range(10):try:#raise IOErrorwith h5.File(filename, 'a') as filein:if group_name is None:filein[key][save_slice] = dataelse:filein[group_name][key][save_slice] = dataprint('Rank %d save dataset \'%s\' %d to %d into %s!'%(mpi_rank, print_key, ist, ied, filename))time.sleep(0.5)breakexcept IOError as e:print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))time.sleep(0.5)else:raise IOError('Rank %d save dataset \'%s\' %d to %d into %s!'%(mpi_rank, print_key, ist, ied, filename))mpi_comm.barrier()def parallel_load_dataset(filename, key, group_name=None, axis=0):'''Load data from one file and spread to different process .\nfilenaem: filename to save the data\nkey: key of the dataset to save data\ngroup_name: if None, save data to file[key] else save data to file[group_name][key]\naxis: the axis to spread\n'''with h5.File(filename, 'r') as filein:if group_name is None:dataset = filein[key]else:dataset = filein[group_name][key]shp = dataset.shapenum = shp[axis]ist = ind_start(num)ied = ind_end(num)save_slice = [slice(None,None,None)]*len(shp)save_slice[axis] = slice(ist, ied, None)save_slice = tuple(save_slice)return dataset[save_slice]def parallel_save_multi_dataset(filename, key, data, group_name=None):'''Save data from different process to one file with different key.\nfilenaem: filename to save the data\nkey: key of the dataset to save data\ndata: ndarray, data to save\ngroup_name: if None, save data to file[key] else save data to file[group_name][key]\n'''if group_name is None:print_key = keyelse:print_key = '%s\' in group \'%s'%(key, group_name)for ii in range(mpi_size):if ii == mpi_rank:for _ in range(10):try:#raise IOErrorwith h5.File(filename, 'a') as filein:if group_name is None:filein[key] = dataelif group_name in filein.keys():filein[group_name][key] = dataelse:filein.create_group(group_name)filein[group_name][key] = dataprint('Rank %d save dataset \'%s\' into %s!'%(mpi_rank, print_key, filename))time.sleep(0.5)breakexcept IOError as e:print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))time.sleep(0.5)else:raise IOError('Rank %d cannot save \'%s\' into %s!'%(mpi_rank, print_key, filename))mpi_comm.barrier()def split_uneven_array(data, root=0, axis=0):'''array_split and then scatter the splitted array as python object'''if mpi_rank == root:data = np.asarray(data)data = np.array_split(data, mpi_size, axis=axis)new_data = mpi_comm.scatter(data, root=root)return new_datadef split_even_array(data, root=0, axis=0):'''array_split and then scatter the splitted array as numpy array'''if mpi_rank == root:data = np.asarray(data)shp = list(data.shape)assert shp[axis]%mpi_size==0, 'Axis %d with length %d cannot exactly divided by mpi size %d!'%(axis, shp[axis], mpi_size)dtype = data.dtypedata = np.array_split(data, mpi_size, axis=axis)data = np.asarray(data)else:dtype = Noneshp = Nonedtype = mpi_comm.bcast(dtype, root=root)shp = mpi_comm.bcast(shp, root=root)shp[axis] = process_size(shp[axis])new_data = np.empty(shp, dtype=dtype)mpi_comm.Scatter(data, new_data, root=root)#new_data = mpi_comm.scatter(data, root=root)return new_datadef split_array(data, root=0, axis=0):'''Split data at root process along axis, and scatter it to other process. If the axis can be split into equal length part, spread them as numpy array, otherwise spread them as python object.'''if mpi_size == 1:return np.asarray(data)if mpi_rank == root:data = np.asarray(data)shp = list(data.shape)if shp[axis]%mpi_size==0:even = Trueelse:even = Falseelse:even = Noneeven = mpi_comm.bcast(even, root=root)if even:print('Split and scatter as numpy array!')return split_even_array(data, root=root, axis=axis)else:print('Split and scatter as python object!')return split_uneven_array(data, root=root, axis=axis)def bcast_array(data, root=0):'''Broadcast array from root as numpy array, but do not need to allocate memory manually.'''if mpi_size == 1:return np.asarray(data)if mpi_rank == root:data = np.asarray(data)dtype = data.dtypeshp = data.shapeelse:dtype = Noneshp = Nonedtype = mpi_comm.bcast(dtype, root=root)shp = mpi_comm.bcast(shp, root=root)if mpi_rank != root:data = np.empty(shp, dtype=dtype)mpi_comm.Bcast(data, root=root)return datadef gather_array(data, root=0, axis=0, expand_dim=False, ascontiguous=True):'''Gather array from root, and concatenate them along axis. If expand_dim, use np.expand_dims to expand axis then concatenate along the new axis. If ascontiguous, ensure the returned array is contiguous using np.ascontiguousarray.'''if mpi_size == 1:if expand_dim:return np.expand_dims(data, axis=axis)else:return np.asarray(data)data = np.asarray(data)shp = list(data.shape)if expand_dim:print('Gather as numpy array and expand axis=%d!'%axis)even = Truenew_shp = [mpi_size] + shpall_shp = mpi_comm.gather(shp, root=root)all_shp = mpi_comm.bcast(all_shp, root=root)for ii in all_shp[1:]:assert np.array_equal(ii, all_shp[0]), 'Shape of data must be the same if expand_dim!'else:all_shp = mpi_comm.gather(shp, root=root)all_shp = mpi_comm.bcast(all_shp, root=root)shp0 = all_shp[0]even = Truetotal_len = shp0[axis]for ii in all_shp[1:]:assert len(shp0) == len(ii), 'Data from different mpi process should have the same number of dimensions! Shapes are: %s'%all_shpshp1 = shp0.copy()shp2 = ii.copy()del shp1[axis]del shp2[axis]assert np.array_equal(shp1, shp2), 'Data from different mpi process should have the same shape except for the merge axis! Shapes are: %s'%all_shpif ii[axis] != shp0[axis]:even = Falsetotal_len += ii[axis]if even:print('Gather as numpy array!')new_shp = shp0.copy()del new_shp[axis]new_shp = [total_len] + new_shpelse:print('Gather as python object!')if even:if mpi_rank == root:new_data = np.empty(new_shp, dtype=data.dtype)else:new_data = Noneif expand_dim:data = np.expand_dims(data, 0)else:data = np.ascontiguousarray(np.moveaxis(data, axis, 0))mpi_comm.Gather(data, new_data, root=root)if mpi_rank == root:new_data = np.moveaxis(new_data, 0, axis)if ascontiguous:new_data = np.ascontiguousarray(new_data)return new_dataelse:new_data = mpi_comm.gather(data, root=root)if mpi_rank == root:new_data = np.concatenate(new_data, axis=axis)return new_dataif __name__ == '__main__':if mpi_rank == 1:with h5.File('test.hdf5', 'w') as filein:passwith h5.File('test1.hdf5', 'w') as filein:passa = np.random.rand(10,2, mpi_rank+3, 6, 5)axis = 2parallel_save_dataset('test.hdf5', 'a', a, group_name='b', axis=axis)parallel_save_dataset('test1.hdf5', 'a', a, axis=-3)a1 = mpi_comm.gather(a, root=0)if mpi_rank == 0:a1 = np.concatenate(a1, axis=axis)with h5.File('test.hdf5', 'r') as filein:print(np.abs(filein['b']['a'][:]-a1).max())with h5.File('test1.hdf5', 'r') as filein:print(np.abs(filein['a'][:]-a1).max())exit()#if mpi_rank == 1:#    with h5.File('test.hdf5', 'w') as filein:#        pass#    a = np.random.rand(10, 2000, 800)#else:#    a = None######from timeit import timeit##def c1():##    b = split_even_array(a, root=1, axis=-1)##def c2():##    b = split_uneven_array(a, root=1, axis=-1)####print(mpi_rank, timeit(c2, number=20), 2)##print(mpi_rank, timeit(c1, number=20), 1)######exit()##b = split_array(a, root=1, axis=-1)##a = mpi_comm.bcast(a, root=1)#a = bcast_array(a, root=1)#print(mpi_rank, b.shape)#print(np.abs(a[...,a.shape[-1]//mpi_size*mpi_rank:a.shape[-1]//mpi_size*(mpi_rank+1)] - b).max())#parallel_save_dataset('test.hdf5', 'a', b, axis=-1)#parallel_save_dataset('test.hdf5', 'a', b, group_name='test', axis=-1)#b1 = parallel_load_dataset('test.hdf5', 'a', group_name='test', axis=-1)#b2 = parallel_load_dataset('test.hdf5', 'a', axis=-1)#print(np.abs(b1 - b).max())#print(np.abs(b2 - b).max())#if mpi_rank == 0:#    with h5.File('test.hdf5', 'r') as filein:#        print(np.abs(a - filein['a'][:]).max())#        print(np.abs(a - filein['test']['a'][:]).max())##exit()##if mpi_rank == 0:#    a = np.random.rand(mpi_size, 30)#    with h5.File('test.hdf5', 'w') as filein:#        pass#else:#    a = None#a = mpi_comm.scatter(a, root=0)#parallel_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a)#parallel_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a, group_name='aaa')#parallel_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a, group_name='aaa%d'%mpi_rank)#with h5.File('test.hdf5', 'r') as filein:#    a1 = filein['%d'%mpi_rank][:]#    a2 = filein['aaa']['%d'%mpi_rank][:]#    a3 = filein['aaa%d'%mpi_rank]['%d'%mpi_rank][:]#    print(np.abs(a1-a).max())#    print(np.abs(a2-a).max())#    print(np.abs(a3-a).max())#exit()axis = 1expand_dim = Trueroot = 1#a = np.random.rand(10, 3, 20)np.random.seed(mpi_rank+1)a = np.random.rand(10, 3, 20)#a = np.random.rand(10, mpi_rank+1, 20)print(np.shape(a), mpi_rank)a = gather_array(a, root=root, axis=axis, expand_dim=expand_dim)print(np.shape(a), mpi_rank)if mpi_rank == root:b = []for ii in range(mpi_size):np.random.seed(ii+1)b.append(np.random.rand(10, 3, 20))#b.append(np.random.rand(10, ii+1, 20))if expand_dim:b[-1] = np.expand_dims(b[-1], axis=axis)b = np.concatenate(b, axis=axis)print(np.abs(a - b).max())

mpi4py的wrapper相关推荐

  1. java windows wrapper_Java Service Wrapper 使用(windows)

    1       简介 最近项目中需要做一个Windows系统服务,记录一下使用过程. Java Service Wrapper 可以将Java程序包装成系统服务,这样就可以随着系统的运行而自动运行.J ...

  2. Apache Unable to find the wrapper https - did you forget to enable it when you configured PHP?

    微信小程序开发交流qq群   173683895    承接微信小程序开发.扫码加微信. Apache Unable to find the wrapper "https" - d ...

  3. tcp wrapper

    tcp wrapper概述: tcp wrapper同iptables一样都是网络资源访问器,工作在传输层只对工作在TCP协议的部分服务做访问控制:tcp wrapper是一个库文件即libwrap. ...

  4. mpi4py多进程实例/举例

    前言: 看了那么多关于mpi4py使用的,却没见到一个能够举例在实际情况中的使用,笔者也是初学者,于是花了一整个下午来找例子并详细解答,希望能帮助想用mpi4py的后来者 提醒:这里不讨论如何使用mp ...

  5. Mybatis-Plus实战中的几个条件构造器Wrapper用法

    Mybatis-Plus实战中的几个条件构造器Wrapper用法 其实Wrapper有很多其他的方法,组合起来也是殊途同归,大家可以自己点开源码去查看一些方法的使用说明 @Testvoid conte ...

  6. Swift Property Wrapper 属性包装器

    @propertyWrapper属性包装器:在定义存储属性时添加一个分离层,代表该属性被包装起来,且在包装器内部可以做一些事情.把一些通用复用的代码放在了包装器中,比如线程安全检查或者数据存储到数据库 ...

  7. weblogic.jdbc.wrapper.Blob_oracle_sql_BLOB cannot be cast to oracle.sql.BLOB 解决方法

    weblogic.jdbc.wrapper.Blob_oracle_sql_BLOB cannot be cast to oracle.sql.BLOB 解决方法 参考文章: (1)weblogic. ...

  8. string转date类型_10:Wrapper;String;Date;Math;File;Enumeration;Syst

    1. 包装类的基本用法 2.自动装箱和拆箱 1. 包装类的基本用法 1.1 为什么需要包装类(Wrapper Class)
 Java 并不是纯面向对象的语言. Java 语言是一个面向对象的语言,但 ...

  9. android log4,GitHub - oronno/log4android: Log4Android - Simple Logging Wrapper Library for Android

    log4android Log4Android - Simple Logging Wrapper Library for Android Tired writing TAG each time wri ...

最新文章

  1. svn更新maven项目报错_使用svn管理Maven项目的方法步骤
  2. 我们离爱因斯坦想了解的“上帝的思想”,还有多远?
  3. usb-key登录windows+远程桌面
  4. SwiftUI3.0封装Lottie动画库
  5. Thinkphp 配置不用输入index.php
  6. java禁止修改map_Java中实现不可变Map
  7. Swift UI开发初探
  8. php7.0 freetype_php7.0.5安装教程
  9. 计算机科学工程哲学学位,2020年剑桥大学硕士读多久
  10. 迪士尼收购福克斯,传媒巨头江山瓦解?
  11. eclipse android环境搭建,Eclipse Android开发环境搭建教程
  12. 泰文Unicode编码表及排版规则
  13. a标签的href属性 download属性
  14. 计算机科学对自然观的影响,浅谈自然辩证法对计算机科学研究的意义
  15. PS技巧三------五彩斑斓的黑色(滤镜---镜头光晕和波浪|||||混合选项---柔光)
  16. 更新驱动后重启黑屏且进不了bios时的一个解决办法
  17. [转] ThreeJS中,那些会让阴影失效的操作
  18. photoshop给照片去斑的一些办法
  19. css和js3d粒子,使用EaselJS实现的3D球形粒子运动
  20. 小米机器人履带双轮平衡_小米米兔机器人评测:一个站在平衡车上的机器人

热门文章

  1. js字符串拼接效率问题
  2. 跨平台SSH软件-Termius
  3. HbuilderX实现ios真机运行uniapp教程
  4. 华为鸿蒙2.0系统是安卓吗,华为鸿蒙2.0可以替代安卓吗,华为鸿蒙2.0优势在哪
  5. CustomerResourceGrid
  6. MACD改良抓牛神器 通达言指标公式 副图 源码 无加密 无未来
  7. 建造者2全部岛屿_勇者斗恶龙:建造者2空荡岛相关任务及剧情攻略分享
  8. Excel学习日记:L16-vlookup函数绝对参照设定
  9. 【中国海洋大学】考研初试复试资料分享
  10. C语言经典编程题100例(21-40)