开场

最近chatgpt已经火爆了,几乎是家喻户晓老少皆知啊,公测推出60天后就已经是UV人数过亿,日访问量号称也是过亿。投资chatgpt研发团队的微软也是2个月内迅速推出自己的chatgpt的bing搜索,股票下载量都是暴增啊。前面文章已经介绍过chatgpt技术可能会对整个人类组织分工带来的影响以及原因,这里就不在继续歪歪了。

chatgpt的一些思考

从这篇文章开始,我打算实现一个mini版本的chatgpt,把背后的原理算法、数据准备工作都会介绍到。这系列文章预计会有7-8篇,主要是讲实现,不会介绍transformer模型技术细节、ppo数学推理。

到最后大家可以收获一个问答式的文本生成工具,大家也可以根据自己需要定制训练自己的模型做自己想要做的事,比如一个跟懂自己智能助理、解读论文的神器、可以通过语音方式理解需求帮你控制智能家居、通过语音帮你画一幅你想要的画...

第一篇先介绍整个RLHF大训练框架,介绍SFT模型训练:数据、基本模型。先介绍单个模型大家先熟悉代码在自己机器上试跑训练下数据。

第二部分会对模型改造、代码封装,让代码能够在多卡多机上训练;更工业风。

第三部分把流程封装,三部分的代码做一个整合,到这边你就可以得到一个真正能够训练中文语料的链路框架,并且可以自己准备训练标注语料。

第四部分会给大家介绍基于这个小的chatgpt引擎做的各种应用探索。

宏观介绍

整个链路包括三块:

  1. 文本生成AGGENT,为了得到一个不错Agent我们需要用‘输入-输出’语料对训练一个不错基准模型,把这个过程叫做sft

  1. 评判文本生成好坏的Reward,为了得到Reward模型我们需要用‘输入-输出list’语料做一个排序打分模型,把这个过程叫做Reward

  1. 利用Reward反馈调试Agent模型PPO调控器

fig1.sft训练过程

fig2.reward训练过程

Rank数据打标

SFT实现

先训练一个基本的有文本生成能力的模型,可以选用GPT或者T5框架模型来做训练。

from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-lyric")
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-lyric")
text_generator = TextGenerationPipeline(model, tokenizer)
text_generator("最美的不是下雨天,是曾与你躲过雨的屋檐", max_length=100, do_sample=True)

GPT2

数据预处理部分

数据样式:

数据使用了deepmind整理的cnnstory部分数据:

原始数据样式如下:

(CNN)Syria is a Hell on Earth that is expanding in plain sight.The death toll there has doubled in a year's time, if an opposition group is right.Since civil war broke out there, 310,000 people have been killed, the Syrian Observatory for Human Rights said Thursday. A year earlier, SOHR's tally stood at 162,402. And the year before, the United Nations put the death toll at 70,000.Violence has plunged well over half of all Syrians into such destitution that they are in dire need of survival aid, the United Nations says, as food rations are being cut for lack of donations.Numbers alone can't convey the immeasurable anguish of millions, but maybe it can remind the rest of us of the magnitude of the world's currently greatest tragedy.The number of years since perpetual bloodshed began, since dictator Bashar al-Assad's security forces fired on crowds of demonstrators and armed militant groups rose up against him in March 2011.Percentage of the Syrian population killed. It would be like killing 3 to 4 million Americans. The range comes from the SOHR's death toll of 310,000 and a recent lower estimate by the U.N. of at least 220,000 dead.The number of Syrians in need of immediate life-saving aid, according to the U.N.  That's the population of Moscow.Syrians driven from their homes, the U.N. says. Imagine the entire Boston metropolitan area emptied out.Syrians who have fled as refugees to neighboring countries, creating humanitarian and economic hardship across Syria's borders. Turkey has taken in 1.7 million, Lebanon 1.2 million, Jordan 625,000, and Iraq 245,000.The reduction in the size of food rations the World Food Programme says it has been forced to make due to a lack of donations. That means people receiving aid will get only 60% of the daily nutrition they need.@highlightMore people have been displaced than live in Moscow; more people lost their homes than live in greater Boston@highlightThe WFP has cut food ration sizes by 30% for lack of donations

上面数据@highlight部分就是文章的摘要部分

#这个文件命名为until.py,
import random
import numpy as np
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer
from tqdm import tnrange#下面方法主要用来做gptencode
def add_special_tokens():""" Returns GPT2 tokenizer after adding separator and padding tokens """tokenizer = GPT2Tokenizer.from_pretrained('gpt2')special_tokens = {'pad_token':'<|pad|>','sep_token':'<|sep|>'}num_add_toks = tokenizer.add_special_tokens(special_tokens)return tokenizerdef set_seed(args):random.seed(args.seed)np.random.seed(args.seed)torch.manual_seed(args.seed)if args.n_gpu > 0:torch.cuda.manual_seed_all(args.seed)def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):""" Filter a distribution of logits using top-k and/or nucleus (top-p) filteringArgs:logits: logits distribution shape (vocabulary size)top_k > 0: keep only top k tokens with highest probability (top-k filtering).top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317"""assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less cleartop_k = min(top_k, logits.size(-1))  # Safety checkif top_k > 0:# Remove all tokens with a probability less than the last token of the top-kindices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]logits[indices_to_remove] = filter_valueif top_p > 0.0:sorted_logits, sorted_indices = torch.sort(logits, descending=True)cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)# Remove tokens with cumulative probability above the thresholdsorted_indices_to_remove = cumulative_probs > top_p# Shift the indices to the right to keep also the first token above the thresholdsorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()sorted_indices_to_remove[..., 0] = 0indices_to_remove = sorted_indices[sorted_indices_to_remove]logits[indices_to_remove] = filter_valuereturn logitsdef sample_seq(model, context, length, device, temperature=1, top_k=0, top_p=0.0):""" Generates a sequence of tokens Args:model: gpt/gpt2 modelcontext: tokenized text using gpt/gpt2 tokenizerlength: length of generated sequence.device: torch.device object.temperature >0: used to control the randomness of predictions by scaling the logits before applying softmax.top_k > 0: keep only top k tokens with highest probability (top-k filtering).top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering)."""context = torch.tensor(context, dtype=torch.long, device=device)context = context.unsqueeze(0)generated = contextwith torch.no_grad():  for _ in tnrange(length):inputs = {'input_ids': generated}outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)next_token_logits = outputs[0][0, -1, :] / temperaturefiltered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)return generateddef beam_search(model, context, length, beam_size, device, temperature=1):""" Generate sequence using beam search https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/Args:model: gpt/gpt2 modelcontext: tokenized text using gpt/gpt2 tokenizerlength: length of generated sequence.beam_size: >=1 and <= total_no_of_tokensdevice: torch.device object.temperature >0: used to control the randomness of predictions by scaling the logits before applying softmax."""context = torch.tensor(context, dtype=torch.long, device=device)context = context.unsqueeze(0)with torch.no_grad():  inputs = {'input_ids': context}outputs = model(**inputs) next_token_logits = outputs[0][0, -1, :] / temperaturenext_token_probs = F.softmax(next_token_logits)scores, indices = torch.topk(next_token_probs, beam_size)indices = indices.tolist()sequences = [[c] for c in indices]for _ in tnrange(length-1):logits = torch.zeros(beam_size*len(next_token_logits))for j in range(len(sequences)):new_generated = torch.cat((context,torch.tensor([sequences[j]], dtype=torch.long, device=device)),dim=1)inputs = {'input_ids': new_generated}outputs = model(**inputs) next_token_logits = outputs[0][0, -1, :] / temperaturenext_token_probs = F.softmax(next_token_logits)start, stop = j*len(next_token_logits), (j+1)*len(next_token_logits)logits[start:stop] = scores[j]*next_token_probsscores, new_logits_indices = torch.topk(logits,beam_size)logits = (new_logits_indices%50259).tolist()for j in range(len(sequences)):sequences[j] = sequences[j]+[logits[j]]return scores, sequencesdef generate_beam_sample(data, tokenizer, model, num=1, length=100, beam_size=3, device=torch.device('cuda')):""" Generate summaries for "num" number of articles using beam search.Args:data = GPT21024Dataset objecttokenizer = gpt/gpt2 tokenizernum = number of articles for which summaries has to be generated"""for i in range(num):sample = data[i]idx = sample['sum_idx']context = sample['article'][:idx].tolist()summary = sample['article'][idx+1:][:100].tolist()scores, sequences = beam_search(model, context, length, beam_size, device)print('new_article', end='\n\n')print(tokenizer.decode(context[:-1]), end='\n\n')print('actual_summary', end='\n\n')print(tokenizer.decode(summary), end='\n\n')for i in range(len(sequences)):text = tokenizer.convert_ids_to_tokens(sequences[i],skip_special_tokens=True)text = tokenizer.convert_tokens_to_string(text)  print("generated_summary-{} and Score is {}.".format(i+1, scores[i]), end='\n\n')print(text, end='\n\n')def generate_sample(data, tokenizer, model, num=1, eval_step=False, length=100, temperature=1, top_k=10, top_p=0.5, device=torch.device('cuda')):""" Generate summaries for "num" number of articles.Args:data = GPT21024Dataset objecttokenizer = gpt/gpt2 tokenizermodel = gpt/gpt2 modelnum = number of articles for which summaries has to be generatedeval_step = can be True/False, checks generating during evaluation or not"""for i in range(num):sample = data[i]idx = sample['sum_idx']context = sample['article'][:idx].tolist()summary = sample['article'][idx+1:][:100].tolist()generated_text = sample_seq(model, context, length, device, temperature, top_k, top_p)generated_text = generated_text[0, len(context):].tolist()text = tokenizer.convert_ids_to_tokens(generated_text,skip_special_tokens=True)text = tokenizer.convert_tokens_to_string(text)if eval_step==False:print('new_article', end='\n\n')print(tokenizer.decode(context), end='\n\n')print("generated_summary", end='\n\n')print(text, end='\n\n')print('actual_summary', end='\n\n')print(tokenizer.decode(summary), end='\n\n')else:print(tokenizer.decode(context), end='\n\n')print("generated_summary", end='\n\n')

把数据转成一篇文章对应一个json文件格式,json包括article、abstract两部分,同时对文本做gptencode编码处理代码如下:

import json
import os
import pickle
import sys
import timefrom utils import add_special_tokens#tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
dm_single_close_quote = '\u2019' # unicode
dm_double_close_quote = '\u201d'
# acceptable ways to end a sentence
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"',dm_single_close_quote, dm_double_close_quote, ")"]def fix_missing_period(line):"""Adds a period to a line that is missing a period"""if "@highlight" in line:return lineif line == "":return lineif line[-1] in END_TOKENS:return linereturn line + " ."def get_art_abs(lines):""" return as list of sentences"""# truncated trailing spaces, and normalize spaceslines = [' '.join(line.strip().split()) for line in lines]lines = [fix_missing_period(line) for line in lines]# Separate out article and abstract sentencesarticle_lines = []highlights = []next_is_highlight = Falsefor idx, line in enumerate(lines):if line == "":continue # empty lineelif line.startswith("@highlight"):next_is_highlight = Trueelif next_is_highlight:highlights.append(line)else:article_lines.append(line)return ' '.join(article_lines), ' '.join(highlights)def write_json(i,article, abstract):""" Saves a json file."""file = "./gpt2_1024_data/"+str(i)+".json"js_example = {}js_example['id'] = ijs_example['article'] = articlejs_example['abstract'] = abstractwith open(file, 'w') as f:json.dump(js_example, f, ensure_ascii=False)def main(file_names, directory):""" Reads txt files, extract articles and summaries, tokenize them and save as json filesArgs:file_names: list, all the articles with total no of tokens less than 1024directory: string, directory where files in file_names is stored"""tokenizer = add_special_tokens()print("Execution Started...")train_ids = []file_id_map = {}i = 0for file in file_names:file = os.path.join(os.getcwd(),directory,file)with open(file,'r',encoding='utf-8') as f:lines = f.read().split('\n\n')article, abstract = get_art_abs(lines)article, abstract = tokenizer.encode(article), tokenizer.encode(abstract)if len(article)>0 and len(abstract)>0 and (len(article)+len(abstract))<=1023:train_ids.append(i)write_json(i,article,abstract)file_id_map[i] = os.path.basename(file).replace('.story', '')i += 1if i%100==0:print(i, " files written")x,y = int(len(train_ids)*0.8), int(len(train_ids)*0.9)valid_ids = train_ids[x:y]test_ids = train_ids[y:]train_ids = train_ids[:x]with open("ids.json",'w') as f:js = {}js['train_ids'] = train_idsjs['valid_ids'] = valid_idsjs['test_ids'] = test_idsjson.dump(js,f)# file_id_map maps the json file ids to actual cnn/dm file names ending with ".story"print("saving file_id_map...")with open("file_id_map.pickle", 'wb') as f:pickle.dump(file_id_map,f)print("file_id_map saved.")if __name__ == '__main__':start = time.time()with open(sys.argv[1],'rb') as f:file_sizes = pickle.load(f)file_names = [file for file,size in file_sizes.items() if size<=1023] #only consider files with total no of tokens less than 1024if sys.argv[1].startswith("cnn"):directory = "cnn_stories_tokenized"os.chdir('/CNN/')else:directory = "dm_stories_tokenized"os.chdir('./DM/')main(file_names, directory)print("total_time_taken: ", (time.time()-start)/60, " minutes")

处理完的数据格式如下

{"id": 0, "article": [12, 43, 27912, 12, 8100, 532, 21095, 33, 12, 1377, 7214, 4621, 286, 262, 890, 5041, 351, 257, 474, 5978, 284, 534, 17627, 764, 775, 1965, 1312, 6207, 3816, 284, 2648, 5205, 286, 511, 4004, 7505, 3952, 5636, 2171, 764], "abstract": [9787, 503, 8100, 13, 785, 7183, 705, 7505, 3952, 5205, 764, 1471, 19550, 287, 319, 262, 995, 705, 82, 27627, 6386, 1660, 19392, 764]}

模型训练部分

#这部分代码拷贝命名'dataset.py'
import os
import json
import numpy as np
import torch
from torch.utils.data import Datasetfrom utils import add_special_tokensclass GPT21024Dataset(Dataset):def __init__(self, root_dir, ids_file, mode='train',length=None):self.root_dir = root_dirself.tokenizer = add_special_tokens()# with open(ids_file,'r') as f:# if mode=='train':#     self.idxs = np.array(json.load(f)['train_ids'])# elif mode=='valid':#     self.idxs = np.array(json.load(f)['valid_ids'])# elif mode=='test':#     self.idxs = np.array(json.load(f)['test_ids'])# self.idxs = self.idxs -min(self.idxs)self.idxs = os.listdir(root_dir)self.mode = modeif len == None:self.len = len(self.idxs)else:self.len = lengthdef __len__(self):return self.lendef __getitem__(self,idx):if self.mode=='valid':idx = self.idxs[-idx]elif self.mode=='test':idx = self.idxs[-idx-self.len]   # assuming valid and test set of same sizeselse:idx = self.idxs[idx]# file_name = os.path.join(self.root_dir,str(idx)+".json")file_name = os.path.join(self.root_dir,str(idx))with open(file_name,'r') as f:data = json.load(f)text = self.tokenizer.encode(self.tokenizer.pad_token)*1024content = data['article'] + self.tokenizer.encode(self.tokenizer.sep_token) + data['abstract']text[:len(content)] = contenttext = torch.tensor(text)sample = {'article': text, 'sum_idx': len(data['article'])}return sample
#训练部分代码
import argparse
from datetime import datetime
import os
import timeimport numpy as np
from transformers import GPT2LMHeadModel,AdamW, WarmupLinearSchedule
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tnrange, tqdm_notebookfrom dataset import GPT21024Dataset
from utils import add_special_tokens, generate_sample, set_seed#please change default arguments if neededparser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, help="learning rate")
parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--n_gpu",default=1, type=int,  help="no of gpu available")
parser.add_argument("--gradient_accumulation_steps",default=2, type=int, help="gradient_accumulation_steps")
parser.add_argument("--batch_size",default=1, type=int,  help="batch_size")
parser.add_argument("--num_workers",default=4, type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device('cpu'), help="torch.device object")
parser.add_argument("--num_train_epochs",default=1, type=int,  help="no of epochs of training")
parser.add_argument("--output_dir",default='./output', type=str,  help="path to save evaluation results")
parser.add_argument("--model_dir",default='./weights', type=str,  help="path to save trained model")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="max gradient norm.")
parser.add_argument("--root_dir",default='./CNN/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default='./CNN/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args([])
print(args)def train(args, model, tokenizer, train_dataset, valid_dataset, ignore_index):""" Trains GPT2 model and logs necessary details.Args:args: dict that contains all the necessary information passed by user while trainingmodel: finetuned gpt/gpt2 modeltokenizer: GPT/GPT2 tokenizertrain_dataset: GPT21024Dataset object for training dataignore_index: token not considered in loss calculation"""writer = SummaryWriter('./output/logs')train_sampler = RandomSampler(train_dataset)train_dl = DataLoader(train_dataset,sampler=train_sampler,batch_size=args.batch_size,num_workers=args.num_workers)loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculationoptimizer = AdamW(model.parameters(),lr=args.lr)scheduler = WarmupLinearSchedule(optimizer,100,80000)global_step = 0tr_loss, logging_loss = 0.0, 0.0model.zero_grad()train_iterator = tnrange(int(args.num_train_epochs), desc="Epoch")set_seed(args)for _ in train_iterator:epoch_iterator = tqdm_notebook(train_dl, desc="Training")for step, batch in enumerate(epoch_iterator):inputs, labels = batch['article'].to(args.device), batch['article'].to(args.device)model.train()logits = model(inputs)[0]# only consider loss on reference summary just like seq2seq modelsshift_logits = logits[..., batch['sum_idx']:-1, :].contiguous()shift_labels = labels[..., batch['sum_idx']+1:].contiguous()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))loss = loss/args.gradient_accumulation_stepsloss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)tr_loss += loss.item()if (step + 1) % args.gradient_accumulation_steps == 0:optimizer.step()scheduler.step()  # Update learning rate schedulemodel.zero_grad()global_step += 1writer.add_scalar('lr', scheduler.get_lr()[0], global_step)writer.add_scalar('loss', (tr_loss - logging_loss)/args.gradient_accumulation_steps, global_step)logging_loss = tr_lossprint("loss:", loss.item(), end='\n\n')if (step + 1)/args.gradient_accumulation_steps == 1.0:print('After 1st update: ', end='\n\n')generate_sample(valid_dataset, tokenizer, model, num=2, eval_step=False,device=args.device)if (step + 1) % (10*args.gradient_accumulation_steps) == 0:results = evaluate(args, model, valid_dataset, ignore_index, global_step)for key, value in results.items():writer.add_scalar('eval_{}'.format(key), value, global_step)print('After', global_step+1,'updates: ', end='\n\n')generate_sample(valid_dataset, tokenizer, model, num=2, eval_step=True,device=args.device)# creating training and validation dataset objecttrain_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='train',length=3000) #training on only 3000 datasets
valid_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='valid',length=500)  #validation on only 500 datasets# load pretrained GPT2
tokenizer = add_special_tokens()
ignore_idx = tokenizer.pad_token_id
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.to(args.device)#training the modelstart = time.time()
train(args, model, tokenizer, train_data, valid_data, ignore_idx)
print('total time: ', (time.time()-start)/60, " minutes", end='\n\n')print('Saving trained model...')
model_file = os.path.join(args.model_dir, 'model_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.bin'.format(len(train_data),args.num_train_epochs))
config_file = os.path.join(args.model_dir, 'config_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.json'.format(len(train_data),args.num_train_epochs))
torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)

训练好的模型做inference

import argparse
import osfrom bs4 import BeautifulSoup
from googlesearch import search
import numpy as np
import requests
from transformers import GPT2Config, GPT2LMHeadModel
import torch
from tqdm import tnrange, tqdm_notebookfrom dataset import GPT21024Dataset
from utils import add_special_tokens, beam_search, generate_beam_sample, generate_sample, sample_seq, set_seed, top_k_top_p_filtering#please change default arguments if neededparser = argparse.ArgumentParser()parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--num_workers",default=4, type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device('cuda'), help="torch.device object")
parser.add_argument("--output_dir",default='./output', type=str,  help="path to save evaluation results")
parser.add_argument("--model_dir",default='./weights', type=str,  help="path to save trained model")
parser.add_argument("--root_dir",default='./CNN/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default='./CNN/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args([])
print(args)# using the same validation and training data as during training
tokenizer = add_special_tokens()
# train_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='train',length=3000)
# valid_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='valid',length=500)
test_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='test',length=500)# model_file and config_file are files used to load finetuned model, change these name as per your file names# model_file = os.path.join(args.model_dir, 'model_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.bin'.format(len(train_data),args.num_train_epochs))
# config_file = os.path.join(args.model_dir, 'config_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.json'.format(len(train_data),args.num_train_epochs))# path to model and config files
model_file = "345-model_O0_data3000_trained_after_5_epochs_only_sum_loss_ignr_pad.bin"
config_file = "345-config_O0_data3000_trained_after_5_epochs_only_sum_loss_ignr_pad.json"config = GPT2Config.from_json_file(config_file)
model = GPT2LMHeadModel(config)
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.eval()
model.to(args.device)generate_sample(test_data, tokenizer, model, num=2, length=100, temperature=1, top_k=10, top_p=0.5, device=args.device)

生成结果:

HBox(children=(IntProgress(value=0), HTML(value='')))
new_articleRome -LRB- CNN -RRB- -- A cruise ship of the Costa Cruises line is adrift off the coast of the Seychelles after a fire in its engine room, the Italian coast guard said Monday. The ship, the Allegra, is a sister of the Costa Concordia, which wrecked off the coast of Italy on January 13, killing at least 21 people. The fire left the Allegra without propulsion, although its communications equipment is intact, the authorities said. The Allegra's fire has been put out, and the passengers are all in good health, the authorities said. The Seychelles is sending a tug, and merchant ships in the area are steaming toward the Allegra, the coast guard said.generated_summaryThe ship is carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carrying cargo from the Seychelles . The ship was carryingactual_summaryAn engine room fire leaves the Costa Allegra without propulsion, authorities say. Its sister ship, the Costa Concordia, shipwrecked last month, killing at least 21. <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>

代码链接:https://github.com/AigcLwq/miniChatgpt.git

T5

下次迭代更新

自己动手做一个mini-智能小助理相关推荐

  1. arduino智能浇花系统_解放双手!自己动手做一个简易智能浇花系统

    原标题:解放双手!自己动手做一个简易智能浇花系统 面对疫情,宅在家的我们可以以各种方式为战"疫"一线的医护工作者.紧急研究病毒的科研人员.口罩厂日夜工作的人们......加油打气. ...

  2. 动手做个mini智能助理--数据准备(2)

    背景: 这部分会介绍如何准备自己的数据,利用chatgpt的self-instruct的方式批量的生成平行语料对.chatgpt有超强的生成能力,并且chatgpt的生成结果有经过harmless.种 ...

  3. mini电脑做linux,自己动手做一个Mini Linux

    今天我们来一步步手动构建一个小于10M的类嵌入式Linux系统,所谓"工欲善其事,必先利其器",所以我们得先准备一下工具才能完成Mini Linux的构建. 环境准备 1.宿主机系 ...

  4. 动手做一个简单的智能小车

    动手做一个简单的智能小车 来到CNDN一年了,看到了许多大佬的杰出作品.也该写点什么来回馈给大家了前不久接触了单片机,想提前进行实践一下所以有想法做一个实体出来,想来想去难的怕自己搞不定,但是还好找到 ...

  5. 开关面板如何自己印字_如何自己动手做一个智能开关

    现在的智能家居这么火,对于想自己动手的小伙伴们来说,都想自己去做一些家里使用 的智设备.现在的中国不缺卖唱卖惨的,缺的是能动手创造一些能实际使用的而不是哗众取宠的人,天天喊着要反击外国技术封锁.那么我 ...

  6. Esp8266学习之旅13 动手做个8266毕设小案例,smartConfig + MQTT协议轻松实现远程控制一盏LED。(附带demo)

    本系列博客学习由非官方人员 半颗心脏 潜心所力所写,不做开发板.仅仅做个人技术交流分享,不做任何商业用途.如有不对之处,请留言,本人及时更改. 序号 SDK版本 内容 链接 1 nonos2.0 搭建 ...

  7. 直播网站源码直播平台软件开发iOS动手做一个直播(原理篇)

    直播网站源码直播平台软件开发iOS动手做一个直播(原理篇) 上篇文章主要给出了代码,但是并没有详细说明直播相关的知识,这篇文章就说一下直播的相关理论知识.附上直播代码篇地址. ###推流 腾讯直播平台 ...

  8. 自己动手做一个小爱同学温湿度传感器(成本八块左右)

    自己动手做一个小爱同学温湿度传感器 1.开发环境简介 2.开发思路 3.程序编写 (1)将点灯科技库文件和DHT11模块库文件导入Arduino的libraries文件夹. (2)下载点灯科技APP, ...

  9. 动手做一个自组网的网络 - 操作系统内核

    动手做一个自组网的网络 - 操作系统内核 动手做一个自组网的网络 - 项目介绍 动手做一个自组网的网络 - 硬件开发板 动手做一个自组网的网络 - 操作系统内核 动手做一个自组网的网络 - 网络协议栈 ...

最新文章

  1. 在Ubuntu 14.04 64bit上查看硬件配置信息
  2. 移动端与PHP服务端接口通信流程设计(基础版)
  3. 移动表到另一表空间命令
  4. 行,Python终于跌神坛了!程序员:活该!你敢来评论吗...
  5. python爬虫多url_Python爬虫抓取多个URL写入本地文件
  6. 杰里杰理 ANC耳机方案设计指南【篇】
  7. Java 解压 gzip 和 tar.gz 文件
  8. python panel dataframe_Pandas面板(Panel)
  9. 树莓派Raspberry Pi上手报告
  10. 58 张图,手把手教会你 Simscape Multibody 物理建模与刚体变换!
  11. YOLO系列算法原理介绍
  12. Java实例——Java方法
  13. java基础之TreeMap
  14. 【JavaScript实现十进制转换成二进制】
  15. Ubuntu22.04上安装Xilinix Vivado 2018.3
  16. 用 zotero 管理文献和个人知识库
  17. 行人再识别之评估标准(CMC曲线)
  18. Win11系统找不到dll文件怎么修复
  19. Ant design pro (九) 修改Title 图标等小功能
  20. linux中的bin目录的作用,linux中bin与sbin目录的作用及区别介绍

热门文章

  1. net core配置跨域
  2. 基于 TCP 的 Qt 网络通信
  3. python飞船游戏(三)
  4. 【EF英孚教育】报道事故实用英文 12月30日
  5. 【转载】Uber Go语言编码规范
  6. 图片流转base64遇到的坑
  7. c语言文件 copyfile,如何在Visual C中正确调用CopyFile函数?
  8. 张清:那个叫程佩的女子
  9. 重写submit提交事件
  10. 电子电路基础知识——电阻,电容,电感