引言

Beam Search 是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的基本思想,然后再介绍一些beam search的优化方法,最后附上自己的代码实现。

1. Beam Search的基础版本

在生成文本的时候,通常需要进行解码操作,贪心搜索 (Greedy Search) 是比较简单的解码。Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。示意图如下所示:

假设我们生成词表中有三个单词{我,爱,你}。我们设 K = 2 K=2 K=2。那么我们在第一时刻确定两个候选输出是{我,你}。紧接着我们要考虑第二个输出,具体步骤如下:

  • 确定单词“我”为第一时刻输出,并将其作为第二时刻输入,在已知 p ( x , 我 ) p(x,我) p(x,我)的情况下,各个单词的输出概率为3种情况,每个组合的概率为 P ( 我 ∣ x ) P ( y 2 ∣ x , 我 ) P(我|x)P(y_2|x,我) P(我∣x)P(y2​∣x,我)。
  • 同样我们把“你”也作为第二时刻输入,同样也有三种组合。
  • 最后我们在六种组合中选择概率最大的三个组合。

接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:

以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。

2. Beam Search的优化

Beam search算法也有许多改进的地方。

2.1 Length normalization:惩罚短句

根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数,这样比较的就是每个单词的平均概率大小。此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。谷歌给的公式如下:

其中α∈[0,1],谷歌建议取值为[0.6,0.7]之间,α用于length normalization。

2.2 Coverage normalization:惩罚重复

另外我们在序列到序列任务中经常会发现一个问题,2016 年, 华为诺亚方舟实验室的论文提到,机器翻译的时候会存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention机制来解决coverage 问题。 Google machine system 利用了如下的方式进行了length normalization 和 coverage penalty。

还是上述公式,β用于控制coverage penalty

coverage penalty 主要用于使用 Attention 的场合,通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 x x x 的每一个 token,防止一些 token 获得过多的 Attention

2.3 End of sentence normalization:抑制长句

有的时候我们发现生成的序列一直生成下去不会停止,有的时候我们可以显式的设置最大生成长度进行控制,这里我们可以采用下式来进行约束:

其中 ∣ X ∣ |X| ∣X∣是source的长度, ∣ Y ∣ |Y| ∣Y∣是当前target的长度,那么由上式可知,target长度越长的话,上述得分越低,这样就会防止出现生成一直不停止的情况。

3. Beam Search的代码实现

总的来说,beam search不保证全局最优,但是比greedy search搜索空间更大,一般结果比greedy search要好。下面附上一些代码实现:

首先,首先定义一个 Beam 类,作为一个存放候选序列的容器,属性需维护当前序列中的 token 以及对应的对数概率,同时还需维护跟当前 timestep 的 Decoder 相关的一些变量。此外,还需要给 Beam 类实现两个函数:一个 extend 函数用以扩展当前的序列(即添加新的 time step的 token 及相关变量);一个 score 函数用来计算当前序列的分数(在Beam类下的seq_score函数中有Length normalization以及Coverage normalization)。

class Beam(object):def __init__(self,tokens,log_probs,decoder_states,coverage_vector):self.tokens = tokensself.log_probs = log_probsself.decoder_states = decoder_statesself.coverage_vector = coverage_vectordef extend(self,token,log_prob,decoder_states,coverage_vector):return Beam(tokens=self.tokens + [token],log_probs=self.log_probs + [log_prob],decoder_states=decoder_states,coverage_vector=coverage_vector)def seq_score(self):"""This function calculate the score of the current sequence."""len_Y = len(self.tokens)# Lenth normalizationln = (5+len_Y)**config.alpha / (5+1)**config.alphacn = config.beta * torch.sum(  # Coverage normalizationtorch.log(config.eps +torch.where(self.coverage_vector < 1.0,self.coverage_vector,torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE)))))score = sum(self.log_probs) / ln + cnreturn scoredef __lt__(self, other):return self.seq_score() < other.seq_score()def __le__(self, other):return self.seq_score() <= other.seq_score()

接着我们需要实现一个 best_k 函数,作用是将一个 Beam 容器中当前 time step 的变量传入 Decoder 中,计算出新一轮的词表概率分布,并从中选出概率最大的 k 个 token 来扩展当前序列(其中加入了End of sentence normalization),得到 k 个新的候选序列。

    def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):"""Get best k tokens to extend the current sequence at the current time step."""# use decoder to generate vocab distribution for the next tokenx_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)x_t = x_t.to(self.DEVICE)# Get context vector from attention network.context_vector, attention_weights, coverage_vector = \self.model.attention(beam.decoder_states,encoder_output,x_padding_masks,beam.coverage_vector)# Replace the indexes of OOV words with the index of OOV token# to prevent index-out-of-bound error in the decoder.p_vocab, decoder_states, p_gen = \self.model.decoder(replace_oovs(x_t, self.vocab),beam.decoder_states,context_vector)final_dist = self.model.get_final_distribution(x,p_gen,p_vocab,attention_weights,torch.max(len_oovs))# Calculate log probabilities.log_probs = torch.log(final_dist.squeeze())# Filter forbidden tokens.# EOS token penalty. Follow the definition in# https://opennmt.net/OpenNMT/translation/beam_search/.log_probs[self.vocab.EOS] *= \config.gamma * x.size()[1] / len(beam.tokens)log_probs[self.vocab.UNK] = -float('inf')# Get top k tokens and the corresponding logprob.topk_probs, topk_idx = torch.topk(log_probs, k)# Extend the current hypo with top k tokens, resulting k new hypos.best_k = [beam.extend(x,log_probs[x],decoder_states,coverage_vector) for x in topk_idx.tolist()]return best_k

最后我们实现主函数 beam_search。初始化encoder、attention和decoder的输⼊,然后对于每⼀个decodestep,对于现有的k个beam,我们分别利⽤best_k函数来得到各⾃最佳的k个extended beam,也就是每个decode step我们会得到k*k个新的beam,然后只保留分数最⾼的k个,作为下⼀轮需要扩展的k个beam。为了只保留分数最⾼的k个beam,我们可以⽤⼀个堆(heap)来实现,堆的中只保存k个节点,根结点保存分数最低的beam。

    def beam_search(self,x,max_sum_len,beam_width,len_oovs,x_padding_masks):"""Using beam search to generate summary."""# run body_sequence input through encoderencoder_output, encoder_states = self.model.encoder(replace_oovs(x, self.vocab))coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)# initialize decoder states with encoder forward statesdecoder_states = self.model.reduce_state(encoder_states)# initialize the hypothesis with a class Beam instance.init_beam = Beam([self.vocab.SOS],[0],decoder_states,coverage_vector)# get the beam size and create a list for stroing current candidates# and a list for completed hypothesisk = beam_widthcurr, completed = [init_beam], []# use beam search for max_sum_len (maximum length) stepsfor _ in range(max_sum_len):# get k best hypothesis when adding a new tokentopk = []for beam in curr:# When an EOS token is generated, add the hypo to the completed# list and decrease beam size.if beam.tokens[-1] == self.vocab.EOS:completed.append(beam)k -= 1continuefor can in self.best_k(beam,k,encoder_output,x_padding_masks,x,torch.max(len_oovs)):# Using topk as a heap to keep track of top k candidates.# Using the sequence scores of the hypos to campare# and object ids to break ties.add2heap(topk, (can.seq_score(), id(can), can), k)curr = [items[2] for items in topk]# stop when there are enough completed hypothesisif len(completed) == beam_width:break# When there are not engouh completed hypotheses,# take whatever when have in current best k as the final candidates.completed += curr# sort the hypothesis by normalized probability and choose the best oneresult = sorted(completed,key=lambda x: x.seq_score(),reverse=True)[0].tokensreturn result

Beam Search的学习笔记(附代码实现)相关推荐

  1. 对联智能生成的原理(学习笔记附代码实现与详解)

    文章均从个人微信公众号" AI牛逼顿"转载,文末扫码,欢迎关注! 过年的脚步越来越近,是不是该给家里贴上一副对联呢?除了买买买,有没有想过自己动手写出一副对联?来吧,撸起袖子加油干 ...

  2. JUC.Condition学习笔记[附详细源码解析]

    JUC.Condition学习笔记[附详细源码解析] 目录 Condition的概念 大体实现流程 I.初始化状态 II.await()操作 III.signal()操作 3个主要方法 Conditi ...

  3. (实验38)单片机,STM32F4学习笔记,代码讲解【SD卡实验】【正点原子】【原创】

    文章目录 其它文章链接,独家吐血整理 实验现象 主程序 SD卡驱动程序 代码讲解 其它文章链接,独家吐血整理 (实验3)单片机,STM32F4学习笔记,代码讲解[按键输入实验][正点原子][原创] ( ...

  4. (实验39)单片机,STM32F4学习笔记,代码讲解【FATFS实验】【正点原子】【原创】

    文章目录 其它文章链接,独家吐血整理 实验现象 主程序 FATFS初始化程序 代码讲解 其它文章链接,独家吐血整理 (实验3)单片机,STM32F4学习笔记,代码讲解[按键输入实验][正点原子][原创 ...

  5. (实验55)单片机,STM32F4学习笔记,代码讲解【网络通信实验】【正点原子】【原创】

    文章目录 其它文章链接,独家吐血整理 实验现象 主程序 LWIP初始化程序 代码讲解 其它文章链接,独家吐血整理 (实验3)单片机,STM32F4学习笔记,代码讲解[按键输入实验][正点原子][原创] ...

  6. (实验37)单片机,STM32F4学习笔记,代码讲解【内存管理实验】【正点原子】【原创】

    文章目录 其它文章链接,独家吐血整理 实验现象 主程序 内存池初始化程序 代码讲解 其它文章链接,独家吐血整理 (实验3)单片机,STM32F4学习笔记,代码讲解[按键输入实验][正点原子][原创] ...

  7. [学习笔记]《代码整洁之道》(八)

    [学习笔记] <代码整洁之道>- 第9章 单元测试 TDD 三定律 谁都知道TDD要求我们在编写生产代码之前先编写单元测试. 定律一:在编写不能通过的测试单元前,不可以编写生产代码. 定律 ...

  8. (实验4)单片机,STM32F4学习笔记,代码讲解【串口实验】【正点原子】【原创】

    文章目录 其它文章链接,独家吐血整理 实验现象 主程序 串口中断程序 代码讲解 其它文章链接,独家吐血整理 (实验3)单片机,STM32F4学习笔记,代码讲解[按键输入实验][正点原子][原创] (实 ...

  9. openCV4.0 C++ 快速入门30讲学习笔记(自用 代码+注释)详细版

    课程来源:哔哩哔哩 环境:OpenCV4.5.1 + VS2019 目录 002.图像色彩空间转换 003.图像对象的创建与赋值 004.图像像素的读写操作 005.图像像素的算术操作(加减乘除4种不 ...

最新文章

  1. java基础学习(一)方法
  2. Python--format()学习记录
  3. azure机器学习_Microsoft Azure机器学习x Udacity —第4课笔记
  4. 关于爬虫中常见的两个网页解析工具的分析 —— lxml / xpath 与 bs4 / BeautifulSoup...
  5. druid dubbo 生产者_dubbo项目扩展druid sql监控
  6. mysql三高讲解(二):2.8 mysql视图相关概念
  7. IntelliJ IDEA的安装详解
  8. oracle typehandler,Mybatis实现自定义的类型转换器TypeHandler
  9. qq群t人php,QQ群机器人,自动加人、T人、与人聊天,你不在,有机器人呢
  10. vue json 编辑组件_内置为Vue组件的Visual JSON编辑器
  11. ××项目日常工作制度和流程(草案)
  12. HUSTOJ搭建后为了方便作为Judger调用进行的一些修改操作
  13. 152. 精读《recoil》
  14. 经济寒冬之后,是人工智能的春天
  15. 魔兽 服务器 角色 最多,魔兽科普:国服人最多的几个服务器都什么来头
  16. 教育与人生:教师节有感
  17. MQTT-新一代物联网协议
  18. 爱心的数学函数方程_笛卡尔心形线公式表白是什么?公式内容整理
  19. C# Winform软件多语言(汉语、英语。。。)界面的切换,低耦合 - 转
  20. 所谓“尽人事,听天命”

热门文章

  1. gps经纬度换算,WGS-84->GCJ02->BD-09
  2. Redis缓存穿透与解决方案
  3. Vue给组件加v-model,封装表单组件
  4. Cocos2d-3.x Action动作锦集
  5. linux 映射远程文件夹,Linux远程映射磁盘的方法
  6. 前端面试必会的设计模式以及在前端开发中应用
  7. 融资租赁业务系统LPR调息引擎
  8. Fleet的台子,我不能落后!
  9. Java死锁如何避免?
  10. pythonocc进阶学习:分割面/合并面(体)splitter/glue