前言

关于PointNet模型的构成、原理、效果等等论文部分内容,我在之前一篇论文中写到过,可以参考这个链接:PointNet论文笔记    下边我就直接放一张网络组成图,并对代码进行解释,我以一种比较容易理解的顺序放置,希望耐心阅读。

网络结构图示

在分类网络中,输入n个点,对输入做特征变换,再进行最大池化输出k个种类;分割网络是分类网络的一个拓展,它考虑了全局和局部的特征以及每个点的输出分数。mlp代表多层感知机,括号中是感知机的层数,批标准化(Batchnorm)本用于所有带有ReLU函数的层,Dropout层被用于分类网络中最后一个多层感知机中。

代码详解

首先我先来讲解分类网络,图中深色部分,首先输入点经过一个transform,再经过多层感知机,再经过一个feature transform,再经过多层感知机和max pooling,最后经过多层感知机获得分类结果,网络结构是比较清晰的,下边一块一块看:

input transform

首先这一层的目的是对输入的每一个点云,在这里是2500个三坐标点,目的是要获得一个3×3的变换矩阵,获得这个矩阵的原因是:要对点云的姿态进行校正,而该变换矩阵就是根据点云特性,做出一个刚体变换,使点云处于一个比较容易检测的姿态。先对输入经过三级卷积核为1×1的卷积处理得到1024通道的数据,再经过全连接处映射到九个数据,最后调整为3×3

class STN3d(nn.Module):def __init__(self):super(STN3d, self).__init__()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, 3, 3)return x

feature transform

下边我们先考虑后边这个feature transform层,这个其实和上边那个是一样的,只是从电源数据中获取一个64×64的变换矩阵,这个也是对特征的一种校正,一种广义的位姿变换,代码几乎没有差别

class STNkd(nn.Module):def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k*k)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)self.k = kdef forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, self.k, self.k)return x

主体部分

这部分讲max pooling之前的剩余部分,首先经过STN3d获得3×3矩阵,乘以点云做完位姿变换,再经过多层感知机(实际上多层感知机与卷积核边长为1的卷积操作本质是一样的),再乘以经过STNkd获得的64×64的矩阵,完成位姿变换,再经过多层感知机(这里同样用边长为1的卷积核的卷积操作),得到n×1024的矩阵,n为每批次读入的数据文件个数。下边这个类中调用了前边两个类。

class PointNetfeat(nn.Module):def __init__(self, global_feat = True, feature_transform = False):super(PointNetfeat, self).__init__()self.stn = STN3d()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):n_pts = x.size()[2]trans = self.stn(x)x = x.transpose(2, 1)x = torch.bmm(x, trans)x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x)))if self.feature_transform:trans_feat = self.fstn(x)x = x.transpose(2,1)x = torch.bmm(x, trans_feat)x = x.transpose(2,1)else:trans_feat = Nonepointfeat = xx = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)if self.global_feat:return x, trans, trans_featelse:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans, trans_feat

后处理部分

下边就要进行最大池化和多层感知机进行分类了,经过全连接分成k类,根据概率来判别究竟属于哪一类

class PointNetCls(nn.Module):def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)# 防止过拟合self.dropout = nn.Dropout(p=0.3)# 归一化防止梯度爆炸与梯度消失self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):# 完成网络主体部分x, trans, trans_feat = self.feat(x)# 经过三个全连接层(多层感知机)映射成k类x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)# 返回的是该点云是第ki类的概率return F.log_softmax(x, dim=1), trans, trans_feat

分割网络

分割网络是借用了分类网络的两部分,分别是64通道和1024通道,堆积在一起形成1088通道的输入,经过多层感知机输出了结果m通道的结果,m代表类的个数,也就是每个点属于哪一类,实际上分割是在像素级或者点级的分类,本质上是一样的

class PointNetDenseCls(nn.Module):def __init__(self, k = 2, feature_transform=False):super(PointNetDenseCls, self).__init__()self.k = kself.feature_transform=feature_transformself.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans, trans_feat

训练结果

训练过程可以参考源码

网络分类性能还是很强的,只是迭代了一次,精度就达到了91%以上

在点较少的情况下,分割效果也还是可以的,5次迭代可以达到 80.0mIoU

PointNet模型的Pytorch代码详解相关推荐

  1. BilSTM 实体识别_NLP-入门实体命名识别(NER)+Bilstm-CRF模型原理Pytorch代码详解——最全攻略

    最近在系统地接触学习NER,但是发现这方面的小帖子还比较零散.所以我把学习的记录放出来给大家作参考,其中汇聚了很多其他博主的知识,在本文中也放出了他们的原链.希望能够以这篇文章为载体,帮助其他跟我一样 ...

  2. BilSTM 实体识别_NLP入门实体命名识别(NER)+BilstmCRF模型原理Pytorch代码详解——最全攻略...

    来自 | 知乎   作者 | seven链接 | https://zhuanlan.zhihu.com/p/79552594编辑 | 机器学习算法与自然语言处理公众号本文仅作学术分享,如有侵权,请联系 ...

  3. 将卷积引入transformer中VcT(Introducing Convolutions to Vision Transformers)的pytorch代码详解

    文章目录 1. Motivation: 2. Method 2.1 Convolutional Token Embedding 模块 2.2 Convolutional Projection For ...

  4. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  5. 目标检测之Faster-RCNN的pytorch代码详解(模型准备篇)

    十月一的假期转眼就结束了,这个假期带女朋友到处玩了玩,虽然经济仿佛要陷入危机,不过没关系,要是吃不上饭就看书,吃精神粮食也不错,哈哈!开个玩笑,是要收收心好好干活了,继续写Faster-RCNN的代码 ...

  6. 一文读懂经典卷积网络模型——LeNet-5模型(附代码详解、MNIST数据集)

    欢迎关注微信公众号[计算机视觉联盟] 获取更多前沿AI.CV资讯 LeNet-5模型是Yann LeCun教授与1998年在论文Gradient-based learning applied to d ...

  7. 生产者消费者模型——C语言代码详解

    概念 生产者消费者模式就是通过一个容器来解决生产者和消费者的强耦合问题.生产者和消费者彼此之间不直接通讯,而通过阻塞队列来进行通讯,所以生产者生产完数据之后不用等待消费者处理,直接扔给阻塞队列,消费者 ...

  8. PointNet代码详解

    PointNet代码详解 最近在做点云深度学习的机器人抓取,这篇博客主要是把近期学习PointNet的一些总结的知识点汇总一下. PointNet概述详见以下网址和博客,这里也就不再赘述了. 三维深度 ...

  9. 从PointNet到PointNet++理论及代码详解

    从PointNet到PointNet++理论及代码详解 1. 点云是什么 1.1 三维数据的表现形式 1.2 为什么使用点云 1.3 点云上以往的相关工作 2. PointNet 2.1 基于点云的置 ...

最新文章

  1. attention retain_Attention-Aware Compositional Network
  2. 干货分享丨轻松玩转 Huawei LiteOS 传感框架
  3. windows下编译c语言文件路径,解决JNI在Windows环境下因长路径导致编译失败问题
  4. 广西科技大学计算机考研,广西科技大学研究生院
  5. android(八)、触摸事件分发
  6. 翻译连载 | 附录 C:函数式编程函数库-《JavaScript轻量级函数式编程》 |《你不知道的JS》姊妹篇...
  7. BigDecimal.divide方法
  8. STM32Cubemx出现工程突然自动退出的问题
  9. 【渝粤教育】广东开放大学 学前儿童保育学 形成性考核 (40)
  10. 用TMG搭建×××服务器(二)---L2TP/IPsec ×××
  11. SPI 机制-插件化扩展功能
  12. 【算法笔记HDU4825】Xor Sum(01字典树模版)
  13. 哨兵卫星及数据下载平台介绍
  14. 算法导论第三版 16.1-5习题答案
  15. 毕业设计 python opencv 机器视觉图像拼接算法
  16. 谷歌地球尝试验证时检测到错误_深思考丨验证码为何越来越难了?
  17. 数字签名技术原理介绍
  18. 服务器托管双线技术方案
  19. 常用Intent合集 Android
  20. 如何用Java设计一个简单的窗口界面(初级二)

热门文章

  1. C# Flurl 高性能的访问http
  2. Android开发要达到阿里P7水平,很难吗,Android高级工程师必备知识
  3. BRISK特征点描述算法详解
  4. 【OpenCV 例程 300篇】245. 特征检测之 BRISK 算子
  5. Java 开发者得力助手,深入实践 Spring Boot
  6. Koa入门(一)—— Koa项目基础框架搭建
  7. 超强1000个jquery极品插件!(转)
  8. Ubuntu 16.04/18.04 安装和使用QQ和微信最简洁的方式(2019.10.28更新)
  9. SSM框架整合(企业权限管理系统)
  10. Vue中,一个组件调用其他组件的方法(非父子组件)