ST-GCN源码分析
在上篇的blog中,写了一下对于ST-GCN论文的分析ST-GCN论文分析_Eric加油学!的博客-CSDN博客,这篇blog写一下对于ST-GCN源码的理解和整理,参考了一些写的比较好的文章,在文末附上链接。
文末有更新的ST-GCN复现全过程(详细)
目录
整体的运行逻辑
静态代码流
动态代码流
核心代码分析
graph.py
def get_edge(self,layout):
def get_hop_distance(num_node,edge,max_hop=1):
def get_adjacency(self,strategy):
st-gcn.py
GCN模块
TCN模块
整体的运行逻辑
静态代码流
主程序 st-gcn-master/main.py
#主程序 st-gcn-master/main.py
#调用不同类文件 (recognition类、demo类)#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()]) #c、接受命令行中参数
arg = parser.parse_args()#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]
p = Processor(sys.argv[2:]) #e、调用recognition类中开始函数
p.start()
类程序1:recognition类 st-gcn-master/processor/recognition.py
def weights_init(m): #权重初始化class REC_Processor(Processor):def load_model(self): #加载模型def load_optimizer(self): #加载优化器def adjust_lr(self): #调整学习率def show_topk(self,k): #显示精度def train(self):def test(self,evaluation = True):def get_parser(add_help = False):
类程序2:processor类 st-gcn-master/processor/processor.py
class Processor(IO):def __init__(self,argv=None):def init_environment(self):def load_optimizer(self):def load_data(self):def show_epoch_info(self):def show_iter_info(self):def train(self):def test(self):def start(self):def get_parser(add_help = False):
类程序3:IO类 st-gcn-master/processor/io.py
class IO():def __init__(self,argv=None):def load_arg(self,argv=None):def init_environment(self):def load_model(self):def load_weights(self):def gpu(self):def start(self):def get_parser(add_help=False):
动态代码流
以NTU交叉主题模型训练为例:
当在终端输入命令: python main.py recognition -c config/st_gcn/ntu-xsub/train.yaml后,执行主程序
#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')
#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()])
# ---> def get_parser(add_help=False):
#c、接受命令行中参数
arg = parser.parse_args()#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor] #arg:processor:recognition
p = Processor(sys.argv[2:]) #sys.argv[2:]:-c config/st_gcn/ntu-xsub/train.yaml
其中实例化和初始化过程如下:
# processor/processor.py
class Processor(IO):def __init__(self,argv=None):self.load_argv(argv) #参数加载,得到self.arg# --> def load_arg(self,argv=None):'''1、读取默认参数到参数表2、使用输入参数更新参数表3、读取参数配置文件更新参数表 配置文件:ntu-xsub/train.yaml4、使用输入参数更新参数表'''self.init_environment():super().init_environment() #继承调用 processor/io.py'''self.io= 获取自定义包中self.io类self.io.save_arg(self.arg) 将参数表保存到工作区配置文件如果使用GPU:获取GPU号和设备号'''#添加定义类参数self.result = dict()self.iter_info = dict()self.epoch_info = dict()self.meta_info = dict(epoch=0, iter=0)self.load_model() # --> recognition.pydef load_model(self):self.model = #下载模型,获得模型self.model#模型文件: /net/st_gcn.pyself.model.apply(weights_init) #权重初始化,见def weights_init(m):self.loss = nn.CrossEntropyLoss() #定义交叉商为损失函数self.load_weights():self.gpu():'''def gpu(self): 将self.arg、self.io等放到gpus上如果使用gpu且数量大于1,使模型并行'''self.load_data():#--> def load_data(self):'''Feeder = import_class(self.arg.feeder) 导入Feeder类self.arg.train_feeder_args['debug'] = self.arg.debugself.data_loader = dict() #建立数据字典self.data_loader['train'] = #训练数据装入self.data_loader['test'] = #测试数据装入'''self.load_optimizer():#-->def load_optimizer(self):#self.optimizer= 定义模型优化参数
开始函数
#e、调用recognition类中开始函数
p.start()
def start(self):if self.arg.phase == 'train':for epoch in range(): #迭代epochself.train()'''def train(self):self.adjust_lr()loader = self.data_loader['train']for data,label in loader:#数据放到GPU上#forward#backward'''self.io.save_model(self.model, filename) #保存模型参数self.test() #测试评估模型'''下载测试数据迭代获取结果,显示精度'''
核心代码分析
核心代码共分3个文件,在net文件夹下,分别为graph.py, tgcn.py, st-gcn.py。
graph.py 中包含邻接矩阵的建立和节点分组策略
st-gcn.py 包含整个网络部分的结构和前向传播方法
tgcn.py 主要是空间域卷积的结构和前向传播方法
graph.py
class Graph的构造函数使用了self.get_edge、self.hop_dis 、self.get_adjacency
def __init__(self,layout='openpose',strategy='uniform',max_hop=1,dilation=1):self.max_hop = max_hopself.dilation = dilationself.get_edge(layout) #确定图中节点间边的关系self.hop_dis = get_hop_distance( #获得邻接矩阵self.num_node, self.edge, max_hop=max_hop)self.get_adjacency(strategy) #根据分区策略获得邻域
def get_edge(self,layout):
根据layout是openpose还是ntu-rgb+d,我们可以决定是18个关键点还是25个关键点的骨架图
neighbor_link中描述了关键点之间的连接情况 。根据代码可以看出,中心点是 1号点,也就是脖子
def get_edge(self, layout):if layout == 'openpose':self.num_node = 18self_link = [(i, i) for i in range(self.num_node)]neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,11),(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]self.edge = self_link + neighbor_linkself.center = 1elif layout == 'ntu-rgb+d':self.num_node = 25self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),(22, 23), (23, 8), (24, 25), (25, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 21 - 1
OpenPose的18关键点顺序如下图。
注意,这里采用的是OpenPose的节点举例的。但是可以发现作者的节点连接顺序与OpenPose提供的输出格式连接顺序是不同的。这样的连接对结果没有影响,但也不能认为将OpenPose中的节点pair改为st-gcn的顺序就可以了,因为OpenPose中的PAF的训练是按照原顺序进行的。
分别对应的位置:
0-鼻子; 1-脖子 ; 2-右肩; 3-右肘; 4-右手腕; 5-左肩
6-左肘 ; 7-左手腕; 8-右臀; 9- 右膝盖; 10-右脚踝; 11-左臀
12-左膝盖; 13-左脚踝; 14-右眼; 15-左眼; 16-右耳; 17-左耳
def get_hop_distance(num_node,edge,max_hop=1):
def get_hop_distance(num_node, edge, max_hop=1):A = np.zeros((num_node, num_node))for i, j in edge: #构建邻接矩阵(无向图为对称矩阵)A[j, i] = 1A[i, j] = 1# compute hop stepshop_dis = np.zeros((num_node, num_node)) + np.inftransfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]#transfer_mat是list类型,需要将list堆叠成一个数组才能进行>操作arrive_mat = (np.stack(transfer_mat) > 0)for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = dreturn hop_dis
这段代码中获得了带自环的邻接矩阵 (18 * 18),其中非连接处为inf (即infinity,无穷大)
稍微复杂一点的是,hop_dis[arrive_mat[d]]=d的计算,第一次接触,参考了python中np.array 矩阵的高级操作_Booker Ye的博客-CSDN博客花了点时间才看懂。
大致就是如下的:
#max_hop=1
#假设 arrive_mat=array([[[True,True],
# [True,False]],
# [[True,False],
# [True,False]]])
# hop_dis=array([[inf,inf],[inf,inf]])
# inf是infinity的缩写,表示无穷大for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = d#结果为:
# hop_dis=array([0,0],
# [0,inf])
矩阵hop_dis中的元素是根据arrive_mat中对应位置的bool值进行变化的,若对应位置为True,则执行赋值操作
def get_adjacency(self,strategy):
def get_adjacency(self, strategy):#合法的距离值:0或1 ,抛弃了infvalid_hop = range(0, self.max_hop + 1, self.dilation)adjacency = np.zeros((self.num_node, self.num_node))for hop in valid_hop:adjacency[self.hop_dis == hop] = 1# 图卷积的预处理normalize_adjacency = normalize_digraph(adjacency)...elif strategy == 'spatial': #按照paper中的第三种分区策略:空间配置划分A = []for hop in valid_hop:a_root = np.zeros((self.num_node, self.num_node))a_close = np.zeros((self.num_node, self.num_node))a_further = np.zeros((self.num_node, self.num_node))for i in range(self.num_node):for j in range(self.num_node):if self.hop_dis[j, i] == hop: #如果i和j是邻接节点#比较节点i和j分别到中心点的距离,center点是1号点(脖子)if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:a_root[j, i] = normalize_adjacency[j, i]elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:a_close[j, i] = normalize_adjacency[j, i]else:a_further[j, i] = normalize_adjacency[j, i]if hop == 0:A.append(a_root) #A的第一维第一个矩阵:自身节点组else:A.append(a_root + a_close) #第一维第二个矩阵:向心组矩阵 (列对应节点到中心点的距离比行对应节点到中心点距离近或者相等)A.append(a_further) #第一维第三个矩阵:离心组矩阵A = np.stack(A)self.A = A #A的shape (3,18,18)...
def normalize_digraph(A): #归一化Dl = np.sum(A, 0) #计算邻接矩阵的度,将每一列的元素求和,Dl的shape为(18,1)num_node = A.shape[0] # 18Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i]**(-1) # Dn是一个对角矩阵,只有主对角元素,为度的倒数AD = np.dot(A, Dn) # A * Dnreturn AD
get_adjacency输出的是一个 (3,18,18)的权值分组A矩阵。
st-gcn.py
网络的输入
整个网络的输入是一个(N,C,T,V,M)的tensor
N : batch size 视频个数
C : 3 输入数据的通道数量 (X,Y,S)代表一个点的信息 (位置x,y + 置信度)
T : 300 一个视频的帧数 paper规定为300
V : 18 根据不同的骨骼获取的节点数而定,coco为18个节点
M : 2 paper中将人数限定在最大2个人
def forward(self, x):# data normalizationN, C, T, V, M = x.size()x = x.permute(0, 4, 3, 1, 2).contiguous()x = x.view(N * M, V * C, T)x = self.data_bn(x)x = x.view(N, M, V, C, T)x = x.permute(0, 1, 3, 4, 2).contiguous()x = x.view(N * M, C, T, V)# forwadfor gcn, importance in zip(self.st_gcn_networks, self.edge_importance):x, _ = gcn(x, self.A * importance)# global poolingx = F.avg_pool2d(x, x.size()[2:])x = x.view(N, M, -1, 1, 1).mean(dim=1)# predictionx = self.fcn(x)x = x.view(x.size(0), -1)return x
从这里可以看出在forward前就已经改变为(N*M,C,T,V),第一层输入就是(512,3,150,18)
512是N=256和M=2的乘积。
不考虑残差结构的话,每一个st-gcn块相当于是input -> gcn ->tcn -> output
网络的结构
# build networksspatial_kernel_size = A.size(0) #空间核大小 就是A 0维(N)的值,也就是batch sizetemporal_kernel_size = 9 #时间核大小为9kernel_size = (temporal_kernel_size, spatial_kernel_size) #核大小 (9,batch size)self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}self.st_gcn_networks = nn.ModuleList((st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 128, kernel_size, 2, **kwargs), #步长为2,作为池化层st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 256, kernel_size, 2, **kwargs), #步长为2,作为池化层st_gcn(256, 256, kernel_size, 1, **kwargs),st_gcn(256, 256, kernel_size, 1, **kwargs),))# initialize parameters for edge importance weightingif edge_importance_weighting:self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size()))for i in self.st_gcn_networks])else:self.edge_importance = [1] * len(self.st_gcn_networks)# fcn for predictionself.fcn = nn.Conv2d(256, num_class, kernel_size=1)
一共10层st_gcn层,但作者没有将第一层算在stgcn模块中,所以共9层。
每一个stgcn层都用residual模块改进。在源码中可以看出当通道数要增加时,使用1x1卷积来进行通道的翻倍,另外 步长为2来完成池化。
根据st-gcn的具体结构图
可以看出一个ST-GCN层包含了一个GCN模块和一个TCN模块,另外还有邻接矩阵和边权重矩阵(edge_importance)的内积,所以更新的模型也分为了两个方面,一是gcn和tcn内 卷积核参数,二是edge_importance内的参数。
GCN模块
gcn模块位于 tgcn.py . 主要过程是一个Conv2d和一个einsum,可以看forward函数
class ConvTemporalGraphical(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,t_kernel_size=1,t_stride=1,t_padding=0,t_dilation=1,bias=True):super().__init__()#这个kernel_size指的是空间上的kernal size,为3,也等于分区策略划分的子集数Kself.kernel_size = kernel_sizeself.conv = nn.Conv2d(in_channels,out_channels * kernel_size,kernel_size=(t_kernel_size, 1), #Conv(1,1)padding=(t_padding, 0),stride=(t_stride, 1),dilation=(t_dilation, 1),bias=bias)def forward(self, x, A):assert A.size(0) == self.kernel_sizex = self.conv(x) #这里输入x是(N,C,T,V),经过conv(x)之后变为(N,C*kernel_size,T,V)n, kc, t, v = x.size()# 这里把kernel_size的维度拿出来,变成(N,K,C,T,V)x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A)) #爱因斯坦约定求和法return x.contiguous(), A
有两个值得注意的操作:
x = self.conv(x)
x = torch.einsum('nkctv,kvw->nctw',(x,A))
这里的self.conv(x)的卷积核真正大小是(t_kernel_size,1),t_kernel_size预设值是1,那么就是一个1x1的卷积层,在第一层时就相当于把输入的(512,3,150,18)转变为(512,output_channels * kernel_size,150,18),第一层中output_channels为64,kernel_size是[1,2,3]中的一个,视分区策略而定。这一层1x1卷积只是把特征升维,且按自己分了多少组加倍。
对于卷积(1,1)的过程,可以参考一下这个链接的可视化过程ST-GCN中,空域图卷积的可视化过程_Lauris_P的博客-CSDN博客_gcn空域
n, kc, t, v = x.size()x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A))
前两句是把(512,output_channels * kernel_size,150,18)转化为了(512,kernel_size,64,150,18),也就是(n,k,c,t,v)的形状,而A是(K,V,V)的邻接矩阵,所以einsum()对A和x进行维度融合,相当于是
根据邻接矩阵中的邻接关系做了一次邻接节点间的特征融合,输出就变回了(N*M,C,T,V)的格式进入tcn。
TCN模块
该模块是让网络在时域上进行特征的提取,类似于LSTM。GCN的输出是一个(N,C,T,V),在TCN中可以理解为和CNN的输入格式(B,C,H,W)一样。要整合不同时间上的节点特征,就是对其进行卷积。
self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,(kernel_size[0], 1),(stride, 1),padding,),nn.BatchNorm2d(out_channels),nn.Dropout(dropout, inplace=True),)
tcn是用(temporal_kernel_size, 1)的卷积核对t维度进行卷积运算。这部分相对于gcn就很好理解了,就是正常的卷积操作,对于同一个节点在不同t下的特征的卷积。
gcn中是在单个时间t的图上生成新的特征和特征交流,tcn是在时间维度上特征交流。
其余参数解析,数据集加载,等等的代码就不写了。
更新:ST-GCN复现的全过程(详细)_Eric加油学!的博客-CSDN博客
ST-GCN复现的全过程
参考文章:ST-GCN的学习之路(二)源码解读 (Pytorch版)_LgrandStar的博客-CSDN博客_stgcn代码详解
ST-GCN中,空域图卷积的可视化过程_Lauris_P的博客-CSDN博客_gcn空域
ST-GCN源码分析相关推荐
- supervisor源码分析
Supervisor分析 1.运行原理概述: Supervisor生成主进程并将主进程变成守护进程,supervisor依次生成配置文件中的工作进程,然后依次监控工作进程的工作状态,并且主进程负责与s ...
- NanoHttpd源码分析
最近在GitHub上发现一个有趣的项目--NanoHttpd. 说它有趣,是因为他是一个只有一个Java文件构建而成,实现了部分http协议的http server. GitHub地址:https:/ ...
- Android源码分析--MediaServer源码分析(二)
在上一篇博客中Android源码分析–MediaServer源码分析(一),我们知道了ProcessState和defaultServiceManager,在分析源码的过程中,我们被Android的B ...
- SRS(simple-rtmp-server)流媒体服务器源码分析--启动
SRS(simple-rtmp-server)流媒体服务器源码分析--系统启动 一.前言 小卒最近看SRS源码,随手写下博客,其一为了整理思路,其二也是为日后翻看方便.如果不足之处,请指教! 首先总结 ...
- HTTP服务器的本质:tinyhttpd源码分析及拓展
已经有一个月没有更新博客了,一方面是因为平时太忙了,另一方面是想积攒一些干货进行分享.最近主要是做了一些开源项目的源码分析工作,有c项目也有python项目,想提升一下内功,今天分享一下tinyhtt ...
- Android ADB 源码分析(三)
前言 之前分析的两篇文章 Android Adb 源码分析(一) 嵌入式Linux:Android root破解原理(二) 写完之后,都没有写到相关的实现代码,这篇文章写下ADB的通信流程的一些细节 ...
- 这篇文章绝对让你深刻理解java类的加载以及ClassLoader源码分析
前言 package com.jvm.classloader;class Father2{public static String strFather="HelloJVM_Father&qu ...
- 解密android日志xlog,XLog 详解及源码分析
一.前言 这里的 XLog 不是微信 Mars 里面的 xLog,而是elvishew的xLog.感兴趣的同学可以看看作者 elvishwe 的官文史上最强的 Android 日志库 XLog.这里先 ...
- 三星uboot1.1.6源码分析——start.s(4)——从NAND复制源码到RAM(3)
通过上两篇博客终于把从NAND复制源码到RAM的c语言写的部分说完了,现在回到start.s中,接着分析余下的代码. ----------------------------------------- ...
- 并发-阻塞队列源码分析
阻塞队列 参考: http://www.cnblogs.com/dolphin0520/p/3932906.html http://endual.iteye.com/blog/1412212 http ...
最新文章
- 类 或 对象 的一些小点 【仅记录,方便以后查阅】
- mysql 1005 错误
- CSS学习17之动画
- 装饰器,闭包,高阶函数,嵌套函数
- hcna(华为)_Telnet篇
- 大数据的发展体现在哪些方面
- mysql 历史记录查询
- (二)面向对象设计原则
- window10c语言软件下载,win10中文语言包下载
- 银行即将关闭直接代扣通道,第三方支付有麻烦了
- r610服务器维修,戴尔服务器R610
- python程序填空快乐的数字_Python习题之快乐的数字
- 服务器采集数据源码,Skywalking数据采集与收集源码分析
- 应该将composer.lock致力于版本控制吗?
- 计算机网络技术之局域网
- CSS基础(P45-P65)
- latex中文简易模板,课程论文使用
- 用容斥原理计算具有有限重数的多重集合的 r-组合(附代码)
- Docker教程(1)Docker 入门
- CISCO banner MOTD, Login的区别