import torch
import torch.nn as nn
import torch.nn.functional as F# 先验网络
class Prior(nn.Module):def __init__(self,input_size=256,output_size=64):super(Prior,self).__init__()self.input_size = input_size # 输入大小self.output_size = output_size # 输出大小self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):# x:[bs,256]x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 后验网络
class Recognition(nn.Module):def __init__(self,input_size=512,output_size=64):super(Recognition,self).__init__()self.input_size = input_size # 输入大小self.output_size = output_size # 输出大小self.fc1 = nn.Linear(input_size,256)self.fc2 = nn.Linear(256,128)self.fc3 = nn.Linear(128,output_size)def forward(self,x):# x:[bs,512]x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 目标网络
class Goal(nn.Module):def __init__(self,input_size=256+32,output_size=2):super(Goal,self).__init__()self.input_size = input_size # 输入大小self.output_size = output_size # 输出大小self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):# x:[bs,256+32]x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xclass BiTrap(nn.Module):def __init__(self,embedding_size=64,input_size=2,output_size=2,gru_size=256,latent_size=32,gru_input_size=64,num_layer=1,pre_len=12):super(BiTrap,self).__init__()self.input_size = input_size # 输入大小self.output_size = output_size # 输出大小self.embedding_size = embedding_size # 空间编码self.gru_size = gru_size # GRU隐藏层大小self.latent_size = latent_size # z的大小self.gru_input_size = gru_input_size # GRU输入大小self.num_layer = num_layer # GRU层数self.pre_len = pre_len # 预测长度# 观测轨迹升维self.fcx = nn.Linear(input_size,embedding_size)# 预测轨迹升维self.fcy = nn.Linear(input_size,embedding_size)# 观测轨迹GRUself.grux = nn.GRU(embedding_size,gru_size)# 预测轨迹GRUself.gruy = nn.GRU(embedding_size,gru_size)# 先验网络 后验网络 目标网络self.prior = Prior(gru_size,2*latent_size)self.recognition = Recognition(gru_size*2,2*latent_size)self.goal = Goal(gru_size+latent_size,output_size)# 状态转换网络self.fcf = nn.Linear(gru_size+latent_size,gru_input_size)self.fc2 = nn.Linear(gru_size+latent_size,gru_size)self.fc3 = nn.Linear(gru_size,gru_input_size)self.fc4 = nn.Linear(gru_size+latent_size,gru_size)self.fc5 = nn.Linear(output_size,gru_input_size)# GRUcellself.forward_gru = nn.GRUCell(gru_input_size,gru_size)self.backward_gru = nn.GRUCell(gru_input_size,gru_size)# 最后的输出self.fc6 = nn.Linear(2*gru_size,output_size)def forward(self,x,mode='train',y=None):n,s = x.shape[-1],x.shape[2]# [bs,c,seq_len,n]->[seq_len,n,c]c=2x = x.squeeze(0).permute(1,2,0)# [seq_len*n,c]x = x.reshape(-1,x.shape[-1])# [seq_len*n,embedding]x = self.fcx(x)# [seq_len,n,embedding]x = x.reshape(s,n,-1)x_gru_h = torch.randn(self.num_layer,n,self.gru_size).cuda()_,h = self.grux(x,x_gru_h)# [n,embedding]x = h.squeeze(0)# 重参数化e = torch.randn(n,self.latent_size).cuda()# 求得先验分布p = self.prior(x)if(mode=='train'):n,s = y.shape[-1],y.shape[2]# [bs,c,seq_len,n]->[seq_len,n,c]c=2y = y.squeeze(0).permute(1,2,0)# [seq_len*n,c]y = y.reshape(-1,y.shape[-1])# [seq_len*n,embedding]y = self.fcy(y)# [seq_len,n,embedding]y = y.reshape(s,n,-1)y_gru_h = torch.randn(self.num_layer,n,self.gru_size).cuda()_,y = self.gruy(y,y_gru_h)# [n,embedding]y = y.squeeze(0)# [n,2*embeddin]y = torch.cat((x,y),1)# 求得后验分布q = self.recognition(y)z = q[:,0:self.latent_size]+q[:,self.latent_size:]*eelse:z = p[:,0:self.latent_size]+p[:,self.latent_size:]*e# [n,gru_size+latent_size]x = torch.cat((h.squeeze(0),z),1)# 求Goal [n,2]g = self.goal(x)forward_gru_h = self.fc2(x)f = self.fcf(x) # 前项输入backward_gru_h = self.fc4(x)b = self.fc5(g) # 后项输入# 计算前向forward_output = []for i in range(self.pre_len):forward_gru_h = self.forward_gru(f,forward_gru_h)forward_output.append(forward_gru_h)f = self.fc3(forward_gru_h)# 计算后向backward_output = []for i in range(self.pre_len-1,-1,-1):backward_gru_h = self.backward_gru(b,forward_gru_h)temp = torch.cat((forward_output[i],backward_gru_h),1)output = self.fc6(temp)backward_output.append(output)b = self.fc5(output)if(mode=='train'):return p,q,g,torch.stack(backward_output,0).unsqueeze(0)else:return torch.stack(backward_output,0).unsqueeze(0)"""
x = torch.randn(1,2,8,24)
y = torch.randn(1,2,12,24)
prior = BiTrap(embedding_size=64,input_size=2,output_size=2,gru_size=256,latent_size=32,gru_input_size=64)
p,q,g,b = prior(x,'train',y)
print(p.shape)
print(q.shape)
print(g.shape)
print(b.shape)b = prior(x,'test')
print(b.shape)
"""

BiTraP:Bi-directional Pedestrian Trajectory Prediction with Multi-modal Goal Estimation相关推荐

  1. 行人轨迹论文阅读SSAGCN: Social Soft Attention Graph Convolution Network for Pedestrian Trajectory Prediction

    SSAGCN: Social Soft Attention Graph Convolution Network for Pedestrian Trajectory Prediction SSAGCN: ...

  2. 论文翻译 SGCN:Sparse Graph Convolution Network for Pedestrian Trajectory Prediction 用于行人轨迹预测的稀疏图卷积网络

    SGCN:Sparse Graph Convolution Network for Pedestrian Trajectory Prediction 用于行人轨迹预测的稀疏图卷积网络 行人轨迹预测是自 ...

  3. 文献翻译:Social LSTM: Human Trajectory Prediction in Crowded Spaces

      这是我阅读的有关轨迹预测的第一篇文献,其内容和使用的模型相对简单,是比较适合的入门篇,我在此把原文翻译分享出来,便于大家交流学习. 这里写目录标题                 Abstract ...

  4. 【ECCV2020】Spatio-Temporal Graph Transformer Networks for Pedestrian Trajectory Prediction

    [ECCV2020]用于行人轨迹预测的时空图 Transformer 网络 摘要 了解人群运动动力学对于现实世界的应用至关重要,例如监控系统和自动驾驶.这是具有挑战性的,因为它需要对具有社会意识的人群 ...

  5. 评估行人行动预测的基准——Benchmark for Evaluating Pedestrian Action Prediction

    评估行人行动预测的基准--Benchmark for Evaluating Pedestrian Action Prediction Date of Conference: 3-8 Jan. 2021 ...

  6. [论文阅读]用于车辆轨迹预测的卷积社交池Convolutional Social Pooling for Vehicle Trajectory Prediction

    文章目录 一.摘要 二.介绍 三.相关研究 3.1 基于机动的模型 3.2 交互感知模型 3.3 运动预测的递归网络 四.问题制定 4.1 参照系 4.2 输入输出 4.3 概率运动预测 4.4 操作 ...

  7. 顶会论文笔记:联邦学习——ATPFL: Automatic Trajectory Prediction Model Design under Federated Learning Framework

    ATPFL: Automatic Trajectory Prediction Model Design under Federated Learning Framework 文章目录 ATPFL: A ...

  8. 文献阅读笔记:EvolveGraph: Multi-Agent Trajectory Prediction with Dynamic Relational Reasoning

    文献阅读笔记 摘要 1 引言 2 相关工作 3 Problem formulation 4 EvolveGraph 5 Experiments 6 结论 EvolveGraph: Multi-Agen ...

  9. Social LSTM: Human Trajectory Prediction in Crowded Spaces 论文翻译

    摘要 行人可沿不同的轨道行走,以避开障碍物及方便其他行人.在这样的场景中行驶的任何自动驾驶车辆都应该能够预见行人未来的位置,并相应地调整其路径以避免碰撞.轨迹预测问题可以看作是一个序列生成任务,我们感 ...

最新文章

  1. TensorRT 加速性能分析
  2. [WWDC] What's New in Swift 4 ?
  3. ftp服务用户访问权限设置
  4. svn教程----svn简介
  5. ELK 企业级日志分析系统
  6. ArcGIS Python
  7. 是网关吗_智能家居网关功能这么多,你都知道吗?
  8. SpringBoot与Mybatis的集成
  9. Java基础学习总结(173)——Java 8到Java 15新功能总结
  10. Javascript是最好的编程语言吗?
  11. OpenJ_Bailian 2814 拨钟问题
  12. 循序渐进!java开发手册阿里巴巴泰山版
  13. linux每日命令(11):cat命令
  14. [技术分享] 融云开发案例核心代码分享
  15. Python学习笔记--6.2 文件读写
  16. Vivado2018的使用
  17. css实现背景图片透明
  18. 【VBS脚本】VBS复制Excel工作簿
  19. 【直播升级——AWS 云之旅】
  20. 玩转Oracle服务器连接

热门文章

  1. Beautiful Soup爬虫
  2. 本地文章上传阿里云文件上传
  3. Qt Phonon介绍及安装
  4. 服务器设置邮箱屏蔽,邮箱服务器IP被屏蔽的问题
  5. 【C++】LeetCode 题库 1834. 单线程 CPU
  6. mysql主从模拟主服务器坏掉,主从切换,主服务器修复,主从恢复
  7. 基于golang的Json选择器
  8. Vue学习笔记 —— 路径引入
  9. 计算机基础教法改革,高职计算机应用基础课程教学改革探索
  10. 中国电信设的“互联星空”陷井