最近才真正开始研究目标跟踪领域(好吧,是真的慢)。就先看了一篇论文:

Fully-Convolutional Siamese Networks for Object Tracking【ECCV2016 workshop】

又因为学的是PyTorch框架,所以找了一份比较clean的代码,还是pytorch1.0的:

https://github.com/huanglianghua/siamfc-pytorch

因为这个作者也是GOT-10k toolkit的主要贡献者,所以用上这个工具箱之后显得training和test会clean一些,要能跑训练和测试代码,还得去下载GOT-10k数据集,训练数据分成了19份,如果只是为了跑一下下一份就行。

论文概述

SiamFC这篇论文算是将深度神经网络较早运用于tracking的,比它还早一点的就是SINT了,主要是运用了相似度学习的思想,采用孪生网络,把127×127的exemplar image

和255×255的search image

输入同一个backbone(论文中就是AlexNet)也叫Embedding Network,生成各自的Embedding,然后这两个Embedding经过互相关计算的得到score map,其上大的位置就代表对应位置上的Embedding相似度大,反之亦然。整个训练流程可以用下图表示:

SiamFC训练流程

个人感觉,训练就是为了优化Embedding Network,在见到的序列中生成一个更好embedding,从而使生成的score map和生成的ground truth有更小的logistic loss。更多细节在之后的几篇会和代码一起分析。

backbones.py分析

from __future__ import absolute_import

import torch.nn as nn

__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']

class _BatchNorm2d(nn.BatchNorm2d):

def __init__(self, num_features, *args, **kwargs):

super(_BatchNorm2d, self).__init__(

num_features, *args, eps=1e-6, momentum=0.05, **kwargs)

class _AlexNet(nn.Module):

def forward(self, x):

x = self.conv1(x)

x = self.conv2(x)

x = self.conv3(x)

x = self.conv4(x)

x = self.conv5(x)

return x

class AlexNetV1(_AlexNet):

output_stride = 8

def __init__(self):

super(AlexNetV1, self).__init__()

self.conv1 = nn.Sequential(

nn.Conv2d(3, 96, 11, 2),

_BatchNorm2d(96),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 2))

self.conv2 = nn.Sequential(

nn.Conv2d(96, 256, 5, 1, groups=2),

_BatchNorm2d(256),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 2))

self.conv3 = nn.Sequential(

nn.Conv2d(256, 384, 3, 1),

_BatchNorm2d(384),

nn.ReLU(inplace=True))

self.conv4 = nn.Sequential(

nn.Conv2d(384, 384, 3, 1, groups=2),

_BatchNorm2d(384),

nn.ReLU(inplace=True))

self.conv5 = nn.Sequential(

nn.Conv2d(384, 256, 3, 1, groups=2))

class AlexNetV2(_AlexNet):

output_stride = 4

def __init__(self):

super(AlexNetV2, self).__init__()

self.conv1 = nn.Sequential(

nn.Conv2d(3, 96, 11, 2),

_BatchNorm2d(96),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 2))

self.conv2 = nn.Sequential(

nn.Conv2d(96, 256, 5, 1, groups=2),

_BatchNorm2d(256),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 1))

self.conv3 = nn.Sequential(

nn.Conv2d(256, 384, 3, 1),

_BatchNorm2d(384),

nn.ReLU(inplace=True))

self.conv4 = nn.Sequential(

nn.Conv2d(384, 384, 3, 1, groups=2),

_BatchNorm2d(384),

nn.ReLU(inplace=True))

self.conv5 = nn.Sequential(

nn.Conv2d(384, 32, 3, 1, groups=2))

class AlexNetV3(_AlexNet):

output_stride = 8

def __init__(self):

super(AlexNetV3, self).__init__()

self.conv1 = nn.Sequential(

nn.Conv2d(3, 192, 11, 2),

_BatchNorm2d(192),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 2))

self.conv2 = nn.Sequential(

nn.Conv2d(192, 512, 5, 1),

_BatchNorm2d(512),

nn.ReLU(inplace=True),

nn.MaxPool2d(3, 2))

self.conv3 = nn.Sequential(

nn.Conv2d(512, 768, 3, 1),

_BatchNorm2d(768),

nn.ReLU(inplace=True))

self.conv4 = nn.Sequential(

nn.Conv2d(768, 768, 3, 1),

_BatchNorm2d(768),

nn.ReLU(inplace=True))

self.conv5 = nn.Sequential(

nn.Conv2d(768, 512, 3, 1),

_BatchNorm2d(512))

这个module主要实现了3个AlexNet版本作为backbone,开头的__all__ = ['AlexNetV1', 'AlexNetV2', 'AlexNetV3']主要是为了让别的module导入这个backbones.py的东西时,只能导入__all__后面的部分。

后面就是三个类AlexNetV1、AlexNetV2、AlexNetV3,他们都集成了类_AlexNet,所以他们都是使用同样的forward函数,依次通过五个卷积层,每个卷积层使用nn.Sequential()堆叠,只是他们各自的total_stride和具体每层卷积层实现稍有不同(当然跟原本的AlexNet还是有些差别的,比如通道数上):

AlexNetV1和AlexNetV2:

共同点:conv2、conv4、conv5这几层都用了groups=2的分组卷积,这跟原来的AlexNet会更接近一点

不同点:conv2中的MaxPool2d的stride不一样大,conv5层的输出通道数不一样

AlexNetV1和AlexNetV3:前两层的MaxPool2d是一样的,但是中间层的卷积层输入输出通道都不一样,最后的输出通道也不一样,AlexNetV3最后输出经过了BN

AlexNetV2和AlexNetV3:conv2中的MaxPool2d的stride不一样,AlexNetV2最后输出通道数小很多

其实感觉即使有这些区别,但是这并不是很重要,这一部分也是整体当中容易理解的,所以不必太去纠结为什么不一样,最后作者用的是AlexNetV1,论文中是这样的结构,其实也就是AlexNetV1:

论文中backbone结构

注意:有些人会感觉这里输入输出通道对不上,这是因为像原本AlexNet分成了2个group,所以会有48->96, 192->384这样。

也可以在此py文件下面再加一段代码,测试一下打印出的tensor的shape:

if __name__ == '__main__':

alexnetv1 = AlexNetV1()

import torch

z = torch.randn(1, 3, 127, 127)

output = alexnetv1(z)

print(output.shape) # torch.Size([1, 256, 6, 6])

x = torch.randn(1, 3, 256, 256)

output = alexnetv1(x)

print(output.shape) # torch.Size([1, 256, 22, 22])

# 换成AlexNetV2依次是:

# torch.Size([1, 32, 17, 17])、torch.Size([1, 32, 49, 49])

# 换成AlexNetV3依次是:

# torch.Size([1, 512, 6, 6])、torch.Size([1, 512, 22, 22])

heads.py

先放代码为敬:

class SiamFC(nn.Module):

def __init__(self, out_scale=0.001):

super(SiamFC, self).__init__()

self.out_scale = out_scale

def forward(self, z, x):

return self._fast_xcorr(z, x) * self.out_scale

def _fast_xcorr(self, z, x):

# fast cross correlation

nz = z.size(0)

nx, c, h, w = x.size()

x = x.view(-1, nz * c, h, w)

out = F.conv2d(x, z, groups=nz) # shape:[nx/nz, nz, H, W]

out = out.view(nx, -1, out.size(-2), out.size(-1)) #[nx, 1, H, W]

return out

为什么这里会有个out_scale,根据作者说是因为,

互相关之后的值太大,经过sigmoid函数之后会使值处于梯度饱和的那块,梯度太小,乘以out_scale就是为了避免这个。

_fast_xcorr函数中最关键的部分就是F.conv2d函数了,可以通过官网查询到用法

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor

input – input tensor of shape (

,

,

,

)

weight – filters of shape (

,

,

,

)

所以根据上面条件,可以得到:x shape:[nx/nz, nz*c, h, w] 和 z shape:[nz, c, hz, wz],最后out shape:[nx, 1, H, W]

其实最后真实喂入此函数的z embedding shape:[8, 256, 6, 6], x embedding shape:[8, 256, 20, 20], output shape:[8, 1, 15, 15]【这个之后再回过来看也行】

同样的,也可以用下面一段代码测试一下:

if __name__ == '__main__':

import torch

z = torch.randn(8, 256, 6, 6)

x = torch.randn(8, 256, 20, 20)

siamfc = SiamFC()

output = siamfc(z, x)

print(output.shape) # torch.Size([8, 1, 15, 15])

好了,这部分先讲到这里,这一块还是算简单的,一般看一下应该就能理解,之后的代码会更具挑战性,嘻嘻,放一个辅助链接,下面这个版本中有一些动图,还是会帮助理解的:

还有下面是GOT-10k的toolkit,可以先看一下,但是训练部分代码还不是涉及很多:

下一篇

siamfc代码解读_siamfc-pytorch代码讲解(一):backbonehead相关推荐

  1. python实现胶囊网络_Capsule Network胶囊网络解读与pytorch代码实现

    本文是论文<Dynamic Routing between Capsules>的论文解读与pytorch代码实现. 如需转载本文或代码请联系作者 @Riroaki 并声明. 众所周知,卷积 ...

  2. ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士. https://zhuanlan.zhihu.com/p/54289848 温故而知新,理 ...

  3. 如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)

    之前讲解了图注意力网络的官方tensorflow版的实现,由于自己更了解pytorch,所以打算将其改写为pytorch版本的. 对于图注意力网络还不了解的可以先去看看tensorflow版本的代码, ...

  4. java车间调度算法_混合算法(GA+TS)求解作业车间调度问题代码解读+完整JAVA代码...

    程序猿声 代码黑科技的分享区 前两篇文章中,我们介绍了FJSP问题,并梳理了一遍HA算法.这一篇文章对小编实现的(很乱很烂的)代码进行简单解读. 往期回顾: 代码下载请关注公众号,后台回复[FJSPH ...

  5. matlab sift代码解读,MATLAB SIFT 代码

    [实例简介] matlab 实现的 sift 变换 的代码,包含整个过程的详细步骤. [实例截图] [核心代码] sift-0.9.0 ├── data │   ├── img3.jpg │   ├─ ...

  6. tensorflow代码翻译成pytorch代码 -详细教程+案例

  7. Transformer Pytorch代码实现以及理解

    Transformer结构​​​​​​​ 论文:Attention is all you need Transformer模型是2017年Google公司在论文<Attention is All ...

  8. shfflenetv2代码解读

    shufflenetv2代码解读 目录 shufflenetv2代码解读 概述 shufflenetv2网络结构图 shufflenetv2架构参数 shufflenetv2代码细节分析 概述 shu ...

  9. GoogLeNet代码解读

    GoogLeNet代码解读 目录 GoogLeNet代码解读 概述 GooLeNet网络结构图 1)从输入到第一层inception 2)从第2层inception到第4层inception 3)从第 ...

  10. Inception代码解读

    Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...

最新文章

  1. Retrofit 找不到 GsonConverterFactory
  2. zabbix4.0添加mysql报警_部署监控三剑客 Zabbix4.0 监控以及告警机制
  3. phpwind自定义推送模块
  4. [c#]Dll自定义目录
  5. c语言运行VBA,C语言选择题部分模块和VBA.ppt
  6. forEach遍历对象数组案例
  7. Html垂直居中不起作用,html – 垂直居中不起作用,因为行不会达到100%的高度
  8. 蛋糕是叫胚子还是坯子_这个生日蛋糕太适合手残党了,不会裱花也能做,学会再不买着吃了...
  9. Centos7安装WPS
  10. SpringBoot RestTemplate 发送请求 忽略证书不安全
  11. 2022“杭电杯”中国大学生算法设计超级联赛(7) 2022杭电多校第七场
  12. 计算机用户名怎么改好听,电脑版本优酷视频如何设置呢称_昵称起名
  13. python拨打网络电话_0成本搭建IP电话系统,统一通信系统,呼叫中心系统-3CX快速安装手册...
  14. vux组件库更换主题颜色的方法
  15. 实现在 .net 中使用 HttpClient 下载文件时显示进度
  16. 损失函数-MSE-CEE
  17. pancakeswap 前端源码编译及部署-linux
  18. python 数据挖掘 之 对数据进行简单预处理(1)
  19. php培训视频全套,43G 干货分享 2017年泰牛PHP全套视频+培训全套完整版课件
  20. springboot2 配置404、403、500等错误页面自动跳转

热门文章

  1. python导入鸢尾花数据集_2020-11-01 鸢尾花数据集Python处理
  2. 华为主题包hwt下载_emui主题打包下载-emui主题打包 v1.0_手机乐园
  3. python调用按键精灵插件_按键精灵 插件命令 重中之重务必要记住怎么操作
  4. Excel表格如何根据身份证号计算年龄
  5. C在mac上用不了malloc.h头文件的解决方法
  6. 最流行的三大数据建模工具
  7. 机器学习-数据归一化方法(Normalization Method)
  8. 小米笔记本、小米游戏本重装原装出厂镜像教程-有百度盘的提取码
  9. 让电脑「读懂」你的思想——java工程师的职业规划
  10. Oracle 密码过期