#学习率很重要,lr=2e-5训练集准确率0.99,lr=1e-3,训练集准确率0.6,loss降不下来。
#lstm的sequence是变长的,注意测试设置batch合理大小,确保不爆内存
import gluonnlp as nlp
import mxnet as mx
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon import rnn
from mxnet import  gluon, autograd
import numpy as np
import functools
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
context = mx.gpu()
#----------------------------------------------------Bert模块---------------------------------------------------------
bert_model, text_vocab = nlp.model.get_model(name='bert_12_768_12',dataset_name='wiki_cn_cased',pretrained=True,ctx=context,#use_pooler=False,use_decoder=False,use_classifier=False,dropout=0.1,
)bert_tokenizer = nlp.data.BERTTokenizer(vocab=text_vocab)#----------------------------------------------------数据预处理模块--------------------------------------------------
sentences = [
"""
油麻地今晚(15日)發生奪命火警,釀成7死11傷。消息指,火場為唐樓一樓餐廳,有尼泊爾籍人士為小童舉辦生日派對。事發時單位正進行慶祝活動,期間有物品起火,火勢迅速蔓延,火舌從窗口冒出。多人被困在近廚房位置,有人被逼爬出外牆,甚至由洗手間的窗戶跳落大廈後巷逃生。
消防共救出18名傷者,其中12人嚴重受傷,全部人送院搶救後,其中7人返魂乏術,死者包括一名9歲男童。大批親友趕赴醫院認屍,眾人哀慟不已。西九龍總區重案組第4隊已接手跟進,消防亦成立專責小隊調查,包括聚會情況及有否易燃物品等。
""",'林鄭政府 我哋做咗大量工作, 急市民所急',
'林利夜闖醫院驚見病母被綁手部流血:媽咪用紙筆寫住「打999」'
]sentence_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(axis=0, pad_val=0),nlp.data.batchify.Stack(),nlp.data.batchify.Pad(axis=0, pad_val=0))
pad_fn=nlp.data.batchify.Pad(axis=0, pad_val=0)
stack_fn = nlp.data.batchify.Stack()
#sub_batchify_fn 处理word padding, sub_sent有效长度, segment_padding
sub_batchify_fn2 = nlp.data.batchify.Tuple(lambda x: pad_fn([j for i in x for j in i]),lambda x : stack_fn([j for i in x for j in i]).squeeze(),lambda x : pad_fn([j for i in x for j in i]))
#sub_batchify_fn1 处理(word padding, sub_sent有效长度, segment_padding),句子长度)
sub_batchify_fn1 = nlp.data.batchify.Tuple(sub_batchify_fn2,lambda x : stack_fn(x).squeeze())
#batchify_fn 处理((word padding, sub_sent有效长度, segment_padding),句子长度),目标值)
batchify_fn = nlp.data.batchify.Tuple(sub_batchify_fn1,lambda x : stack_fn(x).squeeze())
transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=512, pair=False, pad=False);def sents_segment(sentence,seg='。??!',transform=transform):sentence = sentence.strip()[:200]#超过200字的砍掉sents = []st = 0for idx, word in enumerate(sentence):if word in seg:if sentence[st:idx].strip():sent=sentence[st:idx]+wordif len(sent.strip()) > 10:#如果句子短就把短句拼接起来,为了防止batch_size*seq_size的seq_size太大,gpu内存爆了。sent = transform([sent.strip()])sents.append(sent)st = idx+1sent = transform([sentence[st:].strip()]) #每个sentence必须有一条数据,无论长短。sents.append(sent)sents = sentence_fn(sents)return sents,mx.nd.array([len(sents[0])])def get_dataloader(dataset,batch_size):dataset = gluon.data.SimpleDataset(dataset)dataset = dataset.transform_first(sents_segment)train_dataloader = gluon.data.DataLoader(dataset,batch_size=batch_size,batchify_fn=batchify_fn)return train_dataloader#-----------------------------------------------------RNN模块---------------------------------------------------------
def _get_cell_type(cell_type):"""Get the object type of the cell by parsing the inputParameters----------cell_type : str or typeReturns-------cell_constructor: typeThe constructor of the RNNCell"""if isinstance(cell_type, str):if cell_type == 'lstm':return rnn.LSTMCellelif cell_type == 'gru':return rnn.GRUCellelif cell_type == 'relu_rnn':return functools.partial(rnn.RNNCell, activation='relu')elif cell_type == 'tanh_rnn':return functools.partial(rnn.RNNCell, activation='tanh')else:raise NotImplementedErrorelse:return cell_typeclass RNNEncoder(HybridBlock):"""Parameters----------cell_type : str or functionCan be "lstm", "gru" or constructor functions that can be directly called,like rnn.LSTMCellnum_layers : intTotal number of layersnum_bi_layers : intTotal number of bidirectional layershidden_size : intNumber of hidden unitsdropout : floatThe dropout rateuse_residual : bool #是否使用残差块直通道Whether to use residual connection. Residual connection will be added in theuni-directional RNN layersi2h_weight_initializer : str or InitializerInitializer for the input weights matrix, used for the lineartransformation of the inputs.h2h_weight_initializer : str or InitializerInitializer for the recurrent weights matrix, used for the lineartransformation of the recurrent state.i2h_bias_initializer : str or InitializerInitializer for the bias vector.h2h_bias_initializer : str or InitializerInitializer for the bias vector.prefix : str, default 'rnn_'Prefix for name of `Block`s(and name of weight if params is `None`).params : Parameter or NoneContainer for weight sharing between cells. #共享权重的容器Created if `None`."""def __init__(self, cell_type='lstm', num_layers=2, num_bi_layers=1, hidden_size=128,dropout=0.0, use_residual=True,i2h_weight_initializer=None, h2h_weight_initializer=None,i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',prefix=None, params=None):super(RNNEncoder, self).__init__(prefix=prefix, params=params)self._cell_type = _get_cell_type(cell_type)assert num_bi_layers <= num_layers,\'Number of bidirectional layers must be smaller than the total number of layers, ' \'num_bi_layers={}, num_layers={}'.format(num_bi_layers, num_layers)self._num_bi_layers = num_bi_layersself._num_layers = num_layersself._hidden_size = hidden_sizeself._dropout = dropoutself._use_residual = use_residualwith self.name_scope():  #作用域命名空间,理解:方便网络block的子结构的参数命名以及不会发生冲突。self.dropout_layer = nn.Dropout(dropout)self.rnn_cells = nn.HybridSequential()for i in range(num_layers):if i < num_bi_layers:#添加多层双向神经元(包括前向和后向RNN)结构self.rnn_cells.add(rnn.BidirectionalCell(l_cell=self._cell_type(hidden_size=self._hidden_size,i2h_weight_initializer=i2h_weight_initializer,h2h_weight_initializer=h2h_weight_initializer,i2h_bias_initializer=i2h_bias_initializer,h2h_bias_initializer=h2h_bias_initializer,prefix='rnn%d_l_' % i),r_cell=self._cell_type(hidden_size=self._hidden_size,i2h_weight_initializer=i2h_weight_initializer,h2h_weight_initializer=h2h_weight_initializer,i2h_bias_initializer=i2h_bias_initializer,h2h_bias_initializer=h2h_bias_initializer,prefix='rnn%d_r_' % i)))else:#添加多层单向神经元(前向RNN)结构self.rnn_cells.add(self._cell_type(hidden_size=self._hidden_size,i2h_weight_initializer=i2h_weight_initializer,h2h_weight_initializer=h2h_weight_initializer,i2h_bias_initializer=i2h_bias_initializer,h2h_bias_initializer=h2h_bias_initializer,prefix='rnn%d_' % i))def __call__(self, inputs, states=None, valid_length=None):"""Encoder the inputs given the states and valid sequence length.Parameters----------inputs : NDArrayInput sequence. Shape (batch_size, length, C_in)states : list of NDArrays or NoneInitial states. The list of initial statesvalid_length : NDArray or NoneValid lengths of each sequence. This is usually used when part of sequence hasbeen padded. Shape (batch_size,)Returns-------encoder_outputs: listOutputs of the encoder. Contains:- outputs of the last RNN layer- new_states of all the RNN layers"""return super(RNNEncoder, self).__call__(inputs, states, valid_length)def forward(self, inputs, states=None, valid_length=None):  #pylint: disable=arguments-differ, missing-docstring#inputs是NTC结构(batch_size,sequence_size,hidden_size)_, length, _ = inputs.shapenew_states = []outputs = inputsfor i, cell in enumerate(self.rnn_cells):begin_state = None if states is None else states[i] #每次获取上一次前向计算的state结果??错了#outputs输出为NTC结构,layer_states为最后一个的statesoutputs, layer_states = cell.unroll(length=length, inputs=inputs, begin_state=begin_state, merge_outputs=True,valid_length=valid_length, layout='NTC')if i < self._num_bi_layers:#在bidirectional RNN中 outputs为[bs,ss,2*hs],layer_states为[bs,hs]*4,前2个为前向,后2个为后向。# For bidirectional RNN, we use the states of the backward RNN#取后向的2个做为下一层初始化隐状态。new_states.append(layer_states[len(self.rnn_cells[i].state_info()) // 2:])else:new_states.append(layer_states) #保存了每一层的输出隐状态# Apply Dropoutoutputs = self.dropout_layer(outputs)if self._use_residual:if i > self._num_bi_layers:outputs = outputs + inputsinputs = outputsif valid_length is not None:outputs = mx.nd.SequenceMask(outputs, sequence_length=valid_length,use_sequence_length=True, axis=1)return outputs
#--------------------------------------------------多子句网络模块--------------------------------------------------------
class MultisentNet(HybridBlock):def __init__(self, dropout=0.1, prefix=None, params=None):super(MultisentNet, self).__init__(prefix=prefix, params=params)with self.name_scope():self.embedding = None  # will set with lm embedding laterself.encoder = None  # will set with lm encoder laterself.output = nn.HybridSequential()with self.output.name_scope():self.output.add(nn.Dropout(dropout))self.output.add(nn.Dense(2, flatten=False))def hybrid_forward(self, F, data):(words, valid_len, segments), sent_segment_len = dataseq_encoding, cls_encoding = self.embedding(words, segments, valid_len.astype('float32'))data = self.bert_to_rnn_inputs(cls_encoding, sent_segment_len)#encoded batch_size*sequence_size*hidden_sizeencoded = self.encoder(data, valid_length=sent_segment_len)# encoded batch_size*hidden_size 有效长度的最后一层encoded = mx.nd.SequenceLast(encoded, sequence_length=sent_segment_len, axis=1, use_sequence_length=True)#因为sequence是变长的,无法推断Dense层参数多少,所以只能取最后一个进行全连接计算。out = self.output(encoded)return outdef bert_to_rnn_inputs(self, data, sent_segment_len):d1 = []sent_segment_len = sent_segment_len.asnumpy()sent_segment_cumlen = np.concatenate([np.array([0]), sent_segment_len.cumsum()]).astype(int)max_seg_sent = max(sent_segment_len)for i in range(len(sent_segment_cumlen)-1):tmp = data[sent_segment_cumlen[i]:sent_segment_cumlen[i+1]]if len(tmp) < max_seg_sent:tmp = mx.nd.concat(tmp,mx.nd.broadcast_axis(mx.nd.zeros((1, data.shape[1]),ctx=data.context), axis=0, size=int(max_seg_sent) - len(tmp)), dim=0)d1.append(tmp)data = mx.nd.stack(*d1, axis=0)return datalearning_rate, batch_size = 2e-5, 6
epochs = 100
grad_clip = 0.5
log_interval = 100def train(net,context, epochs):trainer = gluon.Trainer(net.collect_params(), 'adam',{'learning_rate': learning_rate})# do not apply weight decay on LayerNorm and bias termsfor _, v in net.collect_params('.*beta|.*gamma|.*bias').items():v.wd_mult = 0.0parameters = [p for p in net.collect_params().values() if p.grad_req != 'null']#loss = gluon.loss.SigmoidBCELoss()loss = gluon.loss.SoftmaxCrossEntropyLoss()#parameters = net.collect_params().values()# Training/Testingfor epoch in range(epochs):start_log_interval_time = time.time()log_interval_L = 0.0epoch_L = 0.0train_acc = 0.0for i, (data, label) in enumerate(train_dataloader):with autograd.record():x = [v.as_in_context(context) for v in data[0]]y = data[1].as_in_context(context)output = net((x,y))label = label.as_in_context(context).astype(np.float32)L = loss(output, label).mean()L.backward()#pred = (output.sigmoid() > 0.5).reshape(-1)pred = mx.nd.argmax(mx.nd.softmax(output, axis=1), axis=1)batch_acc = (pred == label).sum().asscalar()/len(label)train_acc += (pred == label).sum().asscalar()/len(label)# Clip gradientif grad_clip:gluon.utils.clip_global_norm([p.grad(context) for p in parameters],grad_clip)# Update parametertrainer.step(1)log_interval_L += L.asscalar()epoch_L += L.asscalar()# if (i + 1) % log_interval == 0:#     print(#         '[Epoch {} Batch {}/{}] elapsed {:.2f} s, '#         'avg loss {:.6f}, avg acc {:.6f}'.format(#             epoch, i + 1, len(train_dataloader),#             time.time() - start_log_interval_time,#             log_interval_L / log_interval,batch_acc))#     # Clear log interval training stats#     start_log_interval_time = time.time()#     log_interval_L = 0# print('[Epoch {}] train avg loss {:.6f}, train avg acc {:.2f}'.format(#           epoch, epoch_L / len(train_dataloader), train_acc/ len(train_dataloader)))dev_avg_L, dev_acc = evaluate(net, dev_dataloader, context)print('[Epoch {}] train avg loss {:.6f}, train avg acc {:.2f},dev acc {:.2f}, ''dev avg loss {:.6f}'.format(epoch, epoch_L / len(train_dataloader), train_acc/ len(train_dataloader),dev_avg_L, dev_acc))def evaluate(net, dataloader, context):#loss = gluon.loss.SigmoidBCELoss()loss = gluon.loss.SoftmaxCrossEntropyLoss()total_L = 0.0total_sample_num = 0total_correct_num = 0start_log_interval_time = time.time()print('Begin Testing...')for i, (data, label) in enumerate(dataloader):x = [v.as_in_context(context) for v in data[0]]y = data[1].as_in_context(context)output = net((x, y))label = label.as_in_context(context).astype(np.float32)L = loss(output, label)pred = mx.nd.argmax(mx.nd.softmax(output, axis=1), axis=1)#pred = (output.sigmoid() > 0.5).reshape(-1)total_L += L.sum().asscalar()total_sample_num += label.shape[0]total_correct_num += (pred == label).sum().asscalar()if (i + 1) % log_interval == 0:print('[Batch {}/{}] elapsed {:.2f} s'.format(i + 1, len(dataloader),time.time() - start_log_interval_time))start_log_interval_time = time.time()avg_L = total_L / float(total_sample_num)acc = total_correct_num / float(total_sample_num)return avg_L, accif __name__ == '__main__':import pandas as pdtrain_df=pd.read_csv(r'~/.fastNLP/dataset/chn_senti_corp/train.tsv',sep='\t')dev_df = pd.read_csv(r'~/.fastNLP/dataset/chn_senti_corp/dev.tsv', sep='\t')train_dataset = gluon.data.SimpleDataset([tuple(i) for i in train_df[['raw_chars', 'target']].values[:1000]])dev_dataset = gluon.data.SimpleDataset([tuple(i) for i in dev_df[['raw_chars', 'target']].values])train_dataloader = get_dataloader(train_dataset, batch_size)dev_dataloader = get_dataloader(dev_dataset, batch_size)multisentnet = MultisentNet()multisentnet.embedding = bert_modelmultisentnet.encoder = RNNEncoder(num_layers=1, num_bi_layers=0, hidden_size=128)multisentnet.encoder.initialize(mx.init.Xavier(), ctx=context)multisentnet.output.initialize(mx.init.Xavier(), ctx=context)train(multisentnet,mx.gpu(0), 100)

结果:

Begin Testing...
[Batch 100/200] elapsed 7.69 s
[Batch 200/200] elapsed 7.62 s
[Epoch 0] train avg loss 0.469793, train avg acc 0.79,dev acc 0.43, dev avg loss 0.844167
Begin Testing...
[Batch 100/200] elapsed 7.97 s
[Batch 200/200] elapsed 8.10 s
[Epoch 1] train avg loss 0.307943, train avg acc 0.90,dev acc 0.50, dev avg loss 0.866667
Begin Testing...
[Batch 100/200] elapsed 8.18 s
[Batch 200/200] elapsed 7.30 s
[Epoch 2] train avg loss 0.183324, train avg acc 0.96,dev acc 0.67, dev avg loss 0.862500
Begin Testing...
[Batch 100/200] elapsed 7.87 s
[Batch 200/200] elapsed 8.29 s
[Epoch 3] train avg loss 0.098188, train avg acc 0.98,dev acc 0.60, dev avg loss 0.876667
Begin Testing...
[Batch 100/200] elapsed 7.61 s
[Batch 200/200] elapsed 7.66 s
[Epoch 4] train avg loss 0.042816, train avg acc 0.99,dev acc 0.78, dev avg loss 0.859167
Begin Testing...
[Batch 100/200] elapsed 7.92 s
[Batch 200/200] elapsed 7.32 s
[Epoch 5] train avg loss 0.028148, train avg acc 1.00,dev acc 0.80, dev avg loss 0.873333

...

长文本切成短句的bert+lstm训练过程相关推荐

  1. 小技巧!如何把小图拼接成长图,将长图切成小图

    作为电商来说,经常要将商品详情里的小图用短时间来拼接成一张大图,又要将一整张大图剪切成若干小图.如果这要用PS来拼切长短图,也是一个不小的麻烦事,关键是还不能一次性批量完成.今天小编要介绍一个可以快速 ...

  2. php 长图切成多张图片,长图打印,图片平均切割

    图片.png 最近在知识星球学习,想要把星主分享的内容保存下来打印成纸质版方便学习.复制时提示星主开启了文件保护内容,用过知识星球的朋友应该都了解,这是为了保障星主的权益. 想要保存下来这些知识,只能 ...

  3. PS将长图片切成小图分别保存

    1.导入长图,(没有标尺的, 快捷键[CTRL+R]调出标尺): 2.鼠标左键, 参考线拉下来,放到想要分割的位置 3.然后在左边工具栏找到[切片工具] 4.然后点击顶部栏的[基于参考线的切片] 5. ...

  4. LSTM训练过程与参数解读

  5. 将预训练模型应用于长文本阅读理解

    摘要: 基于transformers预训练模型(如BERT.RoBERTa)的表现突出,预训练模型基本作为模型的baseline,但是由于self-attention的矩阵计算时间复杂度为,导致输入长 ...

  6. 谷歌BERT预训练源码解析(三):训练过程

    目录 前言 源码解析 主函数 自定义模型 遮蔽词预测 下一句预测 规范化数据集 前言 本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BE ...

  7. 上海交大:基于近似随机Dropout的LSTM训练加速

    机器之心发布 作者:宋卓然.王儒.茹栋宇.彭正皓.蒋力 上海交通大学 在这篇文章中,作者利用 Dropout 方法在神经网络训练过程中产生大量的稀疏性进行神经网络的训练加速.该论文已经被 Design ...

  8. CogLTX:应用BERT处理长文本

    论文标题:CogLTX: Applying BERT to Long Texts 论文链接:https://arxiv.org/abs/2008.02496 论文来源:NeurIPS 2020 一.概 ...

  9. 【Pytorch】BERT+LSTM+多头自注意力(文本分类)

    [Pytorch]BERT+LSTM+多头自注意力(文本分类) 2018年Google提出了BERT[1](Bidirectional Encoder Representations from Tra ...

最新文章

  1. 如何学习streamdecoder类_如何学习篇5:强化2种能力——2种学习模式之运动类:隐性学习...
  2. OC 知识:彻底理解 iOS 内存管理(MRC、ARC)
  3. 码神日志N0.1|专场邀请:深度解析音视频技术(内有福利哦~)
  4. 机器学习_的应用网站记录01_可以上色的网站
  5. [8.21NOIP模拟赛]决战【tarjan】
  6. html调用python_flask之模板html中调用python函数方法
  7. 如何赋予自主系统具备持续学习的能力?
  8. php获取每月的星期天,php计算一月中有多少个周末
  9. 记 计算机 科学学院 教师,学风浓厚,桃李芬芳—记计算机学院金国祥老师
  10. SQL综合复习试题(二)
  11. java基础 CAS算法的简单理解
  12. 步态识别之GaitSet
  13. win7升级RDP至8.1
  14. 计算机公式固定数值符号,【2人回答】Excel如何锁定,如何Excel某计算公式中锁定其中一个数值?-3D溜溜网...
  15. 概率论笔记4.1.4数学期望的性质/条件期望
  16. python递归编程题_Python数据结构与算法41:递归编程练习题4:铺瓷砖
  17. win10修复tcp驱动服务器,怎么解决tcpip.sys文件导致蓝屏|Win10的tcpip修复工具
  18. Q1营收不及预期,高通还能带着“标准”称霸5G吗?
  19. Masimo SafetyNet Alert™在西欧上市
  20. 超级计算机也无法算尽圆周率,圆周率如果被算尽意味着什么?

热门文章

  1. 不知道抖音怎么赚钱?我来告诉你做短视频真实收入
  2. 一个政策就能让多家巨头损失惨重 苹果暴露巨大权势
  3. html scroll无效,css div设置overflow-x: scroll 横向滚动无效
  4. 人脸识别技术,华为Mate 20 Pro的人 脸 识 别 解 锁
  5. VB 6.0安装过程出现的问题
  6. Notepad鼠标右键消失
  7. 算力驱动未来,AI引领时代
  8. Linux下wine使用方法
  9. 【北京迅为】《iTOP-3568开发板快速测试手册》第4章 Buildroot系统功能测试(2)
  10. Mybatis-Plus >>>QueryWrapper>>> 时间__一蓑烟雨任平生