cpp的例子

device_malloc

  • cpp没有用具体数值初始化 float *d_from_tensor = NULL;device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/cpp/transformer_fp32.cc#L35-L38 直接用的cudaMalloc
void device_malloc(float** ptr, int size) // cudaMalloc函数为什么是二级指针的解释https://blog.csdn.net/CaiYuxingzzz/article/details/121112273
{cudaMalloc((void**)ptr, sizeof(float) * size);
}

allocator

  • allocator用于分配attr_out_buf_
https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/bert_encoder_transformer.h#L131-L135
buf_ = reinterpret_cast<DataType_*>(allocator_.malloc(sizeof(DataType_) * buf_size * 6));
  • 然后将这些参数和encoder_param打包成multi_head_init_param
    在初始化(encoder_transformer_->initialize)时传给attention_->initialize(multi_head_init_param);
    attention_->initialize则只需将传入的参数初始化给attention对象的参数,等forward时调用自己的参数
接口包含两个方法malloc,free
class IAllocator{public:virtual void* malloc(size_t size) const = 0;virtual void free(void* ptr) const = 0;
};
//AllocatorTypeyouenum class AllocatorType{CUDA, TF}; 用的应该是CUDA的
template<>
class Allocator<AllocatorType::CUDA> : public IAllocator{const int device_id_;public:Allocator(int device_id): device_id_(device_id){}void* malloc(size_t size) const {void* ptr = nullptr;int o_device = 0;check_cuda_error(get_set_device(device_id_, &o_device));check_cuda_error(cudaMalloc(&ptr, size));check_cuda_error(get_set_device(o_device));return ptr;}void free(void* ptr) const {int o_device = 0;check_cuda_error(get_set_device(device_id_, &o_device));check_cuda_error(cudaFree(ptr));check_cuda_error(get_set_device(o_device));return;}
};
fastertransformer::Allocator<AllocatorType::CUDA> allocator(0); // 0是device_id_

encoder_param

  • EncoderInitParam encoder_param; //init param here 包含参数的结构体,成员记录了GPU数据的地址

initialize

  BertEncoderTransformer<EncoderTraits_> *encoder_transformer_ = new BertEncoderTransformer<EncoderTraits_>(allocator, batch_size, from_seq_len, to_seq_len, head_num, size_per_head);encoder_transformer_->initialize(encoder_param);

trt_plugin的例子

将数值放入vector

  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/tensorRT/transformer_trt.cc#L108-L136
  • 先分配地址
    host_malloc(&h_attr_kernel_Q, hidden_dim * hidden_dim);
  • 然后进行赋值
    h_attr_kernel_Q[i] = 0.001f;
   std::vector<T* > layer_param;layer_param.push_back(h_attr_kernel_Q);将值打包params.push_back(layer_param);}cudaStream_t stream;cudaStreamCreate(&stream);TRT_Transformer<T>* trt_transformer = new TRT_Transformer<T>(batch_size, seq_len, head_num, hidden_dim, layers);trt_transformer->build_engine(params);trt_transformer->do_inference(batch_size, h_from_tensor, h_attr_mask, h_transformer_out, stream);delete trt_transformer;
  • 构建TRT_Transformer时会调用算子插件,权重在void build_engine(std::vector<std::vector<T* > > &weights)时传入
    https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/trt_model.h#L75-L77
auto plugin = new TransformerPlugin<T>(hidden_dim_, head_num_, seq_len_, batch_size_, point2weight(weights[i][0], hidden_dim_ * hidden_dim_),
  • 创建TransformerPlugin实例时会传入权重
TransformerPlugin(int hidden_dim, int head_num, int seq_len, int max_batch_size,const nvinfer1::Weights &w_attr_kernel_Q,...
  • 这里就是和cpp例子的不同了,其使用权重w_attr_kernel_Q
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L103
cudaMallocAndCopy(d_attr_kernel_Q_, w_attr_kernel_Q, hidden_dim * hidden_dim);
  • cudaMallocAndCopy定义在https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L338-L352
    static void cudaMallocAndCopy(T *&dpWeight, const nvinfer1::Weights &w, int nValue) {assert(w.count == nValue);check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice));T* data = (T*)malloc(sizeof(T) * nValue);cudaMemcpy(data, dpWeight, sizeof(T) * nValue, cudaMemcpyDeviceToHost);}static void cudaMallocAndCopy(T*&dpWeight, const T *&dpWeightOld, int nValue) {check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice));}

cg

  • https://github.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py

FasterTransformer 005 初始化:如何将参数传给模型?相关推荐

  1. 通过BeanShell获取UUID并将参数传递给Jmeter

    有些HTTPS请求报文的报文体中包含由客户端生成的UUID,在用Jmeter做接口自动化测试的时候,因为越过了客户端,直接向服务器端发送报文,所以,需要在Jmeter中通过beanshell获取UUI ...

  2. 数组作为函数的参数传参时,数组名会退化为指针

    1.数组作为函数的参数传参时,数组名会退化为指针 数组作为函数的参数传参时,数组名会退化为指针,数值传参时,需要把数值的长度一起传过去,另外,sizeof()运算符包含字符串的哨兵'/0',而strl ...

  3. java+hadoop配置参数_将Hadoop参数传递给Java代码

    我有一个Uber jar执行一些级联ETL任务. jar的执行方式如下: hadoop jar munge-data.jar 我希望在作业启动时将参数传递给jar,例如 hadoop jar mung ...

  4. 如何将命令行参数传递给Node.js程序?

    我有一个用Node.js编写的Web服务器,我想使用一个特定的文件夹启动. 我不确定如何在JavaScript中访问参数. 我正在像这样运行节点: $ node server.js folder 这是 ...

  5. GoJS超详细入门(插件使用无非:引包、初始化、配参数(json)、引数据(json)四步)...

    GoJS超详细入门(插件使用无非:引包.初始化.配参数(json).引数据(json)四步) 一.总结 一句话总结:插件使用无非:引包.初始化.配参数(json).引数据(json)四步. 1.goj ...

  6. DL之DNN优化技术:采用三种激活函数(sigmoid、relu、tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化

    DL之DNN优化技术:采用三种激活函数(sigmoid.relu.tanh)构建5层神经网络,权重初始值(He参数初始化和Xavier参数初始化)影响隐藏层的激活值分布的直方图可视化 目录

  7. DL之DNN优化技术:自定义MultiLayerNet【5*100+ReLU】对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化、He参数初始化)性能差异

    DL之DNN优化技术:自定义MultiLayerNet[5*100+ReLU]对MNIST数据集训练进而比较三种权重初始值(Xavier参数初始化.He参数初始化)性能差异 导读 #思路:观察不同的权 ...

  8. DL之DNN优化技术:DNN中参数初始化【Lecun参数初始化、He参数初始化和Xavier参数初始化】的简介、使用方法详细攻略

    DL之DNN优化技术:DNN中参数初始化[Lecun参数初始化.He参数初始化和Xavier参数初始化]的简介.使用方法详细攻略 导读:现在有很多学者认为,随着BN层的提出,权重初始化可能已不再那么紧 ...

  9. python get请求 url传参_用Python-get方法向页面发起请求,参数传不进去是怎么回事...

    源自:4-1 接口测试工具-python-get接口实战 用Python-get方法向页面发起请求,参数传不进去是怎么回事 #-*-coding:utf-8-*- import urllib impo ...

最新文章

  1. 题目1189:还是约瑟夫环
  2. UIButton拖动响应事件,距离问题
  3. JAVA——TCP连接中Socket的正确关闭方式
  4. vc6.0垃圾文件清理工具_MacClean360 for Mac(mac系统清理软件)
  5. VTK:PolyData之Casting
  6. 程序员十大心愿,程序员:你这么了解我的心声的嘛!
  7. python 解析xml 文件: SAX方式
  8. Silverlight+WCF 新手实例 象棋 游戏房间列表(十三)
  9. nginx搭建tomcat集群
  10. 【Python】利用MD5文件去重
  11. Codeforces - Robot Rapping Results Report
  12. 微信突破版本限制永久设置透明/半透明头像
  13. ffmpeg学习 函数分析swr_convert
  14. C#如何立即回收内存
  15. unity打包安卓(anroid)APK及安卓环境设置
  16. 知衣科技:夫妻搭档,创业是件骨子里的事
  17. pcf8563 C语言编程
  18. nmt模型源文本词项序列_TensorFlow NMT的数据处理过程
  19. 河工大大一c语言题库,河工大二级C语言题库.doc
  20. 数据结构-带头双向循环链表(增删查改详解)

热门文章

  1. 推荐算法-皮尔逊相关系数的相似度
  2. OSChina 周五乱弹 ——爱酱,我们还是在普通人类中夺冠吧!
  3. php yii调试,yii框架中debug怎么用
  4. C 编译错误 及解决方法总结
  5. java jsch jar_使用JSch从Java在远程计算机上执行命令
  6. oracle tabe unlock_Oracle数据库之统计信息锁住导致收集统计信息失败引起sql执行异常...
  7. office 删除密钥的方式
  8. Linux账号权限分离的必要性
  9. 七牛云 转码_七牛云的音频转码,微信的speex音频转码为mp3格式
  10. 如何用Pact进行微服务集成测试(二)