事先声明:我只是个真·零基础小白,所以部分理解不到位或者有错误,望各位大佬不吝赐教!

第一部分 引入库部分

本代码采用的是苏老师写的bert4keras,即使用keras实现bert,包含层、模型、优化器、分词器等

bert4keras最好在tensorflow<=2.2以及keras<=2.3.1的条件下,即搭配python3.6食用

第二部分 加载BERT模型

BERT的结构是这样的(可以参考这篇文章:原来你是这样的BERT,i了i了!)

simBERTv2也是基于BERT的一个模型,所以需要先“搭建”一个BERT模型【此时我们只有最基本的bert配置】

我们先不看def/class的内容,直接快进到建立加载模型

然后我们去寻找build_transformer_model的源码(位于bert4keras\models.py),这个函数是负责建立一个模型(如果检查点存在,就会用load_weights_from_checkpoint加载检查点中存储的内容)

下面,我们来看看这些参数都是做什么的:

  1. config_path,checkpoint_path:变量如其名,即加载配置和检查点的路径
  2. model:即模型类型,bert4keras支持导入很多种模型,如果感兴趣的话可以前往bert4keras的models.py文件中的build_transformer_model()函数下查找
  3. application:从字面上看不出来是干什么的,我们返回*build_transformer_model()*中查看
application = application.lower()
if application in ['lm', 'unilm'] and model in ['electra', 't5']:raise ValueError('"%s" model can not be used as "%s" application.\n' %(model, application))if application == 'lm':MODEL = extend_with_language_model(MODEL)
elif application == 'unilm':MODEL = extend_with_unified_language_model(MODEL)

还不懂,找到extend_with_language_model()函数查看,苏老师在这里写了备注:

lm给其他语言模型使用,ulm(unified language model)给seq2seq模型用

  1. 其他参数:实际上我们并没有在build_transformer_model()中找到对应的参数,但是kwargs这个参数就很耐人寻味了,我们在models.py中发现,这个参数在*apply()*中层的建立,encoder/decoder的构建中使用过

    with_pool,with_mlm:可以完全参照字面意思,具体见这篇文章【NLP】bert4keras源码及矩阵计算解析

模型的剩余部分由这里建立

encoder = keras.models.Model(roformer.inputs, roformer.outputs[0])#bert中的encoder
seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1])#seq2seq用来计算损失outputs = TotalLoss([2, 3])(roformer.inputs + roformer.outputs)#损失
model = keras.models.Model(roformer.inputs, outputs)#模型建立

第三部分 模型训练的准备

AdamW = extend_with_weight_decay(Adam, 'AdamW')
optimizer = AdamW(learning_rate=1e-5, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
model.summary()

model.complie用于配置训练的优化器、损失函数和准确率评测标准,详见tensorflow中model.compile()用法

第四部分 模型训练

train_generator = data_generator(corpus(), batch_size)
evaluator = Evaluate()model.fit_generator(train_generator.forfit(),steps_per_epoch=steps_per_epoch,epochs=epochs,callbacks=[evaluator]
)

其中,data_generator用来生成数据,而我们虽然使用的语料是汉语文字,喂给真正BERT的还需要进行一些编码,即需要分词(Tokenizer);训练模型,喂给模型的不全是完整的句子,有时候需要遮住一些词喂给模型,而函数mask_encode起到的就是这个作用,苏老师在[博客](SimBERTv2来了!融合检索和生成的RoFormer-Sim模型 - 科学空间|Scientific Spaces)中的【生成】部分阐述了这个思想——即BART;

data_generator还涉及了数据蒸馏的部分,但是暂时我还学到这里,所以暂且搁置了

剩下的fit中则是一些比较常规的内容

训练语料corpus部分还是很容易懂的,有时间再写

第五部分 模型保存

训练好的模型会在每个epoch的最后保存下来,使用的方法为save_weights

def on_epoch_end(self, epoch, logs=None):model.save_weights('./latest_model.weights')# 保存最优if logs['loss'] <= self.lowest:self.lowest = logs['loss']model.save_weights('./best_model.weights')# 演示效果just_show()

加载模型参数则采用load_weights

-----以上,stage1使用语料进行训练,并进行模型的保存,就结束了-----

然而,我们的models.py中,如果要加载检查点的权重,则要使用:

transformer.load_weights_from_checkpoint(checkpoint_path)

这样则会造成一些冲突,即你通过stage1保存的模型,是无法直接使用bert_transformer_model()的

读源码之SimBertv2-stage1相关推荐

  1. 我是怎么读源码的,授之以渔

    点击上方"视学算法",选择"设为星标" 做积极的人,而不是积极废人 作者 :youzhibing 链接 :https://www.cnblogs.com/you ...

  2. 这样读源码,不牛X也难

    程序员在工作过程中,会遇到很多需要阅读源码的场景,比如技术预研.选择技术框架.接手以前的项目.review他人的代码.维护老产品等等.可以说,阅读源代码是程序员的基本功,这项基本功是否扎实,会在很大程 ...

  3. myisam怎么读_耗时半年,我成功“逆袭”,拿下美团offer(刷面试题+读源码+项目准备)...

    欢迎关注专栏[以架构赢天下]--每天持续分享Java相关知识点 以架构赢天下​zhuanlan.zhihu.com 以架构赢天下--持续分享Java相关知识点 每篇文章首发此专栏 欢迎各路Java程序 ...

  4. 微信读书vscode插件_跟我一起读源码 – 如何阅读开源代码

    阅读是最好的老师 在学习和提升编程技术的时候,通过阅读高质量的源码,来学习专家写的高质量的代码,是一种非常有效的提升自我的方式.程序员群体是一群乐于分享的群体,因此在互联网上有大量的高质量开源项目,阅 ...

  5. 读源码,对程序员重要吗?

    来源: CSDN(ID:CSDNnews) 嘿,朋友们!本文我将分享一些关于主动阅读和研究源码的一些想法.在我看来,阅读源码能够帮你成为一名更专业的开发人员.毫无疑问的是,阅读源码提高了我的软件开发水 ...

  6. 夜读源码,带你探究 Go 语言的iota

    Go 语言的 iota 怎么说呢,感觉像枚举,又有点不像枚举,它的底层是什么样的,用哪个姿势使用才算正规,今天转载一篇「Go夜读」社区上分享的文章,咱们一起学习下.Go 夜读,带你每页读源码~!  这 ...

  7. 【一起读源码】1. Java 中元组 Tuple

    1.1 问题描述 使用 Java 做数据分析.机器学习的时候,常常需要对批量的数据进行处理,如果需要处理的数据的维度不超过10时,可以考虑使用 org.javatuples 提供的 Tuple 类工具 ...

  8. Spring读源码系列之AOP--03---aop底层基础类学习

    Spring读源码系列之AOP--03---aop底层基础类学习 引子 Spring AOP常用类解释 AopInfrastructureBean---免被AOP代理的标记接口 ProxyConfig ...

  9. 读源码:PopupWindow

    读源码是为了了解并学习它的实现机制,并更好的运用它,如果在读源码之前已经知道它的怎么运用,这将会更容易理解源码.所以在这读源码开头我推荐阅读一下一位大神写的相关博文,浅显易懂,条理清晰: PopUpW ...

  10. 学会读源码,很重要!

    刚参加工作那会,没想过去读源码,更没想过去改框架的源码:总想着别人的框架应该是完美的.万能的,应该不需要改:另外即使我改了源码,怎么样让我的改动生效了?项目中引用的不还是没改的jar包吗.回想起来觉得 ...

最新文章

  1. 对比Memcached和Redis,谁才是适合你的缓存?
  2. 快速部署Telegraf Influxdb
  3. 用于计算机视觉领域的python第三方库是什么_大量Python开源第三方库资源分类整理,含菜鸟教程章节级别链接...
  4. 科达南沙电子警察“扩编”
  5. 串口服务器接入232显示乱码,串口服务器出现乱码时如何处理,解决方案
  6. flask返回json数据到前端_小白学Flask第六天| abort函数、自定义错误方法、视图函数的返回值...
  7. asp.net core的文件下载
  8. css新奇技术及其未来发展
  9. BZOJ 4000: [TJOI2015]棋盘( 状压dp + 矩阵快速幂 )
  10. 防火墙虚拟系统互访配置实例
  11. 华为击败思科 赢得阿曼2600万美元NGN合同
  12. C#获取当前时区转换方法
  13. PDF控件Aspose.Pdf 12月新版17.12发布 | 附下载
  14. OpenWrt ar71xx 添加原生 AR8035 支持的方法 (AR934X)
  15. 港科夜闻|沈向洋教授获委任为香港科大校董会主席
  16. 今日头条广告_API对接文档学习-1
  17. androidstudio的语音唤醒功能
  18. will be doing的用法
  19. Hyperledger Fabric网络环境手动配置及其链码自动化部署
  20. CNMOOC-os- ch2硬件基础

热门文章

  1. vmware虚拟机Bridged(桥接模式)、NAT(网络地址转换模式)、Host-Only(仅主机模式)详解
  2. 和刚入门的菜鸟们聊聊--什么是聚簇索引与非聚簇索引
  3. 如何给Python安装.whl文件
  4. 【GitLab、Jira、Confluence 单点登录实现】之 CAS 系统部署
  5. 调试python代码神奇ipdb
  6. 【AI】百度ai人脸识别
  7. signal------SIGCHLD
  8. poi-tl-ext扩展,实现多行表格模板替换
  9. Unity animator不剪辑动画实现 分段播放动画
  10. 微信小程序真机调试白屏,只显示Tabbar