转载,原博客:https://blog.csdn.net/iteapoy/article/details/106478462

1 数据与说明

数据下载

数据下载链接:点击下载

数据是一个data.zip压缩包,解压后的目录树如下所示:

D:.
│  eng-fra.txt
│
└─namesArabic.txtChinese.txtCzech.txtDutch.txtEnglish.txtFrench.txtGerman.txtGreek.txtIrish.txtItalian.txtJapanese.txtKorean.txtPolish.txtPortuguese.txtRussian.txtScottish.txtSpanish.txtVietnamese.txt

eng-fra.txt 是第三篇翻译任务中要用到的,这次我们只用到 /name 这个文件夹下的18个文件,每个文件以语言命名,格式为:[Language].txt。打开后,里面是该语言中常用的姓/名。

比如:打开我们最熟悉的 Chinese.txt,可以看到每一行是一个姓或者名(有一些姓/名确实有点点奇怪,但整体来说问题不大)。

Ang
Au-Yong
Bai
Ban
Bao
Bei
Bian
Bui
Cai
Cao
Cen
……

任务说明
这次任务的目标是,输入一个姓名,根据它的拼写,用循环神经网络对它分类,判断它属于哪个语言里的姓名。

比如:

$ python predict.py Hinton
(-0.47) Scottish
(-1.52) English
(-3.57) Irish$ python predict.py Schmidhuber
(-0.19) German
(-2.48) Czech
(-2.68) Dutch
  • Hinton 这个姓名很有可能是Scottish,其次可能是English,再其次可能是Irish。
  • Schmidhuber 这个姓名很有可能是German,其次可能是Czech,再其次可能是Dutch。

2 基础原理

RNN

一般的神经网络都是单向的,一层连着下一层。而循环神经网络(Recurrent Neural Network)和它的名字一样,里面引入了循环体结构,就像我们写代码的 for 或者 while 循环一样,某一步的循环体就像下面这样:

xt是第 t 步循环时的输入,ht是第 t 步循环的输出,它们都是向量,不是标量(一个数值)。这样一个循环体就可以把信息从上一步传递到下一步。不过,这样的循环体看起来不太好懂,让我们把它按时序展平(降维攻击!!!),变成一般的神经网络那样的单向传播结构。展开后就是一个链状结构:

这样我们就可以看到,从第0步到第t步之间都发生了一些什么。每一个A块里的东西都是一样的,你可以理解成 for (i=0; i<t; i++) 或者 for i in range(t) 块中的代码。所以,我们只需要写某一步的变量更新方式,然后让它循环就可以了。

现在的问题是:变量到底应该怎么更新?输入的xt
应该如何处理,才能变成输出的ht?图里的 A 内部具体的更新结构如下:

流程如下:

  1. 把上一步输出的ht−1 乘上一个权重矩阵 Wh
    ,变成Whht−1 。h是隐藏层(hidden layer)的简写
  2. 这一步输入的xt 也乘上一个权重矩阵 Wi, 变成Wixt。i是输入(input)的简写
  3. 把它们相加,变成 Whht−1 + Wixt
  4. 经过一个 tanh()函数的处理,就得到了这一步的输出:ht = tanh(Whht−1 + Wixt)

把流程1到流程4反复循环,就是一个最简单的RNN。
另外,你可能会看到一种带有偏置向量 b 的更新方式:

ht = tanh((Whht-1 + bh) + (Wixt + bi))

我们这里进行了简化,即令所有的向量 b 都为0。另外,我们在初始化向量h0
的时候,也会把它初始化成全为0的向量。

RNN这样一个结构用来处理有前后关联的序列非常有效,因此在自然语言处理里也取得了不错的成绩。因为一句话可以看成是许多词组成的序列,这些词之前有前后文/上下文关系。


LSTM

不过,普通的RNN有一个长短句依赖的问题(不细讲了,反正就是不太好使),所以有人提出了LSTM来改进RNN。LSTM是长短期记忆网络(LSTM,Long Short-Term Memory),通过三个门(遗忘门、输入门、输出门)的控制,存储短期记忆或长期记忆。它的整体流程还是这样:

但是,LSTM里的一个 A 内部的结构变成了这样子:

自从RNN整容成LSTM后,你再也不认识它了……

说实话,这张图美则美矣,我觉得还是李宏毅老师的简化版容易入门,一起贴上来吧!图里省略了几个tanh ⁡ ( ⋅ )函数,更方便理解

图中是第 t 步的更新情况, 就是σ ( ⋅ ))的S型曲线,即sigmoid函数:

我们先看李宏毅老师的高级简化版:
对同一个输入xt,乘上不同的权重Wf、Wi、W、Wo
就变成了四个不同的值。

  • zf = Wfxt,是遗忘门(Forget Gate)的输入。
  • zi = Wixt,是输入门(Input Gate)的输入。
  • z = Wxt,是真正的输入(和输入门的输入是不一样的)
  • zo = Woxt,是输出门(Output Gate)的输入

    右边的图从下往上看,我们先来看红色方框圈出来的部分,输入和输入门的更新,它负责判断是否要接受新的输入:
  • zi乘以权重Wi,加上偏置向量bi,经过输入门,变成了it = σ(Wizi + bi)
  • z乘以权重Wc,加上偏置向量bc,经过tanh(),变成了ct=σ(Wcz + bc)
  • 把itct按元素相乘(Hadamard乘积,运算符为⊙或∗),得到it*ct,然后输入cell。

然后,我们来看看蓝色圈出来的部分,是遗忘门的更新,它负责判断是否要更新cell中的值,如果更新了,就要忘记之前的值,写入新的值:

  • cell中存放了上一步的存储向量ct-1
  • zf乘以权重Wf,加上偏置向量bf,经过遗忘门,变成了ft=σ(Wfzf+bf)
  • 在cell中,把ct-1和ft同样按元素相乘,然后和刚才提到的it*ct相加,就变成了新的存储向量ct=ct-1*ft + it*ct

最后,我们来看一下橙色圈出来的部分,输出门的更新,它负责判断是否要输出最后的值:

  • 新的ct先通过一个tanh()函数,变成tanh(ct)
  • zo乘以权重Wo,加上偏置向量bo,经过输出门,变成了ot=σ(Wozo + bo)
  • 把tanh(ct)和ot按元素相乘,得到新的隐藏层状态ht=ot*tanh(ct)
  • 隐藏层状态ht通过一个softmax函数,得到最后的输出yt=softmax(ht)

以上,就是LSTM中某一步的状态更新情况。我们再回头看这张图:

之前,zf,zi,z,zo都是xt乘以不同的权重得到的。但是,光凭xt,不足以传递足够多的信息,我们把xt和上一步输出的隐藏状态ht-1拼在一起,变成一个新的输入向量[xt,ht-1],上面的更新公式变成了下面这样。




说到这里,所谓的“门”,实际上就是一个 σ ( ⋅ ) 或者 tanh ⁡ ( ⋅ ) 函数。不过,一个好的命名便于更形象化地理解。

因为 LSTM 实在比 RNN 优秀太多,所以我们一般称循环神经网络的时候,其实都是在说LSTM。


GRU

这里省略了偏置向量b。


one-hot编码

刚才说到,输入是xt。在自然语言处理中,我们不可能把一个字母作为输入,进行向量、矩阵的乘法,因此,我们需要把它变成一个特征向量。xt可以是字母的特征向量,或者单词的特征向量,或者句子的特征向量。

本文中,xt是一个字母的特征向量。

我们知道,计算机存储字母一般是用ASCll编码,比如a是97,b是98,c是99,d是100…或者我们也可以说,a是1,b是2,c是3…但是,用这样的连续值表示字母,有一个问题,这意味着a和b的关系比较近,a和z的关系比较远,但是实际上,他们并没有这种内在的关系。我们用一种编码表示它们的时候,它们应该是相互独立的,即对它们进行离散化。

我们通常用到的是 one-hot 编码。它是一个长度为 n nn 的向量,只有 1 11 个数字是1,其它的 n − 1 n-1n−1 个数字都是0。one-hot 编码使得每个字母在它们各自的维度上,与其它字母是独立的。

比如一个单词 “apple”,就分别对a、p、p、l、e 编码,作为输入,LSTM需要循环5次。

  • x0 = [1,0,0,0,0…,0,0]:a
  • x1 = [0,1,0,0,0…,0,0]:p
  • x2 = [0,1,0,0,0…,0,0]:p
  • x3 = [0,0,1,0,0…,0,0]:l
  • x4 = [0,0,0,1,0…,0,0]:e

上面就是本文提到的字母级(character-level)RNN,而单词级(word-level)RNN就是把整个单词 “apple” 编码成一个向量。通常,对单词的编码用于Seq2Seq模型,即处理的是一个序列 “An apple a day keeps the doctor away”:

  • x0 = [1,0,0,0,0,0,0,0,…,0,0]: an
  • x1 = [0,1,0,0,0,0,0,0,…,0,0]: apple
  • x2 = [0,0,1,0,0,0,0,0,…,0,0]: a
  • x3 = [0,0,0,1,0,0,0,0,…,0,0]: day
  • x4 = [0,0,0,0,1,0,0,0,…,0,0]: keeps
  • x5 = [0,0,0,0,0,1,0,0,…,0,0]: the
  • x6 = [0,0,0,0,0,0,1,0,…,0,0]: doctor
  • x7 = [0,0,0,0,0,0,0,1…,0,0]: away

不过本文中,xt是一个字母的特征向量,而不是一个单词的特征向量。

3 代码

数据预处理

首先,我们把所有的 /name/[Language].txt 文件读进来。

n_letters 表示所有字母的数量。因为某些语言的字母和常见的英文字母不太一样,所以我们需要把它转化成普普通通的英文字母,用到了 unicodeToAscii() 函数。

from __future__ import unicode_literals,print_function,division
from io import open
import glob
import osdef findFiles(path): return glob.glob(path)print(findFiles('data/names/*.txt'))import unicodedata
import stringall_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD',s)if unicodedata.category(c)!='Mn'and c in all_letters)print(unicodeToAscii('Ślusàrski'))

Out:

['data/names/Greek.txt', 'data/names/Dutch.txt', 'data/names/Irish.txt', 'data/names/Arabic.txt', 'data/names/Korean.txt', 'data/names/French.txt', 'data/names/Spanish.txt', 'data/names/German.txt', 'data/names/Portuguese.txt', 'data/names/Italian.txt', 'data/names/Vietnamese.txt', 'data/names/Russian.txt', 'data/names/Scottish.txt', 'data/names/Chinese.txt', 'data/names/English.txt', 'data/names/Japanese.txt', 'data/names/Czech.txt', 'data/names/Polish.txt']
Slusarski

文件 [Language].txt 的命名中,Language 就是类别 category 。把每个文件打开,读入每一行,放入一个数组 lines = [names …] 。建立一个词典 category_lines = {language: lines}

category_lines = {}
all_categories = []def readLines(filename):lines = open(filename,encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]for filename in findFiles('data/names/*.txt'):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)
print(all_categories)
print(category_lines['Italian'])

Out:

['Greek', 'Dutch', 'Irish', 'Arabic', 'Korean', 'French', 'Spanish', 'German', 'Portuguese', 'Italian', 'Vietnamese', 'Russian', 'Scottish', 'Chinese', 'English', 'Japanese', 'Czech', 'Polish']
['Abandonato', 'Abatangelo', 'Abatantuono', 'Abate', 'Abategiovanni', 'Abatescianni', 'Abba', 'Abbadelli', 'Abbascia', 'Abbatangelo', 'Abbatantuono', 'Abbate', 'Abbatelli', 'Abbaticchio', 'Abbiati', 'Abbracciabene', 'Abbracciabeni', 'Abelli', 'Abello', 'Abrami', 'Abramo', 'Acardi', 'Accardi', 'Accardo', 'Acciai', 'Acciaio', 'Acciaioli', 'Acconci', 'Acconcio', 'Accorsi', 'Accorso', 'Accosi', 'Accursio', 'Acerbi', 'Acone', 'Aconi', 'Acqua', 'Acquafredda', 'Acquarone', 'Acquati', 'Adalardi', 'Adami', 'Adamo', 'Adamoli', 'Addario', 'Adelardi', 'Adessi', 'Adimari', 'Adriatico', 'Affini', 'Africani', 'Africano', 'Agani', 'Aggi', 'Aggio', 'Agli', 'Agnelli', 'Agnellutti', 'Agnusdei', 'Agosti', 'Agostini', 'Agresta', 'Agrioli', 'Aiello', 'Aiolfi', 'Airaldi', 'Airo', 'Aita', 'Ajello', 'Alagona', 'Alamanni', 'Albanesi', 'Albani', 'Albano', 'Alberghi', 'Alberghini', 'Alberici', 'Alberighi', 'Albero', 'Albini', 'Albricci', 'Albrici', 'Alcheri', 'Aldebrandi', 'Alderisi', 'Alduino', 'Alemagna', 'Aleppo', 'Alesci', 'Alescio', 'Alesi', 'Alesini', 'Alesio', 'Alessandri', 'Alessi', 'Alfero', 'Aliberti', 'Alinari', 'Aliprandi', 'Allegri', 'Allegro', 'Alo', 'Aloia', 'Aloisi', 'Altamura', 'Altimari', 'Altoviti', 'Alunni', 'Amadei', 'Amadori', 'Amalberti', 'Amantea', 'Amato', 'Amatore', 'Ambrogi', 'Ambrosi', 'Amello', 'Amerighi', 'Amoretto', 'Angioli', 'Ansaldi', 'Anselmetti', 'Anselmi', 'Antonelli', 'Antonini', 'Antonino', 'Aquila', 'Aquino', 'Arbore', 'Ardiccioni', 'Ardizzone', 'Ardovini', 'Arena', 'Aringheri', 'Arlotti', 'Armani', 'Armati', 'Armonni', 'Arnolfi', 'Arnoni', 'Arrighetti', 'Arrighi', 'Arrigucci', 'Aucciello', 'Azzara', 'Baggi', 'Baggio', 'Baglio', 'Bagni', 'Bagnoli', 'Balboni', 'Baldi', 'Baldini', 'Baldinotti', 'Baldovini', 'Bandini', 'Bandoni', 'Barbieri', 'Barone', 'Barsetti', 'Bartalotti', 'Bartolomei', 'Bartolomeo', 'Barzetti', 'Basile', 'Bassanelli', 'Bassani', 'Bassi', 'Basso', 'Basurto', 'Battaglia', 'Bazzoli', 'Bellandi', 'Bellandini', 'Bellincioni', 'Bellini', 'Bello', 'Bellomi', 'Belloni', 'Belluomi', 'Belmonte', 'Bencivenni', 'Benedetti', 'Benenati', 'Benetton', 'Benini', 'Benivieni', 'Benvenuti', 'Berardi', 'Bergamaschi', 'Berti', 'Bertolini', 'Biancardi', 'Bianchi', 'Bicchieri', 'Biondi', 'Biondo', 'Boerio', 'Bologna', 'Bondesan', 'Bonomo', 'Borghi', 'Borgnino', 'Borgogni', 'Bosco', 'Bove', 'Bover', 'Boveri', 'Brambani', 'Brambilla', 'Breda', 'Brioschi', 'Brivio', 'Brunetti', 'Bruno', 'Buffone', 'Bulgarelli', 'Bulgari', 'Buonarroti', 'Busto', 'Caiazzo', 'Caito', 'Caivano', 'Calabrese', 'Calligaris', 'Campana', 'Campo', 'Cantu', 'Capello', 'Capello', 'Capello', 'Capitani', 'Carbone', 'Carboni', 'Carideo', 'Carlevaro', 'Caro', 'Carracci', 'Carrara', 'Caruso', 'Cassano', 'Castro', 'Catalano', 'Cattaneo', 'Cavalcante', 'Cavallo', 'Cingolani', 'Cino', 'Cipriani', 'Cisternino', 'Coiro', 'Cola', 'Colombera', 'Colombo', 'Columbo', 'Como', 'Como', 'Confortola', 'Conti', 'Corna', 'Corti', 'Corvi', 'Costa', 'Costantini', 'Costanzo', 'Cracchiolo', 'Cremaschi', 'Cremona', 'Cremonesi', 'Crespo', 'Croce', 'Crocetti', 'Cucinotta', 'Cuocco', 'Cuoco', "D'ambrosio", 'Damiani', "D'amore", "D'angelo", "D'antonio", 'De angelis', 'De campo', 'De felice', 'De filippis', 'De fiore', 'De laurentis', 'De luca', 'De palma', 'De rege', 'De santis', 'De vitis', 'Di antonio', 'Di caprio', 'Di mercurio', 'Dinapoli', 'Dioli', 'Di pasqua', 'Di pietro', 'Di stefano', 'Donati', "D'onofrio", 'Drago', 'Durante', 'Elena', 'Episcopo', 'Ermacora', 'Esposito', 'Evangelista', 'Fabbri', 'Fabbro', 'Falco', 'Faraldo', 'Farina', 'Farro', 'Fattore', 'Fausti', 'Fava', 'Favero', 'Fermi', 'Ferrara', 'Ferrari', 'Ferraro', 'Ferrero', 'Ferro', 'Fierro', 'Filippi', 'Fini', 'Fiore', 'Fiscella', 'Fiscella', 'Fonda', 'Fontana', 'Fortunato', 'Franco', 'Franzese', 'Furlan', 'Gabrielli', 'Gagliardi', 'Gallo', 'Ganza', 'Garfagnini', 'Garofalo', 'Gaspari', 'Gatti', 'Genovese', 'Gentile', 'Germano', 'Giannino', 'Gimondi', 'Giordano', 'Gismondi', 'Giugovaz', 'Giunta', 'Goretti', 'Gori', 'Greco', 'Grillo', 'Grimaldi', 'Gronchi', 'Guarneri', 'Guerra', 'Guerriero', 'Guidi', 'Guttuso', 'Idoni', 'Innocenti', 'Labriola', 'Laconi', 'Lagana', 'Lagomarsino', 'Lagorio', 'Laguardia', 'Lama', 'Lamberti', 'Lamon', 'Landi', 'Lando', 'Landolfi', 'Laterza', 'Laurito', 'Lazzari', 'Lecce', 'Leccese', 'Leggieri', 'Lemmi', 'Leone', 'Leoni', 'Lippi', 'Locatelli', 'Lombardi', 'Longo', 'Lupo', 'Luzzatto', 'Maestri', 'Magro', 'Mancini', 'Manco', 'Mancuso', 'Manfredi', 'Manfredonia', 'Mantovani', 'Marchegiano', 'Marchesi', 'Marchetti', 'Marchioni', 'Marconi', 'Mari', 'Maria', 'Mariani', 'Marino', 'Marmo', 'Martelli', 'Martinelli', 'Masi', 'Masin', 'Mazza', 'Merlo', 'Messana', 'Micheli', 'Milani', 'Milano', 'Modugno', 'Mondadori', 'Mondo', 'Montagna', 'Montana', 'Montanari', 'Monte', 'Monti', 'Morandi', 'Morello', 'Moretti', 'Morra', 'Moschella', 'Mosconi', 'Motta', 'Muggia', 'Muraro', 'Murgia', 'Murtas', 'Nacar', 'Naggi', 'Naggia', 'Naldi', 'Nana', 'Nani', 'Nanni', 'Nannini', 'Napoleoni', 'Napoletani', 'Napoliello', 'Nardi', 'Nardo', 'Nardovino', 'Nasato', 'Nascimbene', 'Nascimbeni', 'Natale', 'Nave', 'Nazario', 'Necchi', 'Negri', 'Negrini', 'Nelli', 'Nenci', 'Nepi', 'Neri', 'Neroni', 'Nervetti', 'Nervi', 'Nespola', 'Nicastro', 'Nicchi', 'Nicodemo', 'Nicolai', 'Nicolosi', 'Nicosia', 'Nicotera', 'Nieddu', 'Nieri', 'Nigro', 'Nisi', 'Nizzola', 'Noschese', 'Notaro', 'Notoriano', 'Oberti', 'Oberto', 'Ongaro', 'Orlando', 'Orsini', 'Pace', 'Padovan', 'Padovano', 'Pagani', 'Pagano', 'Palladino', 'Palmisano', 'Palumbo', 'Panzavecchia', 'Parisi', 'Parma', 'Parodi', 'Parri', 'Parrino', 'Passerini', 'Pastore', 'Paternoster', 'Pavesi', 'Pavone', 'Pavoni', 'Pecora', 'Pedrotti', 'Pellegrino', 'Perugia', 'Pesaresi', 'Pesaro', 'Pesce', 'Petri', 'Pherigo', 'Piazza', 'Piccirillo', 'Piccoli', 'Pierno', 'Pietri', 'Pini', 'Piovene', 'Piraino', 'Pisani', 'Pittaluga', 'Poggi', 'Poggio', 'Poletti', 'Pontecorvo', 'Portelli', 'Porto', 'Portoghese', 'Potenza', 'Pozzi', 'Profeta', 'Prosdocimi', 'Provenza', 'Provenzano', 'Pugliese', 'Quaranta', 'Quattrocchi', 'Ragno', 'Raimondi', 'Rais', 'Rana', 'Raneri', 'Rao', 'Rapallino', 'Ratti', 'Ravenna', 'Re', 'Ricchetti', 'Ricci', 'Riggi', 'Righi', 'Rinaldi', 'Riva', 'Rizzo', 'Robustelli', 'Rocca', 'Rocchi', 'Rocco', 'Roma', 'Roma', 'Romagna', 'Romagnoli', 'Romano', 'Romano', 'Romero', 'Roncalli', 'Ronchi', 'Rosa', 'Rossi', 'Rossini', 'Rotolo', 'Rovigatti', 'Ruggeri', 'Russo', 'Rustici', 'Ruzzier', 'Sabbadin', 'Sacco', 'Sala', 'Salomon', 'Salucci', 'Salvaggi', 'Salvai', 'Salvail', 'Salvatici', 'Salvay', 'Sanna', 'Sansone', 'Santini', 'Santoro', 'Sapienti', 'Sarno', 'Sarti', 'Sartini', 'Sarto', 'Savona', 'Scarpa', 'Scarsi', 'Scavo', 'Sciacca', 'Sciacchitano', 'Sciarra', 'Scordato', 'Scotti', 'Scutese', 'Sebastiani', 'Sebastino', 'Segreti', 'Selmone', 'Selvaggio', 'Serafin', 'Serafini', 'Serpico', 'Sessa', 'Sgro', 'Siena', 'Silvestri', 'Sinagra', 'Sinagra', 'Soldati', 'Somma', 'Sordi', 'Soriano', 'Sorrentino', 'Spada', 'Spano', 'Sparacello', 'Speziale', 'Spini', 'Stabile', 'Stablum', 'Stilo', 'Sultana', 'Tafani', 'Tamaro', 'Tamboia', 'Tanzi', 'Tarantino', 'Taverna', 'Tedesco', 'Terranova', 'Terzi', 'Tessaro', 'Testa', 'Tiraboschi', 'Tivoli', 'Todaro', 'Toloni', 'Tornincasa', 'Toselli', 'Tosetti', 'Tosi', 'Tosto', 'Trapani', 'Traversa', 'Traversi', 'Traversini', 'Traverso', 'Trucco', 'Trudu', 'Tumicelli', 'Turati', 'Turchi', 'Uberti', 'Uccello', 'Uggeri', 'Ughi', 'Ungaretti', 'Ungaro', 'Vacca', 'Vaccaro', 'Valenti', 'Valentini', 'Valerio', 'Varano', 'Ventimiglia', 'Ventura', 'Verona', 'Veronesi', 'Vescovi', 'Vespa', 'Vestri', 'Vicario', 'Vico', 'Vigo', 'Villa', 'Vinci', 'Vinci', 'Viola', 'Vitali', 'Viteri', 'Voltolini', 'Zambrano', 'Zanetti', 'Zangari', 'Zappa', 'Zeni', 'Zini', 'Zino', 'Zunino']

接下来,就是要对字母进行one-hot编码,转成 tensor。

假设字母表中的字母数量为 n_letters , 一个字母的向量就是 < 1 × n_letters > 维,只有 1 个维度是1,其他 n_letters-1 维是0。

一个长度为 line_length 的单词,它的向量维度是 < line_length × n_letters > 维。

在机器学习中,通常我们会按照 batch 来训练,所以这里设定一个单词的 batch 是1,单词的向量维度变成了 < line_length × 1 × n_letters >

import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 返回字母 letter 的索引 index
def letterToIndex(letter):return all_letters.find(letter)# 把一个字母编码成tensor
def letterToTensor(letter):tensor = torch.zeros(1,n_letters)# 把字母 letter 的索引设定为1,其它都是0tensor[0][letterToIndex(letter)] = 1return tensor.to(device)# 把一个单词编码成tensor
def lineToTensor(line):tensor = torch.zeros(len(line),1,n_letters)# 遍历单词中的所有字母,对每个字母 letter 它的索引设定为1,其它都是0for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensor.to(device)print(letterToTensor('J'))
print(lineToTensor('Jones').size())

Out:

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0.]], device='cuda:0')
torch.Size([5, 1, 57])

模型搭建

然后就是我们的模型部分,一个最普通的RNN。

它是一个两层的结构,i2h是输入xt到隐藏层ht,i2o是输入xt到输出ot,softmax是把输出ot变成预测值yt。实际上,它在这里是一个 LogSoftmax 函数,对应的损失函数是NLLLoss(),而如果它是一般的 Softmax 函数,对应的损失函数就是交叉熵损失 CrossEntropy() = Log (NLLLoss())。

这里设定隐藏层的向量维度为128维,为了简单,可以说是隐藏层的大小是128维。

模型真正的运行步骤在 forward()函数中,它的输入input即为xt,隐藏层hidden即为ht

  • combined = torch.cat((input,hidden),1):把xt和上一步的ht-1拼接在一起,变成[xt,ht-1]
  • hidden = self.i2h(combined):我们之前说,把输入[xt,ht-1]乘上权重Wh变成新的隐藏层ht=Wh[xt,ht-1],这里实际上就是通过一个线性的全连接层 i2h,它的输入大小是 input_size + hidden_size, 输出大小是 hidden_size
  • output = self.i2o(combined):把输入[xt,ht-1]乘上权重Wo,得到输出ot=Wo[xt,ht-1],同样是通过一个线性的全连接层i2o,它的输入大小是 input_size + hidden_size, 输出大小是 output_size
  • output = self.softmax(output):通过一个softmax函数,把输出ot变成预测值yt=logsoftmax(ot)
import torch.nn as nnclass RNN(nn.Module):# 初始化定义每一层的输入大小,输出大小def __init__(self, input_size, hidden_size, output_size):super(RNN,self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(input_size + hidden_size, hidden_size)self.i2o = nn.Linear(input_size + hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)# 前向传播过程def forward(self, input, hidden):combined = torch.cat((input,hidden),1) hidden = self.i2h(combined)output = self.i2o(combined)output = self.softmax(output)return output, hidden# 初始化隐藏层状态 h0  def initHidden(self):return torch.zeros(1,self.hidden_size).to(device)n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)
rnn = rnn.to(device)

输入一个字母 A 测试一下:

Input = letterToTensor('A')
hidden = torch.zeros(1, n_hidden).to(device)output, next_hidden = rnn(Input, hidden)
print(output)

Out:

tensor([[-2.8634, -2.8132, -2.9685, -2.8825, -2.9207, -2.8890, -2.9643, -2.8836,-2.9332, -2.9182, -2.8699, -2.8101, -2.9425, -2.8251, -2.9940, -2.8898,-2.8494, -2.8348]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)

再输入名字 Albert 的第一个字母 A 测试一下:

Input = lineToTensor('Albert')
hidden = torch.zeros(1, n_hidden).to(device)output, next_hidden = rnn(Input[0], hidden)
print(output)

Out:

tensor([[-2.8217, -2.7998, -2.8476, -2.8880, -2.9422, -2.8720, -2.8957, -2.9637,-2.9131, -2.9525, -2.9459, -2.8327, -2.8060, -2.9389, -2.8689, -3.0205,-2.9476, -2.8055]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)

定义一个函数 categoryFromOutput() 可以把yt变成对应的类别。用 Tensor.topk 选出18个概率中,概率最大的那个的下标 category_i ,就是yt的类别。

def categoryFromOutput(output):top_n, top_i = output.topk(1)category_i = top_i[0].item()return all_categories[category_i], category_iprint(categoryFromOutput(output))

Out:

('Irish', 8)

训练

因为目前模型还没有被训练,所以上面的概率可以认为是随机产生的。接下来,我们要训练模型。这个教程里不是把所有的数据都拿来训练,而是随机采样一部分数据来训练。

用 randomChoice() 从所有数据中随机采样,先采样得到类别category,再从类别category中随机采样,得到姓名line。

randomTrainingExample() 将采样得到的 category-line对变成tensor。

import randomdef randomChoice(l):return l[random.randint(0,len(l)-1)]def randomTrainingExample():category = randomChoice(all_categories)line = randomChoice(category_lines[category])category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long).to(device)line_tensor = lineToTensor(line)return category, line, category_tensor, line_tensorfor i in range(10):category, line, category_tensor, line_tensor = randomTrainingExample()print('category = ', category, '/ line = ', line)

看看随机采样10个样本的情况:
Out:

category =  Scottish / line =  Mckenzie
category =  Irish / line =  Cormac
category =  German / line =  Farber
category =  French / line =  David
category =  Russian / line =  Yanaslov
category =  Korean / line =  Gwang
category =  Chinese / line =  Hiu
category =  Russian / line =  Turchak
category =  Portuguese / line =  Madeira
category =  Spanish / line =  Castillo

定义损失函数为 NLLLoss(), 学习率0.005。

在训练的每个循环会执行以下过程:

  1. 创建输入tensor和目标tensor
  2. 初始化隐藏层状态h0
  3. 输入每个字母xt并且
  4. 保存下一个字母需要的隐藏层状态ht
  5. 将模型预测的输出yt和目标yt*之间进行对比
  6. 梯度反向传播
  7. 返回输出和损失函数
criterion = nn.NLLLoss()
learning_rate = 0.005def train(category_tensor, line_tensor):hidden = rnn.initHidden()rnn.zero_grad()# RNN的循环for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i],hidden)loss = criterion(output, category_tensor)loss.backward()# 更新参数for p in rnn.parameters():p.data.add_(p.grad.data, alpha=-learning_rate)return output, loss.item()

下面正式开始训练模型。

timeSince() 可以计算出训练时间。总共训练n_iters次,每次用1个样本作为训练。每 print_every 次打印当前的训练损失,每 plot_every 次把损失保存到 all_losses 数组中,便于之后画图。

import time
import mathn_iters = 100000
print_every = 5000
plot_every = 1000current_loss = 0
all_losses = []def timeSince(since):now = time.time()s = now-sincereturn '%dm %ds'%(s//60,s%60)start = time.time()for iter in range(1, n_iters + 1):category, line, category_tensor, line_tensor = randomTrainingExample()output, loss = train(category_tensor, line_tensor)current_loss += lossif iter % print_every == 0:guess, guess_i = categoryFromOutput(output)correct = '√' if guess==category else '×(%s)'%categoryprint('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter/n_iters*100,timeSince(start),loss,line,guess,correct))if iter % plot_every == 0:all_losses.append(current_loss/plot_every)current_loss = 0

Out:

5000 5% (0m 12s) 1.8748 Yun / Chinese ×(Korean)
10000 10% (0m 23s) 1.4919 Adamczak / Polish √
15000 15% (0m 34s) 2.3264 Chavarria / Russian ×(Spanish)
20000 20% (0m 45s) 1.9709 Dziedzic / Russian ×(Polish)
25000 25% (0m 56s) 1.5231 Kang / Chinese ×(Korean)
30000 30% (1m 7s) 2.3836 Baudin / Irish ×(French)
35000 35% (1m 19s) 1.3130 Hoang / Vietnamese √
40000 40% (1m 30s) 3.3299 Gushiken / Dutch ×(Japanese)
45000 45% (1m 41s) 0.9776 Suarez / Spanish √
50000 50% (1m 53s) 0.5342 To / Vietnamese √
55000 55% (2m 5s) 0.7628 Barros / Portuguese √
60000 60% (2m 18s) 0.4310 O'Neal / Irish √
65000 65% (2m 29s) 2.0392 Shannon / English ×(Irish)
70000 70% (2m 40s) 1.4804 Sauvageau / Scottish ×(French)
75000 75% (2m 54s) 0.5012 Mizuno / Japanese √
80000 80% (3m 5s) 0.0978 Auttenberg / Polish √
85000 85% (3m 17s) 1.2776 Nisi / Japanese ×(Italian)
90000 90% (3m 27s) 1.2932 Rian / Irish √
95000 95% (3m 37s) 0.0962 Coghlan / Irish √
100000 100% (3m 47s) 0.9925 Xiang / Chinese √

画图

画出损失函数随着训练的变化情况:

import matplotlib.pyplot as plt
import matplotlib.ticker as tickerplt.figure()
plt.plot(all_losses)


为了看看模型在各个分类上的预测情况,我们要画出18国语言的混淆矩阵。每一行是真实的语言,每一列是预测的语言。用函数 evaluate() 来计算混淆矩阵。evaluate()和 train()非常相似,但是不需要梯度反向传播。

confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000def evaluate(line_tensor):hidden = rnn.initHidden()for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i],hidden)return outputfor i in range(n_confusion):category, line, category_tensor,line_tensor = randomTrainingExample()output = evaluate(line_tensor)guess, guess_i = categoryFromOutput(output)category_i = all_categories.index(category)confusion[category_i][guess_i] += 1for i in range(n_categories):confusion[i] = confusion[i] / confusion[i].sum()fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)ax.set_xticklabels(['']+all_categories,rotation=90)
ax.set_yticklabels(['']+all_categories)ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))plt.show()


两种语言连线处的正方形颜色越偏向暖色,表示两种语言的姓名越相似。从图中可以看到,有一些比较容易混淆语言,比如Chinese和Korean,还有Chinese和Vietnamese,English和Scottish。


预测

对于每个名字 input_line ,每次预测 n_predictions=3 个最有可能的类别,并且输出它们对应的概率:

def predict(input_line, n_predictions=3):print('\n> %s'%input_line)with torch.no_grad():output = evaluate(lineToTensor(input_line))topv, topi = output.topk(n_predictions,1,True)predictions = []for i in range(n_predictions):value = topv[0][i].item()category_index = topi[0][i].item()print('(%.2f) %s' % (value, all_categories[category_index]))predictions.append([value, all_categories[category_index]])predict('Dovesky')
predict('Jackson')
predict('Satoshi')

Out:

> Dovesky
(-1.14) Russian
(-1.23) Czech
(-1.41) Polish> Jackson
(-0.74) Scottish
(-1.17) English
(-2.67) Czech> Satoshi
(-1.54) Portuguese
(-1.64) Italian
(-1.82) Polish

【Pytorch官方教程】从零开始自己搭建RNN1 - 字母级RNN的分类任务相关推荐

  1. PyTorch官方教程大更新:增加标签索引,更加新手友好

    点击上方↑↑↑"视学算法"关注我 来源:公众号 量子位 授权 PyTorch官方教程,现已大幅更新: 提供标签索引,增加主题分类,更加新手友好. 不必再面对一整页教学文章茫然无措, ...

  2. 撒花!PyTorch 官方教程中文版正式上线,激动人心的大好事!

    点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 什么是 PyTorch?其实 PyTorch 可以拆成两部分:Py+Torch.Py 就是 P ...

  3. PyTorch官方教程大更新:增加标签索引,新手更加友好

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作分享,不代表本公众号立场,侵权联系删除 转载于:量子位 AI博士笔记系列推荐 周志华<机器学习>手推 ...

  4. pytorch官方教程中文版(一)PyTorch介绍

    pytorch编程环境是1.9.1+cu10.2 建议有能力的直接看官方网站英文版! 下面所示是本次教程的主要目录: pytorch官方教程中文版: PyTorch介绍 学习PyTorch 图像和视频 ...

  5. PyTorch 官方教程发布,限时免费开放!

    点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 PyTorch 如今已经称为最受欢迎的深度学习框架之一了!2019年1月到6月底,在arXiv ...

  6. PyTorch官方教程《Deep Learning with PyTorch》开源分享,LeCun力荐,通俗易懂

    1 前言 谈到深度学习框架,就不得不谈TensorFlow 和 PyTorch .目前来看,TensorFlow 和 PyTorch 框架是业界使用最为广泛的两个深度学习框架, TensorFlow在 ...

  7. PyTorch-Tutorials【pytorch官方教程中英文详解】- 1 Quickstart

    在PyTorch深度学习实践概论笔记5-课后练习2:pytorch官方教程[中英讲解]中跟着刘老师课后练习给的链接学习了pytorch官方教程,后来发现现在有更新版的教程,有时间正好也一起学习一下. ...

  8. Dynamic Quantization PyTorch官方教程学习笔记

    诸神缄默不语-个人CSDN博文目录 本文是PyTorch的教程Dynamic Quantization - PyTorch Tutorials 1.11.0+cu102 documentation的学 ...

  9. pytorch官方教程中文版(二)学习PyTorch

    pytorch编程环境是1.9.1+cu10.2 建议有能力的直接看官方网站英文版! 下面所示是本次教程的主要目录: pytorch官方教程中文版: PyTorch介绍 学习PyTorch 图像和视频 ...

  10. 机器翻译学习1:pytorch官方教程与代码逐行详解

    官方教程网址:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html 代码所需数据源:https:// ...

最新文章

  1. 四路由器的OSPF DR ,BDR
  2. js(Dom+Bom)第八天—Swiper(插件)
  3. 文档上传到服务器上,将文件上传到服务器上
  4. button点击后变色_汽车改色膜新潮流,2021年流行渐变色
  5. 求0-999之间的水仙花数。
  6. Boostrap(2)
  7. iOS设计模式解析(五)责任链模式
  8. zul组件、zhtml组件、native组件的区别
  9. AxureShare太慢,自己搭建产品原型分享系统
  10. 电脑硬件故障排除经验
  11. Win10开启高性能、卓越性能模式的方法
  12. 【运筹学】整数规划 ( 整数规划求解方法 | 指派问题 )
  13. 小米机型安全删除内置软件列表 miui12 miui13 可删除内置
  14. 目标检测之YOLOv4算法分析
  15. 微博怎么批量取消所有的关注
  16. [LeetCode] Largest Perimeter Triangle
  17. kali安装中文拼音输入法2
  18. 时态二--(专升本语法)
  19. LSTM(long short term memory)长短期记忆网络
  20. Benchmarking Lane-changing Decision-making for Deep Reinforcement Learning

热门文章

  1. php 知乎源代码,PHP最新仿知乎问答社区源码下载带行业打赏问答支持文章、话题、第三方登录、文章和问题打赏...
  2. Matlab 仿真——直流电机速度控制(1)直流电机建模
  3. CPU(中央处理器)和GPU(图像处理器)区别
  4. 窗口函数preceding和following字段
  5. 华为S5700_交换机_基础管理配置
  6. R语言 | 计算基因表达量 TPM R脚本
  7. Real-Time Rendering 4th Edition 实时渲染第四版 第五章 着色基础(Shading Basics)
  8. Python爬虫120例之案例58,手机APP爬虫,“武器库”的准备and皮皮虾APP的测试
  9. 找手机ic库存回收公司
  10. javaweb验证码明明输入正确却还是提示错误,验证码session不同步、不一致问题