今天试了下注册GPU支持的OP。
GPU内核
GPU内核分两部分实现:OpKernel和CUDA内核及其启动代码。

有时OpKernel的实现在CPU和GPU内核之间很常见,比如检查输入和分配输出。在这种情况下,建议的实施是:

  1. 定义在Device上模板化的OpKernel和张量的基本类型。

  2. 为了完成输出的实际计算,Compute函数调用模板函子结构。

  3. 该函数对CPUDevice的专门化定义在同一个文件中,但GPUDevice的专门化定义在.cu.cc文件中,因为它将与CUDA编译器一起编译。

在tensorflow/user_ops/下添加cuda_op_kernel.cu.cc,cuda_op_kernel.cc两个文件。

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "eigen3/unsupported/Eigen/CXX11/Tensor"__global__ void AddOneKernel(const int* in, const int N, int* out) {for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;i += blockDim.x * gridDim.x) {out[i] = in[i] + 1;}
}void AddOneKernelLauncher(const int* in, const int N, int* out) {AddOneKernel<<<32, 256>>>(in, N, out);
}#endif

这里定义的Eigen是一个高层次的C ++库,有效支持线性代数,矩阵和矢量运算,数值分析及其相关的算法。如果编译过程中说未定义eigen,可以自己在网上找个教程下载。然后将安装后产生的Eigen文件夹拷贝到/usr/local/include/下。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"using namespace tensorflow;REGISTER_OP("AddOne").Input("input: int32").Output("output: int32").Doc(R"doc(
Adds 1 to all elements of the tensor.
output: A Tensor.output = input + 1
)doc");void AddOneKernelLauncher(const int* in, const int N, int* out);class AddOneOp : public OpKernel {public:explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}void Compute(OpKernelContext* context) override {// Grab the input tensorconst Tensor& input_tensor = context->input(0);auto input = input_tensor.flat<int32>();// Create an output tensorTensor* output_tensor = NULL;OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),&output_tensor));auto output = output_tensor->template flat<int32>();// Set all but the first element of the output tensor to 0.const int N = input.size();// Call the cuda kernel launcherAddOneKernelLauncher(input.data(), N, output.data());}
};REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_GPU), AddOneOp);

编译GPU设备的内核:
输入以下命令

1. TF_INC=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
2. TF_LIB=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
3. nvcc -std=c++11 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
4. -I $TF_INC -I$TF_INC/external/nsync/public -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC5. g++ -std=c++11 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
6. cuda_op_kernel.cu.o -I $TF_INC -fPIC -lcudart

输入第6步后可能会报: mutex.h:25:22: fatal error: nsync_cv.h: 没有那个文件或目录
在代码中添加nsync_cv.h的路径,修改后输入:

cuda_op_kernel.cu.o -I $TF_INC -I /usr/local/lib/python3.5/dist-packages/tensorflow/include/external/nsync/public/ -fPIC -lcudart

若报/usr/bin/ld: 找不到 -lcudart
那么你需要导入导入这个库文件的路径,如果未安装CUDA库,则/usr/local/lib64需要在上面的第6个命令中明确指定路径。例如,-L /usr/local/cuda-8.0/lib64/如果您的CUDA已安装,请添加/usr/local/cuda-8.0。

cuda_op_kernel.cu.o -I $TF_INC -I$TF_INC/external/nsync/public -fPIC -lcudart -L$TF_LIB -ltensorflow_framework -L/usr/local/cuda-8.0/lib64/cuda_op_kernel.cu.o -I $TF_INC -I /usr/local/lib/python3.5/dist-packages/tensorflow/include/external/nsync/public/ -fPIC -lcudart

生成cuda_op_kernel.so文件后你就可以使用tf.load_op_library(‘你的.so文件路径/cuda_op_kernel.so’)来使用你新添加的OP啦。

import tensorflow as tffile = '/home/siat/Work/tensorflow-r1.4/tensorflow/user_ops/cuda_op_kernel.so'
cuda_op_module = tf.load_op_library(file)
with tf.Session(''):x = cuda_op_module.add_one([[6, 4], [2, 4]]).eval()
print(x)
#[[7 5][3 5]]

注意:我用的Python版本是3.5,若你是使用ubuntu自带2.7版本的话请将命令行中python3.5改为python。

参考:
https://cloud.tencent.com/developer/section/1475696
https://blog.csdn.net/andylei777/article/details/78542624?locationNum=4&fps=1
https://blog.csdn.net/qq_27637315/article/details/79114633

tensorflow添加自定义OP(GPU版本)相关推荐

  1. win7下安装TensorFlow框架的gpu版本

    win7下安装TensorFlow框架的gpu版本 首先附上成功截图 一.系统情况 二.安装工具准备 三.TensorFlow-GPU安装 四.Keras安装 首先附上成功截图 欢迎大家评论,若碰到了 ...

  2. TensorFlow GPU 版本安装个人总结:Win10 + Python3.5 + CUDA 9.0.176 + cudnn v7.5.0.56 + TensorFlow 1.12.0

    TensorFlow GPU 版本安装个人总结:Win10 + Python3.5 + CUDA 9.0.176 + cudnn v7.5.0.56 + TensorFlow 1.12.0 接触机器学 ...

  3. Window10 Tensorflow 2.1 GPU 安装和测试

    Tensorflow 2.1 GPU 安装和测试 1. 硬件要求 2. 软件要求 简单的描述一下它们的功能 3. 安装步骤 3.1. nvidia 驱动可以到这个地址下载, 我的显卡是RTX 2070 ...

  4. 安装指南:Win10系统+ tensorflow 1.7 GPU+Cuda v9.0+cudnnv7.1 +Python3.6

    一开始安装在WIN10安装tensorflow GPU版本可以说是费尽周折,花了一周多的时间,发现版本有一点不对应,就没办法成功,所以每一步安装都要非常"精准",为了给新手扫盲,特 ...

  5. Ubuntu tensorflow自定义GPU版本op节点

    参考:https://blog.csdn.net/qq_27637315/article/details/79114633 windows增加op节点: https://github.com/tens ...

  6. tensorflow自定义GPU版本op节点

    由于前段时间导师布置了一个任务,要修改损失函数,但是这个损失函数在tensorflow自带的库中又没有,想了很多办法,试来试去找不到一个解决方案,因为tensorflow是把框架和数据分开的,所以直接 ...

  7. tensorflow GPU版本配置加速环境

    import tensorflow as tf tf.test.is_gpu_available() 背景 环境:Anaconda .tensorflow_gpu==1.4.0 (这里就用1.4.0版 ...

  8. 禁用GPU版本TensorFlow,切换到CPU版本TensorFlow。

    #禁用gpu版本TensorFlow,因为CUDA号码从0开始,这里直接让CUDA使用-1的GPU,自然就无法使用gpu了. 代码前面加入: import os os.environ["CU ...

  9. 通过Anaconda在Ubuntu16.04上安装 TensorFlow(GPU版本)

    一. 安装环境 Ubuntu16.04.3 LST GPU: GeForce GTX1070 Python: 3.5 CUDA Toolkit 8.0 GA1 (Sept 2016) cuDNN v6 ...

最新文章

  1. codeforces Gargari and Permutations(DAG+BFS)
  2. Math.Round()——面试题小结
  3. git 远程仓库版本的回退以及git reset 几种常用方式记录
  4. port常用和不常用端口一览表
  5. python文件处理系列(一):配置文件处理
  6. http get extension information - another way to get host url and port number of current application
  7. MySQL在哪里看secret_key_K8S 创建和查看secret(九)
  8. 字符串太长 pep8_Python f字符串– PEP 498 –文字字符串插值
  9. Atitit..状态机与词法分析  通用分词器 分词引擎的设计与实现 attilax总结
  10. Linux 配置虚拟IP
  11. Flutter中的JSON解析
  12. 树莓派python 简介_树莓派与python语言概述
  13. LiDAR-based Panoptic Segmentation via Dynamic Shifting Network(论文阅读笔记)
  14. matlab 二次函数图像
  15. Xrm.Page.data.entity Properties and Methods
  16. 新零售mysql设计 订单表 订单详情表
  17. CentOS7 安装 chrome
  18. CSAPP第五章家庭作业参考答案
  19. 《语雀 IT 百科》发布了!
  20. 咖啡烘焙饕餮盛宴——洛阳新都汇有你想要的感觉

热门文章

  1. jstl中fn表达式
  2. matlab 非线性状态方程,非线性方程组求解及matlab实现.ppt
  3. 二本出身、逆袭网易、一路孤独、一路狂欢!
  4. Java反射机制的深入应用
  5. 会计学python就业_会计学是夕阳行业吗?
  6. RHCSA认证考试---2.给系统配置默认的存储库
  7. LISP多边形形心计算公式_计算几何-多边形重心公式
  8. htmltestrunner 中的字段含义verbosity
  9. 计算机二级为什么靠Java的少,计算机二级考试:Java语言学习六大要点
  10. 执行下列python程序输出结果是什么_下列Python程序的运行结果是 x=0 y=True print(xy and 'A''B')_学小易找答案...