文章目录

  • 一 RNN在训练过程中的问题
  • 二 RNN的两种训练模式
  • 三 什么是Teacher Forcing
  • 四 Free-Running vs Teacher Forcing 实例
    • 4.1 Free-running 训练过程
    • 4.2 Teacher-Forcing 训练过程
  • 五 Teacher Forcing的缺点及其解决办法
    • 5.1 Teacher Forcing的缺点
    • 5.2 集束搜索(Beam Search)
    • 5.3 有计划地学习(Curriculum Learning)

我看到有些seq2seq模型训练过程中使用了这个机制,一时搜不到适合我的中文教程资源,寥寥一两篇翻译国外大神的,不过那翻译质量个人觉得还是差,那就自己动手写下这个学习笔记吧。

一 RNN在训练过程中的问题

训练迭代过程早期的RNN预测能力非常弱,几乎不能给出好的生成结果。如果某一个unit产生了垃圾结果,必然会影响后面一片unit的学习。
teacher forcing最初的motivation就是解决这个问题的。

二 RNN的两种训练模式

其实RNN存在着两种训练模式(mode):

  1. free-running mode
  2. teacher-forcing mode

free-running mode就是大家常见的那种训练网络的方式: 上一个state的输出作为下一个state的输入。而Teacher Forcing是一种快速有效地训练循环神经网络模型的方法,该模型使用来自先验时间步长的输出作为输入。

三 什么是Teacher Forcing

所谓Teacher Forcing,就是在学习时跟着老师(ground truth)走!

它是一种网络训练方法,对于开发用于机器翻译,文本摘要,图像字幕的深度学习语言模型以及许多其他应用程序至关重要。它每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。
看一下大佬们对它的评价:

Models that have recurrent connections from their outputs leading back into the model may be trained with teacher forcing. — Page 372, Deep Learning, 2016.
译: 存在把输出返回到模型输入中的这种循环连接单元的模型可以通过Teacher Forcing机制进行训练。

这种技术最初被作为反向传播的替代技术进行宣传与开发

An interesting technique that is frequently used in dynamical supervised learning tasks is to replace the actual output y(t) of a unit by the teacher signal d(t) in subsequent computation of the behavior of the network, whenever such a value exists. We call this technique teacher forcing. — A Learning Algorithm for Continually Running Fully Recurrent Neural Networks, 1989.
译: 在动态监督学习任务中经常使用的一种有趣的技术是,在计算过程中用教师信号 d(t)d(t)d(t) 替换上一个单元的实际输出 y(t)y(t)y(t) 。我们称这种技术为Teacher Forcing。

Teacher Forcing工作原理: 在训练过程的 ttt 时刻,使用训练数据集的期望输出或实际输出: y(t)y(t)y(t), 作为下一时间步骤的输入: x(t+1)x(t+1)x(t+1),而不是使用模型生成的输出h(t)h(t)h(t)。

Teacher forcing is a procedure […] in which during training the model receives the ground truth output y(t) as input at time t + 1. — Page 372, Deep Learning, 2016.
译: teacher forcing 是这样的一个程序: 在训练过程中接收ground truth的输出 y(t)y(t)y(t) 作为t+1t+1t+1时刻的输入

四 Free-Running vs Teacher Forcing 实例

给定如下输入序列:

Mary had a little lamb whose fleece was white as snow

我们想要训练这样一个模型,在给定序列中前一个单词的情况下生成序列中的下一个单词。
那首先,我们得给这个序列的首尾加上起止token:

[START] Mary had a little lamb whose fleece was white as snow [END]

接下来,我们把 “[START]” 输入模型,让模型生成下一个单词。

4.1 Free-running 训练过程

想象下,现在模型生成了一个 “a”, 不过我们当然期望它先生成一个 “Mary”。

XXX y^\hat{y}y^​
“[START]” “a”

接下来,如果把"a"输入模型,来生成序列中的下一个单词,那现在的情况就是:

XXX y^\hat{y}y^​
“[START]” , “a” ?

可以看到,模型现在已经偏离正轨 ,因为生成的错误结果,会导致后续的学习都受到不好的影响,导致学习速度变慢,模型也变得不稳定。

4.2 Teacher-Forcing 训练过程

假如现在模型生成了一个“a”,我们可以在计算了error之后,丢弃这个输出,把"Marry"作为后续的输入。如果要继续预测下一个单词的话,那么现在的情形就变成了:

XXX y^\hat{y}y^​
“[START]” , “Marry” ?

以此类推,所有训练步骤情形为:

XXX y^\hat{y}y^​
“[START]” ?
“[START]” , “Marry” ?
“[START]”, “Marry”, “had” ?
“[START]”, “Marry”, “had”, “a” ?
?

该模型将更正模型训练过程中的统计属性,更快地学会生成正确的序列。

五 Teacher Forcing的缺点及其解决办法

5.1 Teacher Forcing的缺点

Teacher Forcing同样存在缺点: 一直靠老师带的孩子是走不远的。
因为依赖标签数据,在训练过程中,模型会有较好的效果,但是在测试的时候因为不能得到ground truth的支持,所以如果目前生成的序列在训练过程中有很大不同,模型就会变得脆弱。
也就是说,这种模型的cross-domain能力会更差,也就是如果测试数据集与训练数据集来自不同的领域,模型的performance就会变差。
那有没有解决这个限制的办法呢?

5.2 集束搜索(Beam Search)

在预测单词这种离散值的输出时,一种常用方法是对词表中每一个单词的预测概率执行搜索,生成多个候选的输出序列。
这个方法常用于机器翻译(MT)等问题,以优化翻译的输出序列。
beam search是完成此任务应用最广的方法,通过这种启发式搜索(heuristic search),可减小模型学习阶段performance与测试阶段performance的差异。

5.3 有计划地学习(Curriculum Learning)

注: 本来我想翻译为课程学习,后来感觉太不对原本的意思,所以改为"有计划地学习"
如果模型预测的是实值(real-valued)而不是离散值(discrete value),那么beam search就力不从心了。
因为beam search方法仅适用于具有离散输出值的预测问题,不能用于预测实值(real-valued)输出的问题。

Curriculum Learning是Teacher Forcing的一个变种:

We propose to change the training process in order to gradually force the model to deal with its own mistakes, as it would have to during inference. — Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks, 2015.
译: 我们建议改变训练过程,以便逐步迫使模型处理它自己的错误,就像它在推断过程中必须做的那样。

有计划地学习的意思就是: 使用一个概率ppp去选择使用ground truth的输出y(t)y(t)y(t)还是前一个时间步骤模型生成的输出h(t)h(t)h(t)作为当前时间步骤的输入x(+1)x(+1)x(+1)。
这个概率ppp会随着时间的推移而改变,这就是所谓的计划抽样(scheduled sampling)
训练过程会从force learning开始,慢慢地降低在训练阶段输入ground truth的频率。

本文一部分译自: https://machinelearningmastery.com/teacher-forcing-for-recurrent-neural-networks/

一文弄懂关于循环神经网络(RNN)的Teacher Forcing训练机制相关推荐

  1. 人人都能看懂的循环神经网络RNN

    循环神经网络 基础篇   我们假设您有一个管家,他很擅长做苹果派.汉堡以及炸鸡这三样食物.管家制作食物的种类取决于天气,若是晴天,他会做苹果派:若是雨天,他会做汉堡.这样制作食物的规则很容易用神经网络 ...

  2. 【Spring源码:循环依赖】一文弄懂Spring循环依赖

    1. 什么是循坏依赖 很简单,其实就是互相依赖对方,比如,有一个A对象依赖了B对象,B对象又依赖了A对象. // A依赖了B public class A{private B b; }// B依赖了A ...

  3. 一文弄懂神经网络中的反向传播法

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  4. 一文弄懂神经网络中的反向传播法——BackPropagation【转】

    本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法--BackPropagation 最近在看深度学习的东 ...

  5. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  6. 一文弄懂各种loss function

    有模型就要定义损失函数(又叫目标函数),没有损失函数,模型就失去了优化的方向.大家往往接触的损失函数比较少,比如回归就是MSE,MAE,分类就是log loss,交叉熵.在各个模型中,目标函数往往都是 ...

  7. Keras 中的循环神经网络 (RNN)

    简介 循环神经网络 (RNN) 是一类神经网络,它们在序列数据(如时间序列或自然语言)建模方面非常强大. 简单来说,RNN 层会使用 ​​for​​ 循环对序列的时间步骤进行迭代,同时维持一个内部状态 ...

  8. 循环神经网络_小孩都看得懂的循环神经网络

    点击上方"MLNLP",选择"星标"公众号 重磅干货,第一时间送达 全文共 2014 字,28 幅图,预计阅读时间 20 分钟. 本文是「小孩都看得懂」系列的第 ...

  9. Python手撸机器学习系列(十六):循环神经网络RNN的实现

    目录 循环神经网络RNN 1.公式推导 2.代码实现 循环神经网络RNN 1.公式推导 对于该循环神经网络,以中间的RNN单元为例,推导前向传播: 对于Layer-1: z h = w i x + w ...

最新文章

  1. 2020-10-09
  2. #中regex的命名空间_Python空间分析||geopandas安装与基本使用
  3. 轻量型「孟子」模型比肩千亿大模型!AI大牛周明率队刷新CLUE新纪录
  4. TensorFlow高层次机器学习API (tf.contrib.learn)
  5. python字符串常用方法_字符串常用方法
  6. 易语言https服务器,E2EE应用服务器套件 - 文档 - [基础教程] 使用HTTPS(SSL) - E2EE易语言网站敏捷开发框架...
  7. IBM-X3650 6核处理器安装sql server 2005报错解决方法
  8. 微信公众平台无法使用支付宝收付款的解决方案
  9. 面向对象(final/抽象类/接口/内部类)
  10. Lovesource博士:或者我是如何学会不再担心和热爱开放的
  11. android文本框自动补全,[Android]AutoCompleteTextView自动补全文本框
  12. jenkind + git + mave + shell + tomcat
  13. 变频器压频比的正确设置方法
  14. 矿井下无线基站和地面服务器,煤矿井下无线通信系统_矿井通信
  15. 惠普电脑安装Nvidia显卡驱动和cuda小记
  16. 固态硬盘和机械硬盘的区别(7大区别,简单易懂)
  17. 分布式数据库之TiDB
  18. 【技能教学】如何通过FFMPEG编码推RTSP视频直播流到EasyDarwin开源平台时叠加时间水印?
  19. Mysql出现问题:ERROR 1064 (42000): You have an error in your SQL syntax; check the manual that corres解决方案
  20. 如何帮女朋友快速抢到各种票!火车票,演唱会票等!

热门文章

  1. [实践篇]13.21 la qcom watchdog学习笔记
  2. 【转】数据库基本知识:(七)函数和表达式的使用
  3. (4/4) Biweekly Contest 42
  4. int const*与int * const
  5. 雪崩效应及其常见场景和解决方案
  6. 全球及中国分子束外延(MBE)系统行业研究及十四五规划分析报告
  7. flask flask_paginate简单分页案例;网页url随着分页动态变化
  8. 使用ps分隔图片,形成拼图的效果
  9. 数学建模Word排版——细节决定成败
  10. Google+无法取代个人博客