1、先在pytroch中把torch.save的神经网络参数文件转化为git.trace格式

def convertTracedScriptModule(opt):# load pretraind netif opt.input_size >=256:G = network.Generator(opt.in_ch, opt.out_ch, opt.ngf)# D = network.Discriminator(opt.in_ch, opt.out_ch, opt.ndf)else:G = network.Generator_Small(opt.in_ch, opt.out_ch, opt.ngf)# D = network.Discriminator_Small(opt.in_ch , opt.out_ch, opt.ndf)if opt.weights == "not_use":print("please set the path of pretrained net")returnckpt = torch.load(opt.weights)G.load_state_dict(ckpt['G_model'], strict=False)epoch = ckpt['epoch']G.eval()print(G)batch_size = 1example = torch.rand(batch_size,opt.in_ch,opt.input_size, opt.input_size)traced_script_module = torch.jit.trace(G, example)output = traced_script_module(torch.ones(batch_size,opt.in_ch,opt.input_size, opt.input_size))print(output)# save loaded netmodule_name = './traced_script_module/%s_epoch_%d_in_ch_%d_out_ch%d_input_size_%d.pt'%(opt.dataset, epoch, opt.in_ch, opt.out_ch, opt.input_size)traced_script_module.save(module_name)print(module_name + " is saved")

2、在c++中加载traced_script_module

定义一个简单的类,TraceScriptModel,用来加载预训练的网络,并执行预测,头文件如下

其中load用来加载预训练的模型,runForward方法执行预测,batch_size为 1是,返回结果是多通道的数据,即三维连续数据,长度是out_channel*out_rows*out*cols, runForwardFacies也执行预测,返回结果是多通道合并后的二维离散数据,离散值是满足门槛值的通道号。

#pragma once#include <string>
#include <vector>
#include <memory>
#include <torch/script.h>class TracedScriptModel
{
public:TracedScriptModel();TracedScriptModel(int inChannel=1, int inputSize=128);~TracedScriptModel();int load(std::string& modelPath);std::vector<float> runForward(std::vector<float>&inputSeis);std::vector<int> runForwardFacies(std::vector<float>& inputSeis);private:std::shared_ptr<torch::jit::Module> _module = nullptr;int _inCh=1, _inputSize=128;float _cutoff = 0.3;
};
#include "TracedScriptModel.h"TracedScriptModel::TracedScriptModel()
{
}TracedScriptModel::TracedScriptModel(int inChannel, int inputSize):_inCh(inChannel), _inputSize(inputSize)
{}TracedScriptModel::~TracedScriptModel()
{
}int TracedScriptModel::load(std::string& modelPath)
{// Deserialize the ScriptModule from a file using torch::jit::load()._module = std::make_shared<torch::jit::Module>(torch::jit::load(modelPath));assert(_module != nullptr);std::cout << "moudle is loaded ok\n";for (const auto& subModule : _module->modules()) {for (const auto& parms: subModule.named_parameters()) {std::cout << parms.name << std::endl;//std::cout << parms.value << std::endl;}}return 0;
}std::vector<float> TracedScriptModel::runForward( std::vector<float>& inputSeis)
{// transfer data to tensortorch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);tensor_input = tensor_input.permute({ 2, 0, 1 });//tensor_input = tensor_input.toType(torch::kFloat);//tensor_input = tensor_input.div(255);tensor_input = tensor_input.unsqueeze(0);//std::cout << tensor_input << std::endl;// 网络前向计算// Execute the model and turn its output into a tensor.at::Tensor output = _module->forward({ tensor_input }).toTensor();//std::cout << output << std::endl;std::vector<float> v(output.data_ptr<float>(), output.data_ptr<float>() + output.numel());return std::move(v);}std::vector<int> TracedScriptModel::runForwardFacies(std::vector<float>& inputSeis)
{// transfer data to tensortorch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);tensor_input = tensor_input.permute({ 2, 0, 1 });//tensor_input = tensor_input.toType(torch::kFloat);//tensor_input = tensor_input.div(255);tensor_input = tensor_input.unsqueeze(0);//std::cout << tensor_input << std::endl;// 网络前向计算// Execute the model and turn its output into a tensor.at::Tensor output = _module->forward({ tensor_input }).toTensor();//std::cout << output << std::endl;auto tsize = output.sizes();int out_ch = tsize[1];int out_rows = tsize[2];int out_cols = tsize[3];at::Tensor facies = at::zeros({ out_rows, out_cols }, torch::kInt32);at::Tensor ones = at::ones({ out_rows, out_cols }, torch::kInt32);for (int ch = 0; ch < out_ch; ch++) {facies = at::where(output[0][ch] > _cutoff, ones * ch, facies);}std::vector<int> v(facies.data_ptr<int>(), facies.data_ptr<int>() + facies.numel());return v;
}

3、模块调用和结果显示

设置预训练模型的路径,加载预训练模型,输入所需数据,得到预测结果。通过matplotlibcpp进行显示

void testPreNet(int argc, char* argv[])
{int inCh = 1;int inputSize = 128;int outCh = 3;TracedScriptModel tsm{ inCh, inputSize };std::string modelPath = R"(H:\deeplearning\pix2pix_geomodel\pix2pix_lyf\traced_script_module\yuejin_epoch_442_in_ch_1_out_ch3_input_size_128.pt)";tsm.load(modelPath);//std::vector<float> inputDat(128 * 128, 0.5);//tsm.runForward(inputDat);std::string seisNpyFile = R"(H:\yuejin\yuejin_sample\yuejin_facies_freq6_segI\seismic_848.npy)";auto inputSeis = EclipseModel::IModel2D<float>::loadNpy(seisNpyFile);inputSeis /= 800;/*std::vector<float> output = tsm.runForward(inputDat.grid());cout << output.size() << endl;plt::imshow(output.data(), inputSize, inputSize, 1);plt::show();plt::imshow(output.data() + inputSize * inputSize, inputSize, inputSize, 1);plt::show();plt::imshow(output.data() + inputSize * inputSize * 2, inputSize, inputSize, 1);plt::show();*//*plt::imshow(output.data(), inputSize, inputSize, 3);plt::show();*/std::vector<int>output = tsm.runForwardFacies(inputSeis.grid());std::vector<unsigned char> facies(output.size());for (int i = 0; i < output.size(); i++) {facies[i] = output[i];}plt::imshow(facies.data(), inputSize, inputSize, 1);plt::show();
}

效果如下

在c++中利用libtorch部署python中训练的pytorch网络相关推荐

  1. java 中利用subString 截取字符串中第三个/后面的内容,并将/用代替

    原文地址为: java 中利用subString 截取字符串中第三个"/"后面的内容,并将/用>代替 private String extractString(String ...

  2. python中readlines函数用法,python中read() readline()以及readlines()用法

    我们谈到"文本处理"时,我们通常是指处理的内容.Python 将文本文件的内容读入可以操作的字符串变量非常容易.文件对象提供了三个"读"方法: .read(). ...

  3. 【科学文献计量】将Endnote中的文献读入python中进行数据分析,并顺便将结果保存为Excel文件,并封装函数直接调用

    将Endnote中的文献读入python中进行数据分析,并顺便将结果保存为Excel文件 1 需求 2 功能完成 2.1 文献下载 2.2 文献导入到Endnote 2.3 文献导出 2.4 文件加载 ...

  4. python中webdriver_Linux上部署python+selenium+webdriver常见问题解决方案

    折腾了几天的WechatScraper终于部署到Linux服务器上能生产使用了 用篇文章来记录下部署过程中遇到的各种bug和坑. 1. 运行问题 webdriver在有GUI界面的系统上运行是只需要下 ...

  5. python中的range函数|python中的range函数|range()函数详解|Python中range(len())的用法

    本期目录 一.range()传递不同的参数 1.传递一个参数时 2.传递两个参数时 3.传递三个参数时 二.使用 range() 构建 for 循环 三.遍历列表时使用 range(len()) 的用 ...

  6. python中元组_理解python中的元组

    理解 python 中的元组 引言 在 Python 中元组是这样的: 元组是是这样一种数据结构:不变的或者不可改变的(简单来说不能重新赋值) .元素的有序序列.因为元组是 不变的,所以他的数值是不能 ...

  7. python中怎么输出中文-python中使用print输出中文的方法

    看Python简明教程,学习使用print打印字符串,试了下打印中文,不行. 编辑环境:IDLE 上网搜了下解决办法,各种说法,试了两种: print u"学习" print (u ...

  8. python应用中调用spark_在python中使用pyspark读写Hive数据操作

    1.读Hive表数据 pyspark读取hive数据非常简单,因为它有专门的接口来读取,完全不需要像hbase那样,需要做很多配置,pyspark提供的操作hive的接口,使得程序可以直接使用SQL语 ...

  9. linux中popen汉字乱码,Python中使用subprocess.Popen返回值乱码解决方案

    Python中使用subprocess.Popen返回值乱码解决方案 问题描述 在python 2.7中,使用subprocess.Popen()调用*nix命令,并通过管道,获取其输出,并将其返回值 ...

最新文章

  1. 高德地图关键字搜索oc版
  2. c语言中 d的用法,C语言中的#define用法总结
  3. Effective Java 之个人总结
  4. Intellij IDEA 提交代码到远程GitHub仓库
  5. 大咖茶话会 | 与原新浪微博副总裁零距离沟通
  6. python保存数据到本地_Python爬虫入门 | 6 将爬回来的数据存到本地
  7. python声明编码为gbk_Python字符串编码坑彻底详细解决
  8. POJ - 3842 An Industrial Spy dfs(水)
  9. “约见”面试官系列之常见面试题第二十五篇之对vue-router的理解(建议收藏)
  10. 谷粒商城基础篇爬坑笔记--项目导入intellij IDEA后pom.xml无法识别为maven文件和程序包import com.atguigu.common.XXX不存在两个问题解决方法
  11. opencv 显示程序运行时间
  12. Python基础笔记(一)数据类型、变量、字符串
  13. C++11中内联函数(inline)
  14. python辅助 sublime_Sublime+python设置
  15. html图片轮播_前端轮播图怎么做?JavaScript来帮你轻松搞定
  16. 如何优化深度学习模型
  17. MODFLOW Flex、GMS、FEFLOW、HYDRUS实践应用
  18. 免费版的 IDEA 如何使用 Tomcat
  19. 【方法篇】S-棕榈酰化蛋白修饰质谱鉴定方法
  20. 浅谈中国现货市场环境对期货市场发展的影响

热门文章

  1. 一个外行人学习一下TBOX 测试
  2. SWIFT PLM 功能介绍-项目管理的特色
  3. 协议 框架 解决方案
  4. 爬虫日记(40):Flask的模板介绍
  5. Flask中Jinja2模板|如何在Jinja2中格式化一个日期
  6. IMU的FSYNC脚的使用说明
  7. Encoder-Decoder -编码器解码器架构(RNN循环神经网络)
  8. 贝塞尔公式推导与物体跟随复杂曲线的轨迹运动
  9. 【store商城项目09】商品热销排行
  10. 深度学习框架量化感知训练的思考及OneFlow的一种解决方案