一、随机初始化Embedding

1.1 原理

Embedding其实就是个lookup table, 通过tf.nn.embedding_lookup()来调用Embedding.

注意:在调用Embedding后,可以考虑使用dropout层。
注意:在Embedding内,可以考虑对提取的vector做缩放。见于《Attention is all you need》

1.2 示例代码

这是关于Embedding层的相关代码。

def embedding(inputs, vocab_size, num_units, pre_embed=None, scale=True,scope="embedding", reuse=None):with tf.variable_scope(name_or_scope=scope, reuse=reuse):if pre_embed is not None:# 如果指定了词表. lookup_table = tf.get_variable(name="lookup_table",initializer=tf.Variable(pre_embed, dtype=tf.float32),trainable=False)else:lookup_table = tf.get_variable(name="lookup_table",shape=[vocab_size, num_units],initializer=tf.contrib.layers.xavier_initializer())outputs = tf.nn.embedding_lookup(lookup_table, inputs)if scale:# attention is all you need. 中使用了.outputs = outputs * (num_units ** 0.5)return outputs

下面的代码是:

  • load_my_vocab:
    加载模型中使用的vocab词表,用于在预训练的Embedding中,查找对应向量。
  • get_vocab_embedding:
    从预训练的Embedding中,提取需要的向量,并返回。
def load_my_vocab(vocab_file):vocab_list = []with open(vocab_file, 'r', encoding='utf-8') as fr:line = fr.readline()while line:vocab_list.append(line.strip())line = fr.readline()return vocab_listdef get_vocab_embedding(vocab_list, pre_embedding_file):lookup_dict = {}for vocab in vocab_list:lookup_dict[vocab] = []with open(pre_embedding_file, 'r', encoding='utf-8') as fr:line = fr.readline().strip("\n")while line:vocab, embed = line.split(" ", 1)if vocab in lookup_dict:embed = [float(digit_str) for digit_str in embed.strip().split()]lookup_dict[vocab].append(embed)line = fr.readline().strip("\n")lookup_table = [np.mean(lookup_dict[vocab], axis=0) for vocab in vocab_list]return np.array(lookup_table)

下面是主程序:

if __name__ == "__main__":vocab_file = "./vocab.txt"pre_embedding_file = "./glove.840B.300d.txt"vocab_list = load_my_vocab(vocab_file)lookup_table = get_vocab_embedding(vocab_list, pre_embedding_file)vocab_size = 5inputs = tf.Variable([[1, 2, 3], [2, 2, 2], [0, 0, 1]], dtype=tf.int32)word_embed = embedding(inputs=inputs, vocab_size=5, num_units=300, pre_embed=lookup_table)with tf.Session() as sess:sess.run(tf.global_variables_initializer())outputs = sess.run(word_embed)print(outputs)

Tensorflow Embedding相关推荐

  1. TensorFlow 2.X中的动手NLP深度学习模型准备

    简介:为什么我写这篇文章 (Intro: why I wrote this post) Many state-of-the-art results in NLP problems are achiev ...

  2. 解决tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,32] = -1 is not in [0, 50)

    今晚遇到了一个很神奇的问题,在tensorflow embedding过程中: File "/root/Handwriting/Diffusion-Handwriting-Generatio ...

  3. TensorFlow 特征列介绍

    文 / TensorFlow 团队 欢迎阅读介绍 TensorFlow 数据集和估算器系列的第 2 部分(第一部分戳这里).我们将在这篇文章中介绍特征列 (Feature Column) - 一种说明 ...

  4. TensorFlow Estimator 官方文档之----Feature column

    Feature column 本文档详细介绍了特征列(feature columns).您可以将特征列视为原始数据和 Estimator 之间的媒介.特征列非常丰富,使您可以将各种原始数据转换为 Es ...

  5. Metapath2vec:Scalable Representation Learning for Heterogeneous Networks(结构化深度网络特征表示)

    目录 1.图嵌入背景介绍 1.1 什么是图嵌入 1.2 为什么要使用图嵌入 2.论文背景介绍 2.1 同质网络 & 异质网络 2.2 异质网络与Metapath2vec 3.Metapath2 ...

  6. [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表

    [源码解析] NVIDIA HugeCTR,GPU版本参数服务器- (5) 嵌入式hash表 文章目录 [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表 ...

  7. 利用卷积神经网络对座头鲸进行声学探测

    文 / Matt Harvey,Google AI Perception 软件工程师 在过去几年中,Google AI Perception 团队开发出音频事件分析技术,并将其应用于 YouTube ...

  8. TensorFlow tf.keras.layers.Embedding

    参数 参数 描述 input_dim 词汇表的维度(总共有多少个不相同的词) output_dim 嵌入词空间的维度 input_length 输入语句的长度 embeddings_initializ ...

  9. tensorflow基础知识10,三维可视化embedding

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.c ...

最新文章

  1. 揭秘双11丝滑般剁手之路背后的网络监控技术
  2. 根据一级分类查询所有子级分类
  3. 使用Java的RESTful Web服务
  4. matplotlib —— 添加文本信息(text)
  5. 闲鱼如何利用端计算提升推荐场景的ctr
  6. Speerio Skinergy 'Image' is ambiguous 错误
  7. qq浏览器私密空间在哪 具体操作步骤
  8. 数据预处理与数据分类预测
  9. 极光推送---安卓Demo
  10. oracle awr报告生成_5.性能测试 - Oracle体系结构和性能优化简介
  11. Cannot read property 'style' of null 问题
  12. 适合Java零基础学习的视频教程资源合集(小白入门到项目实战)
  13. 游戏测试和软件测试哪个好点?
  14. 全网最详细numpy的argmin与argmax解析(一次性理解np.argmin)
  15. Shopee关键词广告投放策略解析-马六甲erp
  16. 生物特征识别六大技术,你知道多少?
  17. 2 第二章 集群环境搭建(kubeadm 方式)
  18. Prometheus pod 流量监控
  19. Discuz论坛 创始人密码忘记解决办法!
  20. CCPC-Wannafly Comet OJ 夏季欢乐赛(2019)部分题解

热门文章

  1. html5css设置链接颜色,html超链接颜色设置
  2. 13.溯源分析(寻找攻击目标,警方破案)
  3. 神州租车:为消费者出行增添幸福感
  4. APIO2016 Fireworks
  5. 开发微信小程序实现上传图片 拍照功能
  6. Vue-i18n在Routerd动态路由下实现国际化
  7. Mask R-CNN:UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot sho
  8. FPGA-基本IP核的应用-FIFO(同步)
  9. 《视觉SLAM进阶:从零开始手写VIO》第一讲作业
  10. 使用 ESP-Prog _ Jlink 进行 JTAG 调试时的常见错误及解决办法