本文的beam search源码来自:CodeBERT/model.py at master · microsoft/CodeBERT (github.com)https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/model.py

理解过程中加入了注释:

class Beam(object):def __init__(self, size,sos,eos):self.size = sizeself.tt = torch.cudaself.scores = self.tt.FloatTensor(size).zero_()# 大小为[beam_size],记录当前每个beam的分数总和self.prevKs = []# 记录每一步选取的是第几个beam,便于最后回溯生成结果self.nextYs = [self.tt.LongTensor(size).fill_(0)]# nextYs: [seq_len=1, beam_size],随着预测过程seq_len逐渐增加,表示每一步的输出结果# seq_len即为time_stepself.nextYs[0][0] = sos# Has EOS topped the beam yet.self._eos = eosself.eosTop = False# Time and k pair for finished.self.finished = []def getCurrentState(self):batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)# batch: [beam_size, seq_len],用于加入到下一次模型的输入中。return batchdef getCurrentOrigin(self):"Get the backpointers for the current timestep."return self.prevKs[-1]def advance(self, wordLk):'''更新beam中的信息wordLk: [beam_size, vocab_size],上一个时间节点每个beam的模型预测结果,需要用LogSoftMax进行归一化'''numWords = wordLk.size(1)# numWords: vocab_sizeif len(self.prevKs) > 0:beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)# scores: [beam_size]# wordLk是当前的分数,scores是之前的分数,加起来得到beamLk: [beam_size, vocab_size]for i in range(self.nextYs[-1].size(0)):if self.nextYs[-1][i] == self._eos:beamLk[i] = -1e20# 把第i个beam的概率全部设置为负无穷else:beamLk = wordLk[0]# beamLk: [vocab_size] 刚开始只有第一个beamflatBeamLk = beamLk.view(-1) # beamlLk展开bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) # topk个最好分数self.scores = bestScores# scores: [beam_size]prevK = bestScoresId // numWords# prevK: [beam_size]self.prevKs.append(prevK)# prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beamself.nextYs.append((bestScoresId - prevK * numWords))# nextYs: [seq_len, beam_size] 记录了每个事件节点选取的id, seq_len即time_step# 对nextYs的最后一个时间节点进行遍历,检查是否出现了结束符for i in range(self.nextYs[-1].size(0)):if self.nextYs[-1][i] == self._eos:s = self.scores[i]self.finished.append((s, len(self.nextYs) - 1, i))# i 表示第几个beam# 若出现结束符,将(总分数,句子长度,beam的id)三元组加入到finished列表中# finished列表中存的是已经结束的beam的信息# End condition is when top-of-beam is EOS and no global score.if self.nextYs[-1][0] == self._eos:# 当nextYs中最后一个时间点的第一个id为结束符时,将eosTop设置为Trueself.eosTop = Truedef done(self):# 当eosTop为True且已经结束的beam数大于等于beam_size的时候就结束。return self.eosTop and len(self.finished) >=self.sizedef getFinal(self):if len(self.finished) == 0:# 这里的情况就是所有beam的句子长度都达到了max_length但没有任何一个产生了结束符self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))# 这种情况下就手动将第0个beam设置为已经结束self.finished.sort(key=lambda a: -a[0])# 将finished按beam的分数由大到小排序if len(self.finished) != self.size:# 将没有结束的句子也按(分数,长度,beam_id)三元组的形势加入到finished中unfinished=[]for i in range(self.nextYs[-1].size(0)):if self.nextYs[-1][i] != self._eos:s = self.scores[i]unfinished.append((s, len(self.nextYs) - 1, i)) unfinished.sort(key=lambda a: -a[0])self.finished+=unfinished[:self.size-len(self.finished)]# 已经结束的beam排在未结束的句子前面return self.finished[:self.size]def getHyp(self, beam_res):"""回溯,生成结果"""# beam_res 传入的就是finished列表,由get_final得到hyps=[]for _,timestep, k in beam_res:# k是指该结果来自于第几个beamhyp = []for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):# prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beamhyp.append(self.nextYs[j+1][k])# nextYs: [time_step, beam_size] 记录了每个beam的每一步选择,将该id加入到hyp中k = self.prevKs[j][k]# k为结果来自于第几个beamhyps.append(hyp[::-1])# hyp反过来加入到hyps中# 最后得到的hyps:[beam_size, ~]列表,~即长度不一,是每一个beam的预测结果,按分数大小排列return hypsdef buildTargetTokens(self, preds):# preds即为getHyp产生的hyps,记录了每个beam产生的结果,按分数大小排列# 这个函数的目的是截断eos之后的结果sentence=[]for pred in preds:tokens = []for tok in pred:if tok==self._eos:breaktokens.append(tok)sentence.append(tokens)return sentence

Beam Search源码理解相关推荐

  1. 压缩跟踪Compressive Tracking源码理解

    压缩跟踪Compressive Tracking源码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09 在前面一个介绍<Real-Time Compre ...

  2. 从hotspot底层对象结构理解锁膨胀升级过程||深入jdk源码理解longadder的分段cas优化机制——分段CAS优化

    深入jdk源码理解longadder的分段cas优化机制 longadder

  3. faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)

    转载自:faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.n ...

  4. faster rcnn的源码理解(一)SmoothL1LossLayer论文与代码的结合理解

    转载自:faster rcnn的源码理解(一)SmoothL1LossLayer论文与代码的结合理解 - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u ...

  5. TLD(Tracking-Learning-Detection)学习与源码理解之(六)

    TLD(Tracking-Learning-Detection)学习与源码理解之(六) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

  6. TLD(Tracking-Learning-Detection)学习与源码理解之(五)

    TLD(Tracking-Learning-Detection)学习与源码理解之(五)   zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和 ...

  7. TLD(Tracking-Learning-Detection)学习与源码理解之(四)

    TLD(Tracking-Learning-Detection)学习与源码理解之(四) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

  8. TLD(Tracking-Learning-Detection)学习与源码理解之(三)

    TLD(Tracking-Learning-Detection)学习与源码理解之(三) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

  9. TLD(Tracking-Learning-Detection)学习与源码理解之(二)

    TLD(Tracking-Learning-Detection)学习与源码理解之(二) zouxy09@qq.com http://blog.csdn.net/zouxy09 OpenTLD下载与编译 ...

最新文章

  1. 【错误记录】Android 分区存储 错误 ( 文件格式不匹配 )
  2. tinyxml使用指导
  3. 怎么设置html页面背景图片大小怎么设置,HTML – 响应式网页设计:“如何根据浏览器窗口大小使用CSS调整背景图像的大小”?...
  4. OpenCV用方形棋盘进行相机校准
  5. 配置PIX515E DMZ的基本方法与故障排除
  6. Android 和 Chrome OS 融合的可能性
  7. python爬虫入门,10分钟就够了,这可能是我见过最简单的基础教学
  8. 我的YUV播放器MFC小笔记:解析文件名称
  9. mysql中的group_MySQL中使用group
  10. Linux安装Diamond软件,1.1 Linux下安装diamond
  11. Linux下使用FDDB 测试MTCNN人脸检测模型生成 ROC 曲线
  12. 什么是5G技术-认识5G
  13. 路由器自动重启指令_如何按计划自动重启路由器,简便方法
  14. 单像空间后方交会模型
  15. Flash按钮操作(画面暂停与播放)
  16. 5万款Lr顶级调色预设合集,精心整理,分类清晰,摄影师调色师必备素材,够用一辈子
  17. Android设置默认输入法
  18. 婴儿脸上起湿疹吃什么好
  19. 奔图Pantum M6608 一体机驱动
  20. 关于Visual Studio订阅(原MSDN订阅)中无法激活Office 365权益的解决方法(仅适用于MVP)

热门文章

  1. 程序员连续剧_每个程序员都应该看的5部最佳电视连续剧
  2. 微软服务器架云,微软云计算的“三驾马车”
  3. 记主板的南桥芯片和北桥芯片作用及区别(“干南桥”)
  4. 太阳能计算机作文500字,制作太阳能热水器
  5. cocos2dx 常用的基本功能
  6. oracle11.2g递归查询(树形结构查询)
  7. 支持直接外链的最佳网络硬盘OpenDrive
  8. 教师计算机学业水平测试,2020年全国教师资格考试信息技术学科知识与能力练习题...
  9. 布尔巴基学派的灵魂人物—安德烈` 韦伊
  10. 高性能低功耗4口高速USB2.0 HUB NS1.1S 兼容FE1.1