1. 背景

之前介绍了如何在 RNN-T 流式模型上应用时延正则,以及在 Conformer 和 LSTM 上的实验结果。

本期公众号重点带大家回顾下具体的思路,以及如何类似地在 CTC 流式模型上应用时延正则。

有些内容可能有所重复,读者可适当跳过。

2. Delay penalty for RNN-T

标准 RNN-T

如图1所示,RNN-T lattice 包含了特征序列标签序列之间所有可能的对齐路径,两个序列的长度通常不一致。在 lattice 中,从点 (t,u) 出发,向上走的边表示输出 yu+1,分数为 y(t,u);向右走的边表示输出 ∅,分数为 ∅(t,u)。

此处我们提及的 lattice 边上的分数,无特殊说明情况下,都是 log-probability。

图1

假设 lattice 中路径 i 的 分数为 si,RNN-T 的目标函数 L 为最大化 lattice 中所有路径的分数之和:

L=log⁡∑iexp⁡(si)

我们通常使用动态规划算法 forward-backward[1] 来高效地计算目标函数 L,不需要显式计算每条路径的分数 si。具体地,令 α(t,u) 表示在 lattice 中在看到了特征 x0…t 的条件下,输出标签 y0…u 的分数。我们可以得到状态转移方程:

α(t,u)=LogAdd(α(t,u−1)+y(t,u−1),α(t−1,u)+∅(t−1,u)),

lattice 中所有路径的总分数 L,即状态转移的终点,可以计算为:

L=α(T−1,U)+∅(T−1,U)

我们可以发现,RNN-T 的目标函数 L 并没有考虑不同的路径所对应的时延。如图1所示,红色的路径更早地输出 symbol,时延较低;而蓝色的路径更晚地输出 symbol,时延较高。

与非流式模型不同,流式模型无法看到句子中所有的 context。流式模型为了看到更多的上下文,以达到更好的识别性能,会倾向于增强时延较高的路径, 如图1中蓝色的路径。如图2蓝色线所示,随着训练进行,没有时延正则的 RNN-T 流式模型的时延逐渐上升。

图2

Delay-penalized RNN-T

为了惩罚 RNN-T 模型的时延,我们的想法是在目标函数 L 上增加一个时延正则项 Ldelay,得到一个新的目标函数 Laug:

Laug=L+Ldelay

Ldelay 表示 lattice 中所有路径的平均时延分数(值越大,代表时延越低),定义为:

Ldelay=λ∑idiwi

其中,di 为路径 i 的时延分数,λ 是一个超参数,wi 为路径 i 的分数在整个 lattice 中的比重:

wi=∂L∂si=exp⁡(si)∑iexp⁡(si)

此处,di 的值越大,表示路径 i 的时延越低。

下文会具体讲解时延分数 di 的定义。

因此,通过引入时延正则项 Ldelay,RNN-T 会被约束着去增强那些时延较低(di 较大)的路径 i,为他们赋予一个更高的分数 si。

上文提到,我们在优化 L 的过程中,并没有显式计算各个路径 i 的分数 si。那么问题来了,为了优化 Laug,难道我们还要去显示地求出各个路径 i 的分数 si,来计算 wi 吗?这无疑是一种极其低效且不优雅的做法。

此时,Daniel 抛出了一长串数学公式,证明了我们可以优雅地实现 Laug 的优化。

由于篇幅限制,我们不在此列出具体的证明过程。感兴趣的同学可以阅读论文 https://arxiv.org/pdf/2211.00490.pdf,保证学过高中数学的同学都能看懂。

简而言之,对于一个较小的超参数 λ,带时延正则的目标函数 Laug 对路径分数 si 的导数 ∂Laug∂si 可以近似为:

∂Laug∂si≈exp⁡(λdi+si)∑iexp⁡(λdi+si)

我们只需要在优化标准目标函数 L 的过程中,将 si 替换为 λdi+si,即可达到近似地优化 Laug 的效果:

si′=λdi+si

接下来我们来讲一下在 RNN-T lattice 中如何定义 di。令 π={πu}0U−1 为输出标签序列 y0...U−1 (即向上走的边)的帧索引。我们定义路径 i 的时延分数 di 为这些帧索引 πu 相对于句子中间帧的 offset:

di=∑u(T−12−πu)

此处,之所以要加上它们相对于中间帧的 offset,是为了使得引入时延正则后,loss 函数的数值不会和原来相差太大。

图3

如图3所示,为了实现 si′,我们只需要修改 lattice 中那些输出 symbol 的边(即向上走的边),加上与帧索引对应的 offset:

y′(t,u)=y(t,u)+λ×(T−12−t)

因此,在执行 forward-backward 算法之前,我们只需要将 y(t,u) 替换为 y′(t,u),即可以一种简单高效的方式,近似地优化带时延正则的目标函数 Laug。

如图2中红色的线所示,通过在 RNN-T 目标函数上添加时延正则项,随着训练的进行,我们可以逐步降低流式模型的时延。

代码可以参考 k2 的 PR https://github.com/k2-fsa/k2/pull/976 和 icefall 的 PR https://github.com/k2-fsa/icefall/pull/654。

3. Delay penalty for CTC

CTC 的目标函数[2]和 RNN-T 目标函数的公式一样,也是最大化 lattice 中所有可能的对齐路径分数之和 L:

L=log⁡∑iexp⁡(si)

我们希望可以像 RNN-T 一样,对于 lattice 中每条路径,根据时延对应地修改它的分数 si,即 si′=λdi+si,达到近似地优化带时延正则的目标函数 Laug 的效果。

下面将介绍如何使用 k2 fsa 巧妙地实现这个功能。

大家可以下载文件 https://github.com/k2-fsa/next-gen-kaldi-wechat/blob/master/pdf/LF-MMI-training-and-decoding-in-k2-Part-I.pdf,了解如何用 k2 fsa 实现计算 CTC 目标函数。

图4

假设特征序列的长度为5,标签序列为 Z,O,O。利用 k2 fsa 我们可以得到对应的 CTC lattice。在图4所示,在 CTC lattice 中,每条从起点到终点的路径为:特征序列和标签序列之间的合法对齐路径。每条边上有三个属性:(1)输入标签(label);(2)输出标签( aux_label);(3)分数,即 log_softmax(encoder_output)

例如,以下三条对齐路径对应着不同的输入标签序列,他们的输出标签序列经过去除 ϵ 后,都可以得到 Z,O,O:

Z,O,∅,O,∅→Z,O,ϵ,O,ϵ

Z,Z,O,∅,O→Z,ϵ,O,ϵ,O

Z,∅,O,∅,O→Z,ϵ,O,ϵ,O

每条对齐路径的时延,取决于那些首次输出 symbol 的边的帧索引 π={πu}0U−1 ,如下面加粗的 symbol:

Z,O,∅,O,∅→Z,O,ϵ,O,ϵ

Z,Z,O,∅,O→Z,ϵ,O,ϵ,O

Z,∅,O,∅,O→Z,ϵ,O,ϵ,O

每条路径中,那些首次输出 symbol 的边的数量是相同的,为标签序列的长度 U。我们可以像上文 RNN-T 一样,定义每个路径 i 的时延分数 di 为:这些帧索引 πu 相对于句子中间帧的 offset。

图5

如图5所示,为了在 CTC 中实现 si′,我们只需要修改 lattice 中首次输出 symbol 的边(标记为红色)上的分数 yt,加上与帧索引(相对于中间帧)的 offset:

yt′=yt+λ×(T−12−t)

因此,在执行动态规划算法求 CTC lattice 中所有路径总分数之前,我们只需要将 yt 替换为 yt′,即可以一种简单高效的方式,近似地优化带时延正则的目标函数 Laug。

在 k2-fsa CTC 实现过程中,利用 k2.Fsa.get_total_scores() 求得 lattice 所有路径总分数。

具体地,如何修改 lattice 上那些首次输出 symbol 的边的分数,可以参考 k2 的 PR https://github.com/k2-fsa/k2/pull/1086,和 icefall 的 PR https://github.com/k2-fsa/icefall/pull/669,里面有详细的注释。

4. 实验结果

RNN-T

如表1所示,在使用 RNN-T 训练的流式 Conformer(chunk=0.32s)和 LSTM 模型上,应用时延正则可以有效降低模型的时延。我们只需通过调节超参数 λ,即可控制 WER 和 symbol delay 之间的 trade-off。

关于 RNN-T 时延正则,大家可以阅读论文 https://arxiv.org/pdf/2211.00490.pdf 了解更详细的实验结果。

表1

CTC

表2展示了使用 CTC 训练的流式 Conformer 模型 (chunk=0.32s),应用了时延正则后,在 librispeech 数据集 test-clean 和 test-other 上的结果。可以看出,我们同样可以通过调节超参数 λ,即可控制 WER 和 symbol delay 之间的 trade-off。

由于模型只使用了 CTC 损失函数训练了 25 个 epoch,WER 较差,大家可忽略其绝对数值。

表2

5. 总结

最后,再附上论文地址 https://arxiv.org/pdf/2211.00490.pdf,感兴趣的同学可以阅读 Daniel 的详细证明过程。有疑问的同学欢迎通过 github issue 或者评论区和我们讨论。

参考资料

[1] forward-backward: https://arxiv.org/pdf/1211.3711.pdf

[2] CTC 的目标函数: https://www.cs.toronto.edu/~graves/

Delay Penalty for RNN-T and CTC相关推荐

  1. 语音识别2:CTC对齐的算法

    一.提要 如果现在有一个包含剪辑语音和对应的文本,我们不知道如何将语音片段与文本进行对应,这样对于训练一个语音识别器增加了难度. 如下图,存在图片与文本的对齐不易,语音声波对文本的对齐不易. 以上构成 ...

  2. 端到端流式语音识别研究综述——语音识别(论文研读)

    端到端流式语音识别研究综述(2022.09) 摘要: 引言: 1 端到端流式语音识别模型 1.1 可直接实现流式识别的端到端模型 1.2 改进后可实现流式识别的端到端模型 1.2.1 基于单调注意力机 ...

  3. 收藏 | Tensorflow实现的深度NLP模型集锦(附资源)

    来源:深度学习与NLP 本文约2000字,建议阅读5分钟. 本文收集整理了一批基于Tensorflow实现的深度学习/机器学习的深度NLP模型. 收集整理了一批基于Tensorflow实现的深度学习/ ...

  4. Tensorflow实现的深度NLP模型集锦(附资源)

    https://www.toutiao.com/a6685688607191073294/ 本文约2000字,建议阅读5分钟. 本文收集整理了一批基于Tensorflow实现的深度学习/机器学习的深度 ...

  5. 吴恩达 NIPS 2016:利用深度学习开发人工智能应用的基本要点(含唯一的中文版PPT)...

    雷锋网按:为了方便读者学习和收藏,雷锋网(公众号:雷锋网)特地把吴恩达教授在NIPS 2016大会中的PPT做为中文版,由三川和亚峰联合编译并制作. 今日,在第 30 届神经信息处理系统大会(NIPS ...

  6. LSTM 之父发文:2010-2020,我眼中的深度学习十年简史

    作者 | Jürgen Schmidhuber 译者 | 刘畅.若名 出品 | AI科技大本营(ID:rgznai100) 作为LSTM发明人.深度学习元老,Jürgen Schmidhuber于2月 ...

  7. 第1140期AI100_机器学习日报(2017-11-01)

    AI100_机器学习日报 2017-11-01 神经网络模型压缩和加速方法综述 @东北大学自然语言处理实验室 神经网络基础知识:激活函数以及损失函数 @wx: 百度发布 Deep Speech 3,不 ...

  8. LSTM之父发文:2010-2020,我眼中的深度学习十年简史

    2020-02-23 15:04:22 作者 | Jürgen Schmidhuber 编译 | 刘畅.若名 出品 | AI科技大本营(ID:rgznai100) 作为LSTM发明人.深度学习元老,J ...

  9. 论文翻译-Hamming OCR A Locality Sensitive Hashing Neural Network for Scene Text Recognition

    论文翻译-Hamming OCR A Locality Sensitive Hashing Neural Network for Scene Text Recognition 原文地址:https:/ ...

最新文章

  1. 发布一个持续集成的npm包并加上装逼小icon
  2. Windows Server 排错和发帖求助必读
  3. 电脑台式计算机描述不可用,win7系统计算机描述不可用的解决方法
  4. SEO配置信息操作文档
  5. 各色“独特的”数据中心安置法,藏太深了!
  6. ios查看帧率的软件_程序员必看!直播软件开发弱网下保障高清流畅推流的方法...
  7. 在Linux中某些程序无法运行,为何linux下的程序不能在windows下运行,不是“废话”那么简单...
  8. 大数据WEB阶段总结
  9. [导入]C#中TextBox只能输入数字的代码
  10. weakhashmap_Java WeakHashMap putAll()方法与示例
  11. 几何修复_*ST海润:实施终止退市 光伏产业修复成几何?
  12. 【Kettle】作业和转换中的内置变量
  13. Docker 外部访问容器Pp、数据管理volume、网络network 介绍
  14. matlab钢材切割,一种基于MATLAB的钢材裂纹扩展速率试验数据处理方法
  15. 【ant项目构建学习点滴】--(3)打包及运行jar文件
  16. MySQL 报错:Translating SQLException with SQL state '42000', error code '1064', message
  17. v6使用手册 天正电气t20_电气工程设计软件-T20天正电气软件下载 v6.0官方版--pc6下载站...
  18. Unity 自由视角的惯性旋转
  19. .mdf数据库恢复mysql_恢复mdf文件到数据库方法
  20. 高中计算机课听课记录表,初中信息技术课听课记录中学信息技术评课笔记

热门文章

  1. 攻略:邮件搬家同一个域名操作步骤,设置邮箱搬家功能的方法
  2. 滤波器的抽头系数、通带、阻带、过渡带
  3. 大神都在看的五金模具设计之端子模具设计要点
  4. 【OpenPose-Windows】OpenPose1.4.0+VS2017+CUDA9.2+cuDNN9.2+Windows配置教程
  5. 新网工李白——>李白你好(来抽大奖啦~)
  6. SLAM学习 | 小觅相机的图像与IMU时间戳对齐分析
  7. 重返月球,铺路火星:2024年首位女性登月,280亿美元开启太空探索新纪元-1
  8. SD卡分区教程 安卓手机SD卡分区
  9. Linux 2.6内核配置说明(7----Bus options (PCI, PCMCIA, EISA, MCA, ISA)总线选项)
  10. 计算机基础与程序设计(基于C语言)学习笔记