文章目录

  • 1 概述
  • 2 Tensor的基本操作
    • 2.1 Tensor的初始化
      • (1)通过数组创建
      • (2)通过默认方法创建
      • (3)通过其他的`tensor`创建
      • (4)通过`opencv::core::Mat`创建
    • 2.2 Tensor的属性
    • 2.3 Tensor的运算
      • (1)改变device
      • (2)获取值(indexing and slicing)
      • (3)合并tensors
      • (4)四则运算
  • 参考资料

1 概述

在使用rust进行torch模型部署时,不可避免地会用到tch-rs。但是tch-rs的文档太过简洁,和没有一样,网上的资料也少得可怜,很多操作需要我们自己去试。这些内容虽然简单,但是自己找起来很费时间。

这篇文章总结了如何使用tch-rs进行tensor的基本操作。讲述的内容参考了pytorch的tensor教程。

运行环境:

[dependencies]
tch = "0.7.0"
opencv = "0.63"

2 Tensor的基本操作

用到的库

use std::iter;use opencv::prelude::*;
use opencv::core::{Mat, Scalar};
use opencv::core::{CV_8UC3};
use tch::IndexOp;
use tch::{Device, Tensor};

2.1 Tensor的初始化

(1)通过数组创建

let t = Tensor::of_slice::<i32>(&[1, 2, 3, 4, 5]);
t.print();
// vector也是一样的
let v = vec![1,2,3];
let t = Tensor::of_slice::<i32>(&v);
t.print();
// 2d vector
let v = vec![[1.5,2.0,3.9,4.4], [3.1,4.3,5.1,6.9]];
let v:Vec<f32> = v.iter().flat_map(|array| array.iter()).cloned().collect();
let data = unsafe{std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<f32>())
};
let t = Tensor::of_data_size(data, &[2,4], tch::Kind::Float);
t.print();

print的结果是

 12345
[ CPUIntType{5} ]123
[ CPUIntType{3} ]1.5000  2.0000  3.9000  4.40003.1000  4.3000  5.1000  6.9000
[ CPUFloatType{2,4} ]

(2)通过默认方法创建

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::ones(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::zeros(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
t.print();

print的结果是

 1.0522  0.6981  0.92360.2324 -1.1048 -2.5820
[ CPUFloatType{2,3} ]1  1  11  1  1
[ CPUFloatType{2,3} ]0  0  00  0  0
[ CPUFloatType{2,3} ]0  1  23  4  5
[ CPUFloatType{2,3} ]

(3)通过其他的tensor创建

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
let t = t.rand_like();
t.print();

print的结果是

 0.3376  0.1885  0.34150.5135  0.8321  0.4140
[ CPUFloatType{2,3} ]

(4)通过opencv::core::Mat创建

这可以用在opencv读取图像后,转为torch tensor。当然tch-rs本身也有各种读取图片的方式,可见tch::vision::image。这里介绍两种方法,一种通过tch::Tensor::f_of_blob,一种通过tch::Tensor::of_data_size

// 创建一个(row, col, channel)=(2, 3, 3)=(height, width, channel)的Mat
let mat = Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
// 获取mat的size,这里的结果是[2, 3, 3]
let size: Vec<_> = mat.mat_size().iter().cloned().map(|dim| dim as i64).chain(iter::once(mat.channels() as i64)).collect();
// 获取每个dimension的stride,这里的结果是[9, 3, 1]
let strides = {let mut strides: Vec<_> = size.iter().rev().cloned().scan(1, |prev, dim| {let stride = *prev;*prev *= dim;Some(stride)}).collect();strides.reverse();strides
};
// 构建tensor
let t = unsafe {let ptr = mat.ptr(0).unwrap() as *const u8;tch::Tensor::f_of_blob(ptr, &size, &strides, tch::Kind::Uint8, tch::Device::Cpu).unwrap()
};
t.print();

print的结果是

(1,.,.) = 3  2  13  2  13  2  1(2,.,.) = 3  2  13  2  13  2  1
[ CPUByteType{2,3,3} ]

还有一种比较简洁的转换方法

let mut mat = Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
let h = mat.size().unwrap().height;
let w = mat.size().unwrap().width;
let data = mat.data_bytes_mut().unwrap();
let t = tch::Tensor::of_data_size(data, &[h as i64, w as i64, 3], tch::Kind::Uint8);
t.print();

print的结果也是

(1,.,.) = 3  2  13  2  13  2  1(2,.,.) = 3  2  13  2  13  2  1
[ CPUByteType{2,3,3} ]
test tensor_ops::init_ops ... ok

2.2 Tensor的属性

用tch::Tensor的print()方法可打印出数据的所有属性,但是想要获取到这些属性,需要用其他的方法。

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
println!("size of the tensor: {:?}", t.size());
println!("kind of the tensor: {:?}", t.kind());
println!("device on which the tensor is located: {:?}", t.device());

打印的结果是

size of the tensor: [2, 3]
kind of the tensor: Float
device on which the tensor is located: Cpu

2.3 Tensor的运算

(1)改变device

.to().to_device()这两个方法都可以。

let mut t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
if tch::Cuda::is_available(){t = t.to(Device::Cuda(0));println!("change device to {:?}", t.device());
}
t = t.to_device(Device::Cpu);
println!("change device to {:?}", t.device());

如果是有cuda,且安装了cuda版本的tch-rs的话,就会打印出

change device to Cuda(0)
change device to Cpu

(2)获取值(indexing and slicing)

这个在tch-rs的例子中有很多,详见tests/tensor_indexing.rs。这里列几种常用的。

通过.i()进行索引

let tensor = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
println!("original tensor:");
tensor.print();
println!("tensor.i(0):");
tensor.i(0).print();
println!("tensor.i((1, 1)):");
tensor.i((1, 1)).print();
println!("tensor.i((.., 2)):");
tensor.i((.., 2)).print();
println!("tensor.i((.., -1)):");
tensor.i((.., -1)).print();
println!("tensor.i((.., [2, 0])):");
let index: &[_] = &[2, 0];
tensor.i((.., index)).print();

打印的结果是

original tensor:0  1  23  4  5
[ CPUFloatType{2,3} ]
tensor.i(0):012
[ CPUFloatType{3} ]
tensor.i((1, 1)):
4
[ CPUFloatType{} ]
tensor.i((.., 2)):25
[ CPUFloatType{2} ]
tensor.i((.., -1)):25
[ CPUFloatType{2} ]
tensor.i((.., [2, 0])):2  05  3
[ CPUFloatType{2,2} ]

通过.index()进行索引

let tensor = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
println!("original tensor:");
tensor.print();
let rows_select = Tensor::of_slice(&[0i64, 1, 0]);
let column_select = Tensor::of_slice(&[1i64, 2, 2]);
let selected = tensor.index(&[Some(rows_select), Some(column_select)]);
println!("selecte by row and column:");
selected.print();

打印的结果是

original tensor:0  1  23  4  5
[ CPULongType{2,3} ]
selecte by row and column:152
[ CPULongType{3} ]

(3)合并tensors

Tensor::f_cat不会生成新的axis,而Tensor::stack会生成新的axis。

let t1 = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let t2 = Tensor::arange_start(6, 12, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let tensor = Tensor::f_cat(&[t1.copy(), t2.copy()], 1).unwrap();
println!("using Tensor::f_cat");
tensor.print();
let tensor = Tensor::stack(&[t1.copy(), t2.copy()], 1);
println!("using Tensor::stack");
tensor.print();

打印的结果是

using Tensor::f_cat0   1   2   6   7   83   4   5   9  10  11
[ CPULongType{2,6} ]
using Tensor::stack
(1,.,.) = 0  1  26  7  8(2,.,.) = 3   4   59  10  11
[ CPULongType{2,2,3} ]

(4)四则运算

tch-rs对[+, -, *, /]都进行了重载,可以实现和标量的直接运算。涉及到dim的复杂运算可以用tensor来处理。下面以加法为例,其他与f_add对应的分别是f_subf_mulf_div

let tensor = Tensor::ones(&[2, 4, 3], (tch::Kind::Float, Device::Cpu));
tensor.print();
// add with scalar
let add_tensor = &tensor + 0.5;
add_tensor.print();
// add with tensor
let add_tensor = Tensor::of_slice::<f32>(&[1.0,2.0,3.0]).view((1,1,3));
let add_tensor = &tensor.f_add(&add_tensor).unwrap();
add_tensor.print();

打印的结果为

original tensor:
(1,.,.) = 1  1  11  1  11  1  11  1  1(2,.,.) = 1  1  11  1  11  1  11  1  1
[ CPUFloatType{2,4,3} ]
add with scalar:
(1,.,.) = 1.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.5000(2,.,.) = 1.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.5000
[ CPUFloatType{2,4,3} ]
add with tensor:
(1,.,.) = 2  3  42  3  42  3  42  3  4(2,.,.) = 2  3  42  3  42  3  42  3  4
[ CPUFloatType{2,4,3} ]

参考资料

[1] https://github.com/LaurentMazare/tch-rs
[2] https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html#

tch-rs指南 - Tensor的基本操作相关推荐

  1. PyTorch | (3)Tensor及其基本操作

    PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch | (3)Tensor及其基本操作 Tensor attributes: 在tens ...

  2. pytorch方法,Tensor及其基本操作_重点

    由于之前的草稿都没了,现在只有重写-. 我好痛苦 本章只是对pytorch的常规操作进行一个总结,大家看过有脑子里有印象就好,知道有这么个东西,需要的时候可以再去详细的看,另外也还是需要在实战中多运用 ...

  3. pytorch Tensor及其基本操作

    转自: https://zhuanlan.zhihu.com/p/36233589 由于之前的草稿都没了,现在只有重写-. 我好痛苦 本章只是对pytorch的常规操作进行一个总结,大家看过有脑子里有 ...

  4. 从 X 入门Pytorch——环境安装建议,Tensor多种构造方式,Tensor的基本操作

    本文参加新星计划人工智能(Pytorch)赛道: https://bbs.csdn.net/topics/613989052 满打满算,入门CV的坑已经快一年了,现在忙着换模型,加模块,看效果. 但是 ...

  5. Tensor基础操作总结

    文章目录 前言 一.tensor是什么? 二.tensor的基本操作 1.tensor的创建 2.基础操作 3.索引与切片 4.广播机制 5.tensor和numpy的互相转换 6.在GPU上使用te ...

  6. PyTorch | (4)神经网络模型搭建和参数优化

    PyTorch | (1)初识PyTorch PyTorch | (2)PyTorch 入门-张量 PyTorch | (3)Tensor及其基本操作 PyTorch | (4)神经网络模型搭建和参数 ...

  7. 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型(附链接)

    来源:机器之心 本文约800字,建议阅读5分钟. 本文介绍了官方教程入门PyTorch的技巧训练. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框 ...

  8. 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框架.想发论文,不学 PyTor ...

  9. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

最新文章

  1. Eclipse如何更改包名后,批量修改文件的包名
  2. VC界面库BCGControlBar和Xtreme Toolkit详细对比评测
  3. Web开发(一)·期末不挂之第四章·CSS语法基础(CSS选择器选择器优先级各类样式表的使用方法)
  4. 控件 qml_Flat风格的Qml进度条
  5. JEECG 3.7.8 新版表单校验提示风格使用升级方法(validform 新风格漂亮,布局简单)
  6. Android开发--apk的生成
  7. 计算机专业论文选题网站方面,5大网站汇总,搞定新颖的计算机专业毕业设计网站汇总...
  8. SQLite Tutorial 4 : How to export SQLite file into CSV or Excel file
  9. 自动驾驶感知-车道线系列(二)——Canny边缘检测
  10. 5G面临的挑战和应用场景
  11. 2019最烂密码榜单出炉,教你设置神级密码!
  12. Numpy中的Boardcast机制
  13. Android FFmpeg视频转码并保存到本地
  14. 阿里云2022年双十一活动各云产品新购和续费优惠政策汇总
  15. 现代计算机领域出现了,时空道路网最近邻查询技术
  16. 华为老总身份彻底曝光,全世界感到害怕!
  17. 网约车2.0时代,首汽约车让AI实时“听懂”打车服务
  18. AutoJS4.1.0实战教程---京东领京豆
  19. 用shell手撸容器实现批量用openssl签证书
  20. 用JS实现的完美无限级联下拉菜单

热门文章

  1. HR:您好,您应聘的软件测试工程师岗位录取Offer会在三天之内发到您的邮箱。
  2. 软件测试面试题:什么是数据的对立性,有几个层次?
  3. 前端图片渲染性能优化与实践 — 图片懒加载
  4. C++天气预报小软件
  5. React 10分钟快速入门
  6. 层次化的设计(hierarchy design):概论
  7. P1500 丘比特的烦恼(KMMCMF)
  8. 如何高效的使用便利贴 win10便签贴工具居然可以这么好用
  9. 人民币贬值速度计算公式及应对措施
  10. 彻底搞懂Python切片操作_xing2516_新浪博客