Safety-compliant Generative Adversarial Networks for Human Trajectory Forecasting

  • Abstract
  • I. INTRODUCTION
  • II. RELATED WORK
  • III. METHOD
    • A.问题定义
    • B.生成对抗网络
    • C.交互建模
    • D.SGANv2
    • E.训练
    • F.联合采样
  • IV. EXPERIMENTS

arXiv 论文翻译
利用符合安全性要求的生成对抗网络进行人类轨迹预测

Abstract

->人群中人类轨迹预测的主要难点:建模社会交互(social interactions) 和 生成无碰撞的多模态轨迹。
-> 随着SGAN(Social Generative Adversarial Networks)的成功,其后有许多‘人群中人类运动的建模’工作是基于GAN展开的。虽然这些工作在减少距离度量方面表现良好,但是都很难产生社会能够接受的轨迹;因为这些轨迹一般都会造成很高的碰撞风险。
->为了解决人类轨迹预测无法实际使用的问题(碰撞风险太高了),本文提出了SGANv2模型。这个模型同时具备‘时空交互建模’(spatio-temporal interaction modelling) 和 ‘基于注意力机制的判别器’。‘时空交互’能够更好建模人类的社会交互,‘基于注意力机制的判别器’能够更好的对时间序列进行建模。
->SGANv2 模型在测试阶段使用判别器进行联合采样(collaborative sampling)。联合采样不仅能够优化碰撞轨迹,而且能够缓解模式坍塌问题。
->在多个真实/合成数据集上,我们验证了SGANv2方法能够有效提供符合社会要求的多模态轨迹。

I. INTRODUCTION

人群中行人运动预测对于诸如自动驾驶、社会机器人这类的自动系统至关重要。为了准确预测行人轨迹,预测模型需要解决以下三个关键性挑战:

  1. 建模社会交互(Modelling social interactions):预测模型需要学会一个人的轨迹如何影响另一个人的轨迹
  2. 符合物理规律轨迹(Physically acceptable outputs):生成的轨迹不能相互碰撞。
  3. 多模态(Multimodality):给定历史信息,模型要能够输出未来的所有模式(mode)不是很懂所有的mode讲的是啥

多模态轨迹预测的目标就是学习一个能够预测未来轨迹的生成式子模型。GAN是轨迹生成式模型中较受欢迎的一个选择,它能够将一个噪声分布 映射为 目标真实分布。 Gupta 等人提出了SGAN模型用于学习人类交互(human interaction)和生成多模态轨迹,这个模型带有社会学机制(social mechanisms)。后来的许多改进方法,大都是在距离指标上表现良好,但是在‘社会交互’和‘碰撞轨迹’两个方面的表现实际上并不是很好。

之前的方法都没法生成低碰撞的轨迹 的原因可能是 这些模型中的判别器都没有建模人与人之间的交互,导致这些判别器没法区分真实轨迹和生成轨迹。只有当判别器能够区分出真实轨迹和生成轨迹的时候,从判别器到生成器的监督信息才是有指导意义的。基于这个判断,本文基于SGAN提出两点改进:

  1. 时空交互建模(spatio-temporal interaction modelling)- 能够使D更好的区分真假数据
  2. 基于注意力机制的判别器-加强时间序列的建模能力。

基于以上两点改进,本文提出的SGANv2 能够更好建模人类基本礼节,通过生成的轨迹碰撞较少能够得以体现。

为了进一步降低碰撞,本文提出的方法SGANv2在测试阶段也使用判别器。在测试的时候,我们在生成器和判别器之间使用 联合采样 来改进生成的不安全轨迹。从经验上表明,联合采样能够缓解GAN模型模式坍塌问题。

在合成和真实轨迹数据集上做对比试验验证SGANv2算法的效果。目前常用的Top-20 ADE/FDE 等指标其实不足以衡量生成轨迹的多样性。做了一个实验来验证以上说法:一个产生均匀分布的预测模型 与 当前最先进的预测模型 在Top-20 ADE/FDE两指标上表现一致。所以本文提出了一个新的评估方案,用来评估符合社会学的多样性性能指标。在新的评价指标里,本文的方法在多个数据集上表现良好。最后,文章在最近发布的Forking Paths 数据集上,验证了联合采样能够缓解模式坍塌问题。总结本文的2个主要贡献:

  1. 基于SGAN 提出 SGANv2, SGANv2模型中生成器和判别器同时采用时空交互建模,此外判别器采用基于transfermer的建模方式。
  2. 在测试的时候使用 联合采样 能够减少预测碰撞 和 缓解模式坍塌问题。

II. RELATED WORK

III. METHOD

SGANv2-主要有3个结构上的改动

  1. 生成器和判别器同时采用时空交互建模;
  2. 判别器采用基于transfermer的建模方式;
  3. 在测试的时候使用 生成器和判别器之间使用 联合采样 来改进生成的不安全轨迹;

A.问题定义

给定一个场景,模型输入:场景内所有行人的轨迹 X = { X 1 , X 2 , . . . , X n } X=\{X_1, X_2, ..., X_n\} X={X1,X2,...,Xn},其中n为行人总数。行人 i i i 的历史运动轨迹: X i = ( x i t , y i t ) X_i=(x_i^t, y_i^t) Xi=(xit,yit), 其中时间步长 t = 1 , 2 , . . . , T o b s t=1, 2,...,T_{obs} t=1,2,...,Tobs;未来gt轨迹表示为 Y i = ( x i t , y i t ) Y_i=(x_i^t, y_i^t) Yi=(xit,yit),其中时间步长 t = T o b s + 1 , . . . , T p r e d t=T_{obs}+1,...,T_{pred} t=Tobs+1,...,Tpred;模型对行人i 的预测输出表示为 Y ^ i \hat{Y}_i Y^i。轨迹预测模型需要同时预测所有人的未来轨迹 Y ^ = Y ^ 1 , Y ^ 2 . . . Y ^ n \hat{Y}=\hat{Y}_1,\hat{Y}_2...\hat{Y}_n Y^=Y^1,Y^2...Y^n。行人 i i i在时间t的速度表示为 v i t v_i^t vit

B.生成对抗网络

生成对抗网络的训练最小最大化博弈问题。
min ⁡ G max ⁡ D E x ∼ p r [ log ⁡ ( D ( x ) ) ] + E z ∼ p z [ 1 − l o g ( D ( G ( z ) ) ) ] \min_{G}\max_{D}\mathbb{E}_{x\sim p_r}[\log(D(x))]+\mathbb{E}_{z\sim p_z}[1-log(D(G(z)))] GminDmaxExpr[log(D(x))]+Ezpz[1log(D(G(z)))]

C.交互建模

社会交互建模对于生成安全准确轨迹至关重要。本文区分两个重要的概念:空间交互建模(spatial interaction modelling)-仅在一个时间步长内建立各个行人之间的交互,SGAN仅在生成器建模的过程中使用邻居行人的信息一次。时空交互建模 在每个时间步长(从 t = 1 t=1 t=1t = T p r e d t=T_{pred} t=Tpred)内使用空间交互建模。时间维度的建模通过序列建模机制,例如LSTM、Transformer。

D.SGANv2

SGANv2有三个关键组件:空间交互模块(Spatial Interaction embedding Module:SIM)、G、 D。 SIM 主要负责空间交互,G,D 负责时序交互,三者联合起来就能完成时空交互建模(Spatio-Temporal Interaction Modelling:STIM)。SIM负责建模每个行人的motion embedding + 各个行人在时间步长t处的spatial interaction embedding;G负责时间序列上的建模,利用LSTM encoder-decoder框架输出多模态预测轨迹; D使用Transfer, 输入观测轨迹X+预测轨迹 Y ^ \hat{Y} Y^/真实轨迹 Y Y Y,输出真假判断。

Spatial Interaction embedding Module:SIM–行人运动预测与一般的序列预测任务的主要区别是交互性:一个人的轨迹受到周围行人轨迹的影响。本文用一个单层的MLP 获得行人 i i i在时间 t t t时的运动embedding向量 e i t e_i^t eit,:
e i t = ϕ ( v i t ; W e m b ) e_i^t=\phi(v_i^t;W_{emb}) eit=ϕ(vit;Wemb)

其中: v i t v_i^t vit为输入速度, ϕ \phi ϕ 为embedding方程, W e m d W_{emd} Wemd为权重向量。可以使用附文[3] [43]所提出的方法获取行人 i i it t t时刻和场景交互的表征 p i t p_i^t pit, 然后将motion embedding e i t e_i^t eit和 spatial interaction p i t p_i^t pit拼接在一起 s i t = [ e i t , p i t ] s_i^t=[e_i^t, p_i^t] sit=[eit,pit],送给G/D。

Generator– LSTM encoder-decoder。encoder LSTM 建模每个时刻的空间embeding在时间维度上的交互关系:
h i t = L S T M e n c ( h i t − 1 , s i t ; W e n c o d e r ) h_i^t=LSTM_{enc}(h_i^{t-1},s_i^t;W_{encoder}) hit=LSTMenc(hit1,sit;Wencoder)

其中 h i t h_i^t hit为行人i在 t t t时刻的隐状态, W e n c o d e r W_{encoder} Wencoder为LSTM可学习的权重参数。encoder LSTM 最后一个输出表示为行人i的观察表征。本文与SGAN一致,使用这个观察表征作为GAN的条件输出。也就是说GAN 的生成器G的输入是行人 i i i的观察表征和随机噪声 z z z。decoder LSTM的隐状态由encoder LSTM最后一个隐状态初始化, decoder LSTM的的递归表达式为:
$ h i t = L S T M d e c ( h i t − 1 , [ s i t , z i ] ; W e n c o d e r ) h_i^t=LSTM_{dec}(h_i^{t-1},[s_i^t,z_i];W_{encoder}) hit=LSTMdec(hit1,[sit,zi];Wencoder)

没搞懂LSTM的输出和隐状态有啥区别?, 其中 W d e c o d e r W_{decoder} Wdecoder为LSTM可学习的权重参数。decoder LSTM 在t时刻的隐状态 h i t h_i^t hit被用于预测下一个时刻的速度 v i t + 1 v_i^{t+1} vit+1。下一个时刻的速度被建模为二元高斯分布,均值为 μ t + 1 = ( μ x , μ y ) t + 1 \mu^{t+1}=(\mu_x,\mu_y)^{t+1} μt+1=(μx,μy)t+1,方差为 σ t + 1 = ( σ x , σ y ) t + 1 \sigma^{t+1}=(\sigma_x,\sigma_y)^{t+1} σt+1=(σx,σy)t+1, 相关系数 ρ t + 1 \rho^{t+1} ρt+1
[ μ t , σ t , ρ t ] = ϕ ( h i t − 1 , W n o r m ) [\mu^t,\sigma^t,\rho^t]=\phi(h_i^{t-1}, W_{norm}) [μt,σt,ρt]=ϕ(hit1,Wnorm)

其中 ϕ d e c \phi_{dec} ϕdec为MLP, W n o r m W_{norm} Wnorm可学习参数。

Discriminator 人和人之间的交互是随着时间演进的,本文的D能够建模时空交互,主要是采用transfer 来进行时间序列的建模。判别器D的输入为 T r a j r e a l = [ X , Y ] Traj_{real}=[X,Y] Trajreal=[X,Y] 或者 T r a j f a k e = [ X , Y ^ ] Traj_{fake}=[X,\hat{Y}] Trajfake=[X,Y^], 输出相应的真假判断。判别器D有自己的空间交互建模SIM 来获取行人 i i i在时间步长 t t t里的空间交互embedding s i t s_i^t sit,将所有的时刻的embedding堆叠在一起:
S i = [ s i 1 ; s i 2 ; . . . , s i T p r e d ] S_i=[s_i^1;s_i^2;...,s_i^{Tpred}] Si=[si1;si2;...,siTpred]

这个堆叠向量将作为transformer网络的输入。transformer能够建模时间维度的交互性主要通过self-attention模块,在attention模块内部, S i S_i Si的每一个元素被分解为Q(query), K(key), V(value),输出矩阵由如下方程计算:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk

QKT)V

其中 d k d_k dks i t s_i^t sit的维度。attention模块的输出被normalized之后传给一个前馈网络以得到序列 S i S_i Si最终的表征R_i(representation):
R i = m a x ( 0 , A i ∗ W 1 + b 1 ) ∗ W 2 + b 2 R_i=max(0, A_i*W_1+b_1)*W2+b2 Ri=max(0,AiW1+b1)W2+b2

其中 W 1 , W 2 , b 1 , b 2 W_1,W_2,b_1,b_2 W1,W2,b1,b2都是可学习的; ∗ * 表示矩阵乘法; A i A_i Ai为attention模块正则化输出。最后将 R i R_i Ri输入到MLP ϕ d \phi_d ϕd中得到 S i S_i Si为真假的得分。

E.训练

SGANv2实际是一个conditional GAN模型,它的一个输入是从 N ( 0 , 1 ) N(0,1) N(0,1)中采样的随机噪声向量 z z z,另一个输入是历史观测轨迹 X X X, 输出未来的预测轨迹 X X X。研究人员发现最小二乘[60]目标函数有利于SGANv2的训练。最小二乘GAN不是从距离度量的角度来改进GAN模型,而是从防止D训练的过好而导致的梯度消失入手。:
min ⁡ G L ( G ) = 1 2 E z ∼ p z [ ( D ( X , G ( X , z ) ) − 1 ) 2 ] \min_G\mathcal{L}(G)=\frac{1}{2}\mathbb{E}_{z\sim p_z}[(D(X,G(X,z))-1)^2] GminL(G)=21Ezpz[(D(X,G(X,z))1)2]

min ⁡ D L ( D ) = 1 2 E x ∼ p r [ ( D ( X , Y ) − 1 ) 2 ] + 1 2 E z ∼ p z [ ( D ( X , G ( X , z ) ) ) 2 ] \min_D\mathcal{L}(D)=\frac{1}{2}\mathbb{E}_{x\sim p_r}[(D(X,Y)-1)^2]+\frac{1}{2}\mathbb{E}_{z\sim p_z}[(D(X,G(X,z)))^2] DminL(D)=21Expr[(D(X,Y)1)2]+21Ezpz[(D(X,G(X,z)))2]

为了增加生成样本的多样性,本文还使用了多样行loss[2]。对于每个场景模型会生成k条轨迹 ,利用最近的生成轨迹进行L2距离惩罚:
L v a r i e t y = min ⁡ k ∣ ∣ Y − G ( X , z ) ( k ) ∣ ∣ 2 2 \mathcal{L}_{variety}=\min_{k}||Y-G(X,z)(k)||_2^2 Lvariety=kmin∣∣YG(X,z)(k)22
在参考论文[43]策略的指导下,G只预测场景中感兴趣的行人。在测试的时候,同时预测一个场景中的所有行人,所有行人的预测任务共享相同的模型参数。

F.联合采样

联合采样是个什么玩意儿?,再看看吧

IV. EXPERIMENTS

PaperNotes(21)-Safety-compliant Generative Adversarial Networks for Human Trajectory Forecasting相关推荐

  1. Generative Adversarial Networks in Computer Vision: A Survey and Taxonomy(计算机视觉中的GANs:综述与分类)

    Abstract: 生成对抗网络(GANs)在过去几年得到了广泛的研究.可以说,他们最重要的影响是在计算机视觉领域,在挑战方面取得了巨大的进步,如可信的图像生成,图像之间的翻译,面部属性操纵和类似领域 ...

  2. 论文翻译:2019_Bandwidth Extension On Raw Audio Via Generative Adversarial Networks

    论文地址:原始音频的带宽扩展通过生成对抗网络 博客作者:凌逆战 博客地址:https://www.cnblogs.com/LXP-Never/p/10661950.html 摘要 基于神经网络的方法最 ...

  3. ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks论文翻译——中英文对照

    文章作者:Tyan 博客:noahsnail.com  |  CSDN  |  简书 声明:作者翻译论文仅为学习,如有侵权请联系作者删除博文,谢谢! 翻译论文汇总:https://github.com ...

  4. ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks论文翻译——中文版

    文章作者:Tyan 博客:noahsnail.com  |  CSDN  |  简书 声明:作者翻译论文仅为学习,如有侵权请联系作者删除博文,谢谢! 翻译论文汇总:https://github.com ...

  5. 史上最全GAN综述2020版:算法、理论及应用(A Review on Generative Adversarial Networks: Algorithms, Theory, and Applic)

    ** ** 史上最全GAN综述2020版:算法.理论及应用** 论文地址:https://arxiv.org/pdf/2001.06937.pdf ** 摘要:生成对抗网络(GANs)是近年来的一个研 ...

  6. Text to image论文精读 DM-GAN: Dynamic Memory Generative Adversarial Networks for t2i 用于文本图像合成的动态记忆生成对抗网络

    Text to image论文精读 DM-GAN: Dynamic Memory Generative Adversarial Networks for Text-to-Image Synthesis ...

  7. MM2018/风格迁移-Style Separation and Synthesis via Generative Adversarial Networks通过生成性对抗网络进行风格分离和合成

    Style Separation and Synthesis via Generative Adversarial Networks通过生成性对抗网络进行风格分离和合成 0.摘要 1.概述 2.相关工 ...

  8. 论文翻译:2018_Speech Bandwidth Extension Using Generative Adversarial Networks

    论文地址:基于生成对抗网络的语音频带扩展 博客作者(引用请指明出处):https://www.cnblogs.com/LXP-Never/p/10121897.html 摘要 语音盲带宽扩展技术已经出 ...

  9. 自动驾驶轨迹预测论文阅读(三)Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks

    [略读]GUPTA A, JOHNSON J, FEI-FEI L, et al., 2018. Social GAN: Socially Acceptable Trajectories with G ...

  10. Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 中文翻译

    Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 中文翻译 如有异议,请多指教,非专业 ...

最新文章

  1. Go 1.12发布:改进了运行时性能以及模块支持
  2. Docker创建虚机和swarm
  3. 最近和朋友微信卖螃蟹有点偏离重心了
  4. javascript 中的暗物质 - 闭包
  5. c语言e怎么表示_来测测!这11个C语言入门基础知识你都掌握了吗?
  6. 文本框为空按钮不可点击
  7. SAP License:自动创建带内部订单预算管控的在建工程
  8. 几个支持SCORM的免费平台
  9. 地理信息系统概论_南京大学815地理信息系统概论考研初试历年真题参考书目重难点笔记...
  10. 正则表达式判断手机号码
  11. Apache-Tomcat-Ajp文件读取漏洞(CVE-2020-1938、CNVD-2020-10487)
  12. 解决go合约fabric shim peer依赖问题
  13. java微信登录_java微信授权登陆
  14. ad7606与stm32连接电路介绍
  15. 电脑用户没有admin权限,如何配置node开发环境
  16. 3D电视与3D眼镜的工作原理
  17. 积分价值调整的两个原因
  18. 计算机网络断网吗,教您解决电脑网络常常断网掉线的方法?
  19. linux蓝牙日志,linux蓝牙
  20. 换行和回车(/n /r)

热门文章

  1. 戴飞创业笔记:做一个清醒的人!
  2. Android CameraX实现摄像头预览、拍照、录制视频
  3. LinuxCNC学习-Machinekit手册介绍
  4. Windows终端安全处置建议
  5. IDEA开发 springboot 日志彩色消失
  6. java解析fml32_学习了FML后,自己练手写了一个tuxedo FML32服务器端和客户端程序,供大家参考...
  7. bjca数字认证产品垃圾,服务垃圾
  8. mysql5.7 ERROR:Access denied for user ‘root‘@‘%完美解决
  9. Qt配置OpenCV教程,亲测已试过
  10. Cai_Sublime