参考:

文章目录

  • 1.backbones.py分析
    • 1.def __init__(self, num_features, *args, **kwargs):中*args和**kwargs到底是什么
    • 2.BatchNorm2d函数分析:
    • 3.nn.ReLU(inplace=True)代码分析
    • 4.output stride = 4解析:
    • 5.__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']中__all__解析
    • 6.nn.Conv2d(384, 32, 3, 1, groups=2))这个groups=2
  • 2.heads.py分析
  • 3.train.py分析
  • 4.siamfc.py
    • SiamFCTransforms
    • train_step

1.backbones.py分析

from __future__ import absolute_importimport torch.nn as nn__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']class _BatchNorm2d(nn.BatchNorm2d):def __init__(self, num_features, *args, **kwargs):super(_BatchNorm2d, self).__init__(num_features, *args, eps=1e-6, momentum=0.05, **kwargs)class _AlexNet(nn.Module):def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)return xclass AlexNetV1(_AlexNet):output_stride = 8def __init__(self):super(AlexNetV1, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 2),_BatchNorm2d(96),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, groups=2),_BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv4 = nn.Sequential(nn.Conv2d(384, 384, 3, 1, groups=2),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv5 = nn.Sequential(nn.Conv2d(384, 256, 3, 1, groups=2))class AlexNetV2(_AlexNet):output_stride = 4def __init__(self):super(AlexNetV2, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 2),_BatchNorm2d(96),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, groups=2),_BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(3, 1))self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv4 = nn.Sequential(nn.Conv2d(384, 384, 3, 1, groups=2),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv5 = nn.Sequential(nn.Conv2d(384, 32, 3, 1, groups=2))class AlexNetV3(_AlexNet):output_stride = 8def __init__(self):super(AlexNetV3, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 192, 11, 2),_BatchNorm2d(192),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv2 = nn.Sequential(nn.Conv2d(192, 512, 5, 1),_BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv3 = nn.Sequential(nn.Conv2d(512, 768, 3, 1),_BatchNorm2d(768),nn.ReLU(inplace=True))self.conv4 = nn.Sequential(nn.Conv2d(768, 768, 3, 1),_BatchNorm2d(768),nn.ReLU(inplace=True))self.conv5 = nn.Sequential(nn.Conv2d(768, 512, 3, 1),_BatchNorm2d(512))

这个module主要实现了3个AlexNet版本作为backbone,开头的__all__ = [‘AlexNetV1’, ‘AlexNetV2’, ‘AlexNetV3’]主要是为了让别的module导入这个backbones.py的东西时,只能导入__all__后面的部分。最后作者用的是AlexNetV1。


class _BatchNorm2d(nn.BatchNorm2d):def __init__(self, num_features, *args, **kwargs):super(_BatchNorm2d, self).__init__(num_features, *args, eps=1e-6, momentum=0.05, **kwargs)

1.def init(self, num_features, args, **kwargs):中args和**kwargs到底是什么

*args的用法:当传入的参数个数未知,且不需要知道参数名称时。
**args的用法:当传入的参数个数未知,但需要知道参数的名称时(字典,即键值对)
这里的args不是必须的,也就是说可以换成kwargs等等,但是星号是必须的

2.BatchNorm2d函数分析:

BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
num_features:一般输入参数为batch_sizenum_featuresheight*width,即为其中特征的数量,即为输入BN层的通道数;
eps:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5,避免分母为0;
momentum???:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数);(手册里说,一般不修改就对了)
affine:是否需要仿射。如果affine=False,γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。


class AlexNetV1(_AlexNet):output_stride = 8def __init__(self):super(AlexNetV1, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 2),_BatchNorm2d(96),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, groups=2),_BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2))self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv4 = nn.Sequential(nn.Conv2d(384, 384, 3, 1, groups=2),_BatchNorm2d(384),nn.ReLU(inplace=True))self.conv5 = nn.Sequential(nn.Conv2d(384, 256, 3, 1, groups=2))

3.nn.ReLU(inplace=True)代码分析

inplace = False 时,不会修改输入对象的值,而是返回一个新创建的对象,所以打印出对象存储地址不同,类似于C语言的值传递
inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同,类似于C语言的址传递

4.output stride = 4解析:

output stride为该矩阵经过多次卷积pooling操作后,尺寸缩小的值(这个还保留疑问)

5.all = [‘AlexNetV1’, ‘AlexNetV2’, ‘AlexNetV3’]中__all__解析

__all__是一个字符串list,用来定义模块中对于from XXX import 时要对外导出的符号,即要暴露的接口,但它只对import *起作用,对from XXX import XXX不起作用。
控制 from xxx import * 的行为

6.nn.Conv2d(384, 32, 3, 1, groups=2))这个groups=2

这个groups=2,是将卷积分为两组:

注意:有些人会感觉这里输入输出通道对不上,这是因为像原本AlexNet分成了2个group,所以会有48->96, 192->384这样。


if __name__ == '__main__':alexnetv1 = AlexNetV1()import torchz = torch.randn(1, 3, 127, 127)output = alexnetv1(z)print(output.shape)  # torch.Size([1, 256, 6, 6])x = torch.randn(1, 3, 256, 256)output = alexnetv1(x)print(output.shape)  # torch.Size([1, 256, 22, 22])# 换成AlexNetV2依次是:# torch.Size([1, 32, 17, 17])、torch.Size([1, 32, 49, 49])# 换成AlexNetV3依次是:# torch.Size([1, 512, 6, 6])、torch.Size([1, 512, 22, 22])

2.heads.py分析


from __future__ import absolute_importimport torch.nn as nn
import torch.nn.functional as F__all__ = ['SiamFC']class SiamFC(nn.Module):def __init__(self, out_scale=0.001):super(SiamFC, self).__init__()self.out_scale = out_scaledef forward(self, z, x):return self._fast_xcorr(z, x) * self.out_scaledef _fast_xcorr(self, z, x):# fast cross correlation# x size 8,256,20,20# z size 8,256,6,6nz = z.size(0)  #size(0)即取第一个shape值#nz = 8nx, c, h, w = x.size()#nx = 8,c = 256,h = 20,w = 20x = x.view(-1, nz * c, h, w)#x.shape = [1,2048,20,20]out = F.conv2d(x, z, groups=nz)# out.shape = [1,8,15,15]#输入是4维,输出也是4维,**高层补1???**# print(out.size())out = out.view(nx, -1, out.size(-2), out.size(-1))# out.shape = [8,1,15,15]return out

为什么这里会有个out_scale,根据作者说是因为, z 和x 互相关之后的值太大,经过sigmoid函数之后会使值处于梯度饱和的那块,梯度太小,乘以out_scale就是为了避免这个。
这里nn.Con2d要与nn.function.conv2d区分开:(可参考手册)
1.nn.Conv2d
torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode=‘zeros’)
in_channels-----输入通道数
out_channels-------输出通道数
kernel_size--------卷积核大小
stride-------------步长
padding---------是否对输入数据填充0
2.nn.function.conv2d
torch.nn.functional.conv2d(input,weight,bias=None,stride=1,padding=0,dilation=1,groups=1)
input-------输入tensor大小(minibatch,in_channels,iH, iW)
weight------权重大小(out_channels, in_channels/groups, kH, kW)
注意:权重参数中,第一个卷积核的输出通道数,第二个是输入通道数
这里针对nn.function.conv2d讨论
例如:torch.nn.functional.conv2d(self, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=2)
input:
minibatch:batch中的样例个数
in_channels:每个样例数据的通道数
iH:每个样例的高(行数)
iW:每个样例的宽(列数)

weight(就是filter):
out_channels:卷积核的个数
in_channels/groups:每个卷积核的通道数
kH:每个卷积核的高(行数)
kW:每个卷积核的宽(列数)
groups作用:对input中的每个样例数据,将通道分为groups等份,即每个样例数据被分成了groups个大小为(in_channel/groups, iH, iW)的子数据。对于这每个子数据来说,卷积核的大小为(in_channel/groups, kH, kW)。这一整个样例数据的计算结果为各个子数据的卷积结果拼接所得

3.train.py分析

参考siamfc-pytorch代码讲解(二):train&siamfc

因为作者使用了GOT-10k这个工具箱,train.py代码非常少,就下面几行:


from __future__ import absolute_importimport os
from got10k.datasets import *from siamfc import TrackerSiamFCif __name__ == '__main__':root_dir = os.path.expanduser('~/data/GOT-10k')seqs = GOT10k(root_dir, subset='train', return_meta=True)tracker = TrackerSiamFC()tracker.train_over(seqs)

首先我们就需要按照GOT-10k download界面去下载好数据集,并且按照这样的文件结构放好(因为现在用不到验证集和测试集,可以先不用下,训练集也只要先下载1个split,所以就需要把list.txt中只保留前500项,因为GOT-10k_Train_000001里面有500个squences):???

|-- GOT-10k/|-- train/|  |-- GOT-10k_Train_000001/|  |   ......|  |-- GOT-10k_Train_000500/|  |-- list.txt

这里可以打印一下seps到底是什么,因为他是train_over的入参:


print(seqs)
# <got10k.datasets.got10k.GOT10k object at 0x000002366865CF28>
print(seqs[0])
# 这里比较多,截取一部分
# seqs[0]就是指第一个序列GOT-10k_Train_000001,返回三个元素的元组
# 第一个元素是一个路径列表,第二个是np.ndarray,第三个是字典,包含具体信息
# (['D:\\GOT-10k\\train\\GOT-10k_Train_000001\\00000001.jpg', ...],
# array([[347., 443., 429., 272.],...[551., 467., 513., 318.]]),
# {'url': 'https://youtu.be/b0ZnfLI8YPw',...})

4.siamfc.py

SiamFCTransforms

SiamFCTransforms是transforms.py里面的一个类,主要是对输入的groung truth的z, x, bbox_z, bbox_x进行一系列变换,构成孪生网络的输入,这其中就包括了

RandomStretch:主要是随机的resize图片的大小,其中要注意cv2.resize()的一点用法,可以参考我的这篇博客:cv2.resize()的一点小坑
CenterCrop:从img中间抠一块(size, size)大小的patch,如果不够大,以图片均值进行pad之后再crop
RandomCrop:用法类似CenterCrop,只不过从随机的位置抠,没有pad的考虑
Compose:就是把一系列的transforms串起来
ToTensor: 就是字面意思,把np.ndarray转化成torch tensor类型

类初始化里面针对self.transforms_z和self.transforms_x数据增强方法中具体参数的设置可以参考 issue#21,作者提到在train phase和test phase embedding size不一样没太大的影响,而且255-16可以模拟测试阶段目标的移动(个人感觉这里没有完全就按照论文上来,但也不用太在意,自己可以试着改回来看哪一个效果好)。#???

基于pytorch的代码,一般获取和处理数据,都定义在数据集定义中的__getitem__类方法中。这里我们到dataset.py中的Pair类的__getitem__中,找到如下:

 crop_z = self.crop(img_z, bndbox_z, self.exemplarSize)  # crop template patch from img_z, then resize [127, 127]crop_x = self.crop(img_x, bndbox_x, self.instanceSize)  # crop search patch from img_x, then resize [255, 255]

下面具体讲里面的_crop函数:

def _crop(self, img, box, out_size):# convert box to 0-indexed and center based [y, x, h, w]box = np.array([box[1] - 1 + (box[3] - 1) / 2,box[0] - 1 + (box[2] - 1) / 2,box[3], box[2]], dtype=np.float32)center, target_sz = box[:2], box[2:]context = self.context * np.sum(target_sz)#(w+h)/2=2p,context就是添加边界的宽度size = np.sqrt(np.prod(target_sz + context))#根号下(w+2p)*(h+2p)size *= out_size / self.exemplar_sz#255/127,s缩放系数avg_color = np.mean(img, axis=(0, 1), dtype=float)#填充色为像素的均值interp = np.random.choice([cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_AREA,cv2.INTER_NEAREST,cv2.INTER_LANCZOS4])# np.random.choice()函数:在给定的元素中随机选取patch = ops.crop_and_resize(img, center, size, out_size,border_value=avg_color, interp=interp)return patch

因为GOT-10k里面对于目标的bbox是以ltwh(即left, top, weight, height)形式给出的,上述代码一开始就先把输入的box变成center based,坐标形式变为[y, x, h, w],结合下面这幅图就非常好理解:

论文中,在bounding box加上了周围的一些边界信息,上下左右各加上p个像素信息。其中:
假设bounding box的大小为(w,h),加上边界后的大小为(w+2p,h+2p)。对于模板图像而言,还需要对其加上边界后的结果乘上一个缩放系数s,使得区域的面积为127x127,当然了,这里缩放后的长宽不一定非要等于127,那么
会发现虽然现在区域面积为127x127,但是他并不是127x127的方形,所以需要进行resize。

crop_and_resize:


def crop_and_resize(img, center, size, out_size,border_type=cv2.BORDER_CONSTANT,border_value=(0, 0, 0),interp=cv2.INTER_LINEAR):# convert box to corners (0-indexed)计算绿色框中的两个点size = round(size)  # the size of square cropcorners = np.concatenate((np.round(center - (size - 1) / 2),np.round(center - (size - 1) / 2) + size))corners = np.round(corners).astype(int)# pad image if necessarypads = np.concatenate((-corners[:2], corners[2:] - img.shape[:2]))npad = max(0, int(pads.max()))if npad > 0:img = cv2.copyMakeBorder(img, npad, npad, npad, npad,border_type, value=border_value)# crop image patchcorners = (corners + npad).astype(int)patch = img[corners[0]:corners[2], corners[1]:corners[3]]# resize to out_sizepatch = cv2.resize(patch, (out_size, out_size),interpolation=interp)return patch


train_step


for epoch in range(self.cfg.epoch_num):# update lr at each epochself.lr_scheduler.step(epoch=epoch)# loop over dataloaderfor it, batch in enumerate(dataloader):loss = self.train_step(batch, backward=True)print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(epoch + 1, it + 1, len(dataloader), loss))sys.stdout.flush()

而train_step里面难度又是在于理解_create_labels,具体的一些tensor的shape可以看我的注释,我好奇就把他打印出来了,看来本来__getitem__返回一对pair(z, x),经过dataloader的加载,还是z堆叠一起,x堆叠一起,并不是(z, x)绑定堆叠一起【主要自己对dataloader源码不是很熟,手动捂脸】
而且criterion使用的BalancedLoss,是调用F.binary_cross_entropy_with_logits,进行一个element-wise的交叉熵计算,所以创建出来的labels的shape其实就是和responses的shape是一样的:


def train_step(self, batch, backward=True):# set network modeself.net.train(backward)# parse batch dataz = batch[0].to(self.device, non_blocking=self.cuda)x = batch[1].to(self.device, non_blocking=self.cuda)# print("batch_z shape:", z.shape)  # torch.Size([8, 3, 127, 127])# print("batch_x shape:", x.shape)  # torch.Size([8, 3, 239, 239])with torch.set_grad_enabled(backward):# inferenceresponses = self.net(z, x)# print("responses shape:", responses.shape) # torch.Size([8, 1, 15, 15])# calculate losslabels = self._create_labels(responses.size())loss = self.criterion(responses, labels)if backward:# back propagationself.optimizer.zero_grad()loss.backward()self.optimizer.step()

创建标签,论文里是这么说的:为我们的exemplar image z zz 和search image x xx都是以目标为中心的,所以labels的中心为1,中心以外为0。


def _create_labels(self, size):# skip if same sized labels already createdif hasattr(self, 'labels') and self.labels.size() == size:return self.labelsdef logistic_labels(x, y, r_pos, r_neg):dist = np.abs(x) + np.abs(y)  # block distancelabels = np.where(dist <= r_pos,np.ones_like(x),np.where(dist < r_neg,np.ones_like(x) * 0.5,np.zeros_like(x)))return labels# distances along x- and y-axisn, c, h, w = sizex = np.arange(w) - (w - 1) / 2y = np.arange(h) - (h - 1) / 2x, y = np.meshgrid(x, y)# create logistic labels 这里除以stride,是相对score map上来说r_pos = self.cfg.r_pos / self.cfg.total_strider_neg = self.cfg.r_neg / self.cfg.total_stridelabels = logistic_labels(x, y, r_pos, r_neg)# repeat to sizelabels = labels.reshape((1, 1, h, w))labels = np.tile(labels, (n, c, 1, 1))# convert to tensorsself.labels = torch.from_numpy(labels).to(self.device).float()return self.labels

最后出来的一个batch下某一个通道下的label就是下面这样的,有没有一种扫雷的既视感,

siamfc-pytorch代码分析相关推荐

  1. GAT: 图注意力模型介绍及PyTorch代码分析

    文章目录 GAT: 图注意力模型介绍及代码分析 原理 图注意力层(Graph Attentional Layer) 情境一:节点和它的一个邻居 情境二:节点和它的多个邻节点 聚合(Aggregatio ...

  2. GAT:图注意力模型介绍及PyTorch代码分析

    文章目录 1.计算注意力系数 2.聚合 2.1 附录--GAT代码 2.2 附录--相关代码 3.完整实现 3.1 数据加载和预处理 3.2 模型训练 1.计算注意力系数 对于顶点 iii ,通过计算 ...

  3. 【BasicNet系列:六】MobileNet 论文 v1 v2 笔记解读 + pytorch代码分析

    1.MobileNet V1 MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 参考 ...

  4. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  5. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  6. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

  7. ResNet论文笔记及Pytorch代码解析

    注:个人学习记录 感谢B站up主"同济子豪兄"的精彩讲解,参考视频的记录 [精读AI论文]ResNet深度残差网络_哔哩哔哩_bilibili 算法的意义(大概介绍) CV史上的技 ...

  8. 门控图神经网络(GGNN)及代码分析

    门控图神经网络GGNN及代码分析 基本概念 GGNN是一种基于GRU的经典的空间域message passing的模型 问题描述 一个图 G = (V, E), 节点v ∈ V中存储D维向量,边e ∈ ...

  9. 语音识别(2): kws项目实现、数据集代码分析

    语音识别(2):KWS数据集代码分析 数据集分析 kws的语音数据为该数据集有 30 个短单词的 65000 个长度 1 秒钟的发音. 这是Google的一个语音数据集 下载地址:http://dow ...

最新文章

  1. 读取xml忽略dtd验证
  2. 滚动条禁止_Axure 教程:不可见滚动条的页面滚动效果
  3. 一款功能强大,可扩展端到端加密反向Shell的工具
  4. Angular 路由守卫
  5. 2020蓝桥杯省赛---java---B---9(子串分值和)
  6. mysql connector cpp_MySQL Connector/C++(一)
  7. 随想录(windows上cuda环境安装)
  8. 凸优化笔记(非常零碎)
  9. oracle客户端登录失败,Win7系统配置Oracle客户端连接失败的解决方法
  10. python调用matlab
  11. (2020)Java后端开发----(面试题和笔试题)
  12. [读书笔记] 代码整洁之道(二)
  13. 课程作业记录3:瑞利衰落信道下的BPSK/QPSK/16QAM的Matlab仿真
  14. vs2019,C#,MySQL创建图书管理系统7(用户借/还书)
  15. 从零到熟悉,带你掌握Python len() 函数的使用
  16. 连接本地数据库,mysql提示Can‘t connect to MySQL server on localhost (10061)解决办法
  17. Element 表单只能输入数字校验
  18. PLD PLA PAL GAL
  19. matlab绘制动态图,Matlab绘制动态图的两种方式(参考)
  20. SSH学习(个人笔记)

热门文章

  1. qss使用及优先级关系
  2. 口袋之旅html5超强账号,口袋之旅h5高级账号,h5裂空座多少高级狩猎卷
  3. 如何下载谷歌/百度/高德大字体地图用于打印
  4. java sca_SCA java编码入门
  5. kettle工具下载、安装、数据迁移、定时任务详解
  6. 论文解读:HINGRL:通过异构信息网络上的图表示学习预测药物-疾病关联
  7. 汽车品牌 API数据接口
  8. lucene实现分组统计的方法
  9. win98vmdk镜像_VMDK、VHD镜像互转工具
  10. java计算机毕业设计消防安全应急培训管理平台源码+系统+数据库+lw文档+mybatis+运行部署