在上篇的blog中,写了一下对于ST-GCN论文的分析ST-GCN论文分析_Eric加油学!的博客-CSDN博客,这篇blog写一下对于ST-GCN源码的理解和整理,参考了一些写的比较好的文章,在文末附上链接。

文末有更新的ST-GCN复现全过程(详细)

目录

整体的运行逻辑

静态代码流

动态代码流

核心代码分析

graph.py

def get_edge(self,layout):

def get_hop_distance(num_node,edge,max_hop=1):

def get_adjacency(self,strategy):

st-gcn.py

GCN模块

TCN模块


整体的运行逻辑

静态代码流

主程序  st-gcn-master/main.py

#主程序  st-gcn-master/main.py
#调用不同类文件 (recognition类、demo类)#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()]) #c、接受命令行中参数
arg = parser.parse_args()#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]
p = Processor(sys.argv[2:]) #e、调用recognition类中开始函数
p.start()

类程序1:recognition类        st-gcn-master/processor/recognition.py

def weights_init(m):    #权重初始化class REC_Processor(Processor):def load_model(self):        #加载模型def load_optimizer(self):    #加载优化器def adjust_lr(self):         #调整学习率def show_topk(self,k):        #显示精度def train(self):def test(self,evaluation = True):def get_parser(add_help = False):

类程序2:processor类        st-gcn-master/processor/processor.py

class Processor(IO):def __init__(self,argv=None):def init_environment(self):def load_optimizer(self):def load_data(self):def show_epoch_info(self):def show_iter_info(self):def train(self):def test(self):def start(self):def get_parser(add_help = False):

类程序3:IO类         st-gcn-master/processor/io.py

class IO():def __init__(self,argv=None):def load_arg(self,argv=None):def init_environment(self):def load_model(self):def load_weights(self):def gpu(self):def start(self):def get_parser(add_help=False):

动态代码流

以NTU交叉主题模型训练为例:

当在终端输入命令: python main.py recognition -c config/st_gcn/ntu-xsub/train.yaml后,执行主程序

#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')
#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()])
#    --->    def get_parser(add_help=False):
#c、接受命令行中参数
arg = parser.parse_args()#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]    #arg:processor:recognition
p = Processor(sys.argv[2:])   #sys.argv[2:]:-c config/st_gcn/ntu-xsub/train.yaml

其中实例化和初始化过程如下:

# processor/processor.py
class Processor(IO):def __init__(self,argv=None):self.load_argv(argv)    #参数加载,得到self.arg# --> def load_arg(self,argv=None):'''1、读取默认参数到参数表2、使用输入参数更新参数表3、读取参数配置文件更新参数表 配置文件:ntu-xsub/train.yaml4、使用输入参数更新参数表'''self.init_environment():super().init_environment() #继承调用 processor/io.py'''self.io=        获取自定义包中self.io类self.io.save_arg(self.arg)    将参数表保存到工作区配置文件如果使用GPU:获取GPU号和设备号'''#添加定义类参数self.result = dict()self.iter_info = dict()self.epoch_info = dict()self.meta_info = dict(epoch=0, iter=0)self.load_model()    # --> recognition.pydef load_model(self):self.model =     #下载模型,获得模型self.model#模型文件: /net/st_gcn.pyself.model.apply(weights_init)  #权重初始化,见def weights_init(m):self.loss = nn.CrossEntropyLoss()   #定义交叉商为损失函数self.load_weights():self.gpu():'''def gpu(self):  将self.arg、self.io等放到gpus上如果使用gpu且数量大于1,使模型并行'''self.load_data():#--> def load_data(self):'''Feeder = import_class(self.arg.feeder)    导入Feeder类self.arg.train_feeder_args['debug'] = self.arg.debugself.data_loader = dict()       #建立数据字典self.data_loader['train'] =      #训练数据装入self.data_loader['test'] =        #测试数据装入'''self.load_optimizer():#-->def load_optimizer(self):#self.optimizer= 定义模型优化参数

开始函数

#e、调用recognition类中开始函数
p.start()
def start(self):if self.arg.phase == 'train':for epoch in range():  #迭代epochself.train()'''def train(self):self.adjust_lr()loader = self.data_loader['train']for data,label in loader:#数据放到GPU上#forward#backward'''self.io.save_model(self.model, filename)    #保存模型参数self.test() #测试评估模型'''下载测试数据迭代获取结果,显示精度'''

核心代码分析

核心代码共分3个文件,在net文件夹下,分别为graph.py, tgcn.py, st-gcn.py。

graph.py 中包含邻接矩阵的建立和节点分组策略

st-gcn.py 包含整个网络部分的结构和前向传播方法

tgcn.py    主要是空间域卷积的结构和前向传播方法

graph.py

class Graph的构造函数使用了self.get_edge、self.hop_dis 、self.get_adjacency

    def __init__(self,layout='openpose',strategy='uniform',max_hop=1,dilation=1):self.max_hop = max_hopself.dilation = dilationself.get_edge(layout)   #确定图中节点间边的关系self.hop_dis = get_hop_distance(    #获得邻接矩阵self.num_node, self.edge, max_hop=max_hop)self.get_adjacency(strategy)    #根据分区策略获得邻域

def get_edge(self,layout):

根据layout是openpose还是ntu-rgb+d,我们可以决定是18个关键点还是25个关键点的骨架图

neighbor_link中描述了关键点之间的连接情况 。根据代码可以看出,中心点是 1号点,也就是脖子

    def get_edge(self, layout):if layout == 'openpose':self.num_node = 18self_link = [(i, i) for i in range(self.num_node)]neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,11),(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]self.edge = self_link + neighbor_linkself.center = 1elif layout == 'ntu-rgb+d':self.num_node = 25self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),(22, 23), (23, 8), (24, 25), (25, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 21 - 1

OpenPose的18关键点顺序如下图。

注意,这里采用的是OpenPose的节点举例的。但是可以发现作者的节点连接顺序与OpenPose提供的输出格式连接顺序是不同的。这样的连接对结果没有影响,但也不能认为将OpenPose中的节点pair改为st-gcn的顺序就可以了,因为OpenPose中的PAF的训练是按照原顺序进行的。

分别对应的位置:

0-鼻子;        1-脖子 ;        2-右肩;        3-右肘;        4-右手腕;        5-左肩

6-左肘 ;        7-左手腕;     8-右臀;        9- 右膝盖;   10-右脚踝;      11-左臀

12-左膝盖;    13-左脚踝;  14-右眼;       15-左眼;      16-右耳;        17-左耳

def get_hop_distance(num_node,edge,max_hop=1):

def get_hop_distance(num_node, edge, max_hop=1):A = np.zeros((num_node, num_node))for i, j in edge:   #构建邻接矩阵(无向图为对称矩阵)A[j, i] = 1A[i, j] = 1# compute hop stepshop_dis = np.zeros((num_node, num_node)) + np.inftransfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]#transfer_mat是list类型,需要将list堆叠成一个数组才能进行>操作arrive_mat = (np.stack(transfer_mat) > 0)for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = dreturn hop_dis

这段代码中获得了带自环的邻接矩阵 (18 * 18),其中非连接处为inf (即infinity,无穷大)

稍微复杂一点的是,hop_dis[arrive_mat[d]]=d的计算,第一次接触,参考了python中np.array 矩阵的高级操作_Booker Ye的博客-CSDN博客花了点时间才看懂。

大致就是如下的:

#max_hop=1
#假设 arrive_mat=array([[[True,True],
#                   [True,False]],
#                   [[True,False],
#                   [True,False]]])
# hop_dis=array([[inf,inf],[inf,inf]])
# inf是infinity的缩写,表示无穷大for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = d#结果为:
# hop_dis=array([0,0],
#               [0,inf])

矩阵hop_dis中的元素是根据arrive_mat中对应位置的bool值进行变化的,若对应位置为True,则执行赋值操作

def get_adjacency(self,strategy):

    def get_adjacency(self, strategy):#合法的距离值:0或1 ,抛弃了infvalid_hop = range(0, self.max_hop + 1, self.dilation)adjacency = np.zeros((self.num_node, self.num_node))for hop in valid_hop:adjacency[self.hop_dis == hop] = 1# 图卷积的预处理normalize_adjacency = normalize_digraph(adjacency)...elif strategy == 'spatial':  #按照paper中的第三种分区策略:空间配置划分A = []for hop in valid_hop:a_root = np.zeros((self.num_node, self.num_node))a_close = np.zeros((self.num_node, self.num_node))a_further = np.zeros((self.num_node, self.num_node))for i in range(self.num_node):for j in range(self.num_node):if self.hop_dis[j, i] == hop:   #如果i和j是邻接节点#比较节点i和j分别到中心点的距离,center点是1号点(脖子)if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:a_root[j, i] = normalize_adjacency[j, i]elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:a_close[j, i] = normalize_adjacency[j, i]else:a_further[j, i] = normalize_adjacency[j, i]if hop == 0:A.append(a_root)    #A的第一维第一个矩阵:自身节点组else:A.append(a_root + a_close)  #第一维第二个矩阵:向心组矩阵 (列对应节点到中心点的距离比行对应节点到中心点距离近或者相等)A.append(a_further) #第一维第三个矩阵:离心组矩阵A = np.stack(A)self.A = A      #A的shape (3,18,18)...
def normalize_digraph(A):    #归一化Dl = np.sum(A, 0)   #计算邻接矩阵的度,将每一列的元素求和,Dl的shape为(18,1)num_node = A.shape[0]   # 18Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i]**(-1)  # Dn是一个对角矩阵,只有主对角元素,为度的倒数AD = np.dot(A, Dn)      # A * Dnreturn AD

get_adjacency输出的是一个 (3,18,18)的权值分组A矩阵。

st-gcn.py

网络的输入

整个网络的输入是一个(N,C,T,V,M)的tensor

N : batch size  视频个数
C : 3           输入数据的通道数量 (X,Y,S)代表一个点的信息 (位置x,y + 置信度)
T : 300         一个视频的帧数 paper规定为300
V : 18          根据不同的骨骼获取的节点数而定,coco为18个节点
M : 2           paper中将人数限定在最大2个人
    def forward(self, x):# data normalizationN, C, T, V, M = x.size()x = x.permute(0, 4, 3, 1, 2).contiguous()x = x.view(N * M, V * C, T)x = self.data_bn(x)x = x.view(N, M, V, C, T)x = x.permute(0, 1, 3, 4, 2).contiguous()x = x.view(N * M, C, T, V)# forwadfor gcn, importance in zip(self.st_gcn_networks, self.edge_importance):x, _ = gcn(x, self.A * importance)# global poolingx = F.avg_pool2d(x, x.size()[2:])x = x.view(N, M, -1, 1, 1).mean(dim=1)# predictionx = self.fcn(x)x = x.view(x.size(0), -1)return x

从这里可以看出在forward前就已经改变为(N*M,C,T,V),第一层输入就是(512,3,150,18)

512是N=256和M=2的乘积。

不考虑残差结构的话,每一个st-gcn块相当于是input -> gcn ->tcn -> output

网络的结构

# build networksspatial_kernel_size = A.size(0)     #空间核大小 就是A 0维(N)的值,也就是batch sizetemporal_kernel_size = 9            #时间核大小为9kernel_size = (temporal_kernel_size, spatial_kernel_size)   #核大小 (9,batch size)self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}self.st_gcn_networks = nn.ModuleList((st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 64, kernel_size, 1, **kwargs),st_gcn(64, 128, kernel_size, 2, **kwargs),      #步长为2,作为池化层st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 128, kernel_size, 1, **kwargs),st_gcn(128, 256, kernel_size, 2, **kwargs),     #步长为2,作为池化层st_gcn(256, 256, kernel_size, 1, **kwargs),st_gcn(256, 256, kernel_size, 1, **kwargs),))# initialize parameters for edge importance weightingif edge_importance_weighting:self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size()))for i in self.st_gcn_networks])else:self.edge_importance = [1] * len(self.st_gcn_networks)# fcn for predictionself.fcn = nn.Conv2d(256, num_class, kernel_size=1)

一共10层st_gcn层,但作者没有将第一层算在stgcn模块中,所以共9层。

每一个stgcn层都用residual模块改进。在源码中可以看出当通道数要增加时,使用1x1卷积来进行通道的翻倍,另外 步长为2来完成池化。

根据st-gcn的具体结构图

可以看出一个ST-GCN层包含了一个GCN模块和一个TCN模块,另外还有邻接矩阵和边权重矩阵(edge_importance)的内积,所以更新的模型也分为了两个方面,一是gcn和tcn内 卷积核参数,二是edge_importance内的参数。

GCN模块

gcn模块位于 tgcn.py . 主要过程是一个Conv2d和一个einsum,可以看forward函数

class ConvTemporalGraphical(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,t_kernel_size=1,t_stride=1,t_padding=0,t_dilation=1,bias=True):super().__init__()#这个kernel_size指的是空间上的kernal size,为3,也等于分区策略划分的子集数Kself.kernel_size = kernel_sizeself.conv = nn.Conv2d(in_channels,out_channels * kernel_size,kernel_size=(t_kernel_size, 1), #Conv(1,1)padding=(t_padding, 0),stride=(t_stride, 1),dilation=(t_dilation, 1),bias=bias)def forward(self, x, A):assert A.size(0) == self.kernel_sizex = self.conv(x)    #这里输入x是(N,C,T,V),经过conv(x)之后变为(N,C*kernel_size,T,V)n, kc, t, v = x.size()# 这里把kernel_size的维度拿出来,变成(N,K,C,T,V)x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A)) #爱因斯坦约定求和法return x.contiguous(), A

有两个值得注意的操作:

x = self.conv(x)

x = torch.einsum('nkctv,kvw->nctw',(x,A))

这里的self.conv(x)的卷积核真正大小是(t_kernel_size,1),t_kernel_size预设值是1,那么就是一个1x1的卷积层,在第一层时就相当于把输入的(512,3,150,18)转变为(512,output_channels * kernel_size,150,18),第一层中output_channels为64,kernel_size是[1,2,3]中的一个,视分区策略而定。这一层1x1卷积只是把特征升维,且按自己分了多少组加倍。

对于卷积(1,1)的过程,可以参考一下这个链接的可视化过程ST-GCN中,空域图卷积的可视化过程_Lauris_P的博客-CSDN博客_gcn空域

        n, kc, t, v = x.size()x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A))

前两句是把(512,output_channels * kernel_size,150,18)转化为了(512,kernel_size,64,150,18),也就是(n,k,c,t,v)的形状,而A是(K,V,V)的邻接矩阵,所以einsum()对A和x进行维度融合,相当于是

根据邻接矩阵中的邻接关系做了一次邻接节点间的特征融合,输出就变回了(N*M,C,T,V)的格式进入tcn。

TCN模块

该模块是让网络在时域上进行特征的提取,类似于LSTM。GCN的输出是一个(N,C,T,V),在TCN中可以理解为和CNN的输入格式(B,C,H,W)一样。要整合不同时间上的节点特征,就是对其进行卷积。

        self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,(kernel_size[0], 1),(stride, 1),padding,),nn.BatchNorm2d(out_channels),nn.Dropout(dropout, inplace=True),)

tcn是用(temporal_kernel_size, 1)的卷积核对t维度进行卷积运算。这部分相对于gcn就很好理解了,就是正常的卷积操作,对于同一个节点在不同t下的特征的卷积
gcn中是在单个时间t的图上生成新的特征和特征交流,tcn是在时间维度上特征交流。


其余参数解析,数据集加载,等等的代码就不写了。


更新:ST-GCN复现的全过程(详细)_Eric加油学!的博客-CSDN博客

ST-GCN复现的全过程

参考文章:ST-GCN的学习之路(二)源码解读 (Pytorch版)_LgrandStar的博客-CSDN博客_stgcn代码详解

ST-GCN中,空域图卷积的可视化过程_Lauris_P的博客-CSDN博客_gcn空域

ST-GCN源码分析相关推荐

  1. supervisor源码分析

    Supervisor分析 1.运行原理概述: Supervisor生成主进程并将主进程变成守护进程,supervisor依次生成配置文件中的工作进程,然后依次监控工作进程的工作状态,并且主进程负责与s ...

  2. NanoHttpd源码分析

    最近在GitHub上发现一个有趣的项目--NanoHttpd. 说它有趣,是因为他是一个只有一个Java文件构建而成,实现了部分http协议的http server. GitHub地址:https:/ ...

  3. Android源码分析--MediaServer源码分析(二)

    在上一篇博客中Android源码分析–MediaServer源码分析(一),我们知道了ProcessState和defaultServiceManager,在分析源码的过程中,我们被Android的B ...

  4. SRS(simple-rtmp-server)流媒体服务器源码分析--启动

    SRS(simple-rtmp-server)流媒体服务器源码分析--系统启动 一.前言 小卒最近看SRS源码,随手写下博客,其一为了整理思路,其二也是为日后翻看方便.如果不足之处,请指教! 首先总结 ...

  5. HTTP服务器的本质:tinyhttpd源码分析及拓展

    已经有一个月没有更新博客了,一方面是因为平时太忙了,另一方面是想积攒一些干货进行分享.最近主要是做了一些开源项目的源码分析工作,有c项目也有python项目,想提升一下内功,今天分享一下tinyhtt ...

  6. Android ADB 源码分析(三)

    前言 之前分析的两篇文章 Android Adb 源码分析(一) 嵌入式Linux:Android root破解原理(二) 写完之后,都没有写到相关的实现代码,这篇文章写下ADB的通信流程的一些细节 ...

  7. 这篇文章绝对让你深刻理解java类的加载以及ClassLoader源码分析

    前言 package com.jvm.classloader;class Father2{public static String strFather="HelloJVM_Father&qu ...

  8. 解密android日志xlog,XLog 详解及源码分析

    一.前言 这里的 XLog 不是微信 Mars 里面的 xLog,而是elvishew的xLog.感兴趣的同学可以看看作者 elvishwe 的官文史上最强的 Android 日志库 XLog.这里先 ...

  9. 三星uboot1.1.6源码分析——start.s(4)——从NAND复制源码到RAM(3)

    通过上两篇博客终于把从NAND复制源码到RAM的c语言写的部分说完了,现在回到start.s中,接着分析余下的代码. ----------------------------------------- ...

  10. 并发-阻塞队列源码分析

    阻塞队列 参考: http://www.cnblogs.com/dolphin0520/p/3932906.html http://endual.iteye.com/blog/1412212 http ...

最新文章

  1. 类 或 对象 的一些小点 【仅记录,方便以后查阅】
  2. mysql 1005 错误
  3. CSS学习17之动画
  4. 装饰器,闭包,高阶函数,嵌套函数
  5. hcna(华为)_Telnet篇
  6. 大数据的发展体现在哪些方面
  7. mysql 历史记录查询
  8. (二)面向对象设计原则
  9. window10c语言软件下载,win10中文语言包下载
  10. 银行即将关闭直接代扣通道,第三方支付有麻烦了
  11. r610服务器维修,戴尔服务器R610
  12. python程序填空快乐的数字_Python习题之快乐的数字
  13. 服务器采集数据源码,Skywalking数据采集与收集源码分析
  14. 应该将composer.lock致力于版本控制吗?
  15. 计算机网络技术之局域网
  16. CSS基础(P45-P65)
  17. latex中文简易模板,课程论文使用
  18. 用容斥原理计算具有有限重数的多重集合的 r-组合(附代码)
  19. Docker教程(1)Docker 入门
  20. CISCO banner MOTD, Login的区别

热门文章

  1. 【计算机网络】计算机网络的性能指标
  2. 常见业务指标(用户、行为、业务)
  3. 机器人创业大赛 | RoboStartup开营初体验
  4. Python中的6种标准数据类型
  5. 国产手机厂商扎堆海外,vivo成国际化标杆
  6. 2022熔化焊接与热切割试题及答案
  7. ACCESS数据库开发-DLookup and DCount
  8. 00003 不思议迷宫.0009.7:一键采矿(钻石、金蛋等)
  9. 阿里的职级划分 ---- P8是什么概念
  10. 7-1 时间换算(15分)