Tensorflow Embedding
一、随机初始化Embedding
1.1 原理
Embedding其实就是个lookup table, 通过tf.nn.embedding_lookup()来调用Embedding.
注意:在调用Embedding后,可以考虑使用dropout层。
注意:在Embedding内,可以考虑对提取的vector做缩放。见于《Attention is all you need》
1.2 示例代码
这是关于Embedding层的相关代码。
def embedding(inputs, vocab_size, num_units, pre_embed=None, scale=True,scope="embedding", reuse=None):with tf.variable_scope(name_or_scope=scope, reuse=reuse):if pre_embed is not None:# 如果指定了词表. lookup_table = tf.get_variable(name="lookup_table",initializer=tf.Variable(pre_embed, dtype=tf.float32),trainable=False)else:lookup_table = tf.get_variable(name="lookup_table",shape=[vocab_size, num_units],initializer=tf.contrib.layers.xavier_initializer())outputs = tf.nn.embedding_lookup(lookup_table, inputs)if scale:# attention is all you need. 中使用了.outputs = outputs * (num_units ** 0.5)return outputs
下面的代码是:
- load_my_vocab:
加载模型中使用的vocab词表,用于在预训练的Embedding中,查找对应向量。 - get_vocab_embedding:
从预训练的Embedding中,提取需要的向量,并返回。
def load_my_vocab(vocab_file):vocab_list = []with open(vocab_file, 'r', encoding='utf-8') as fr:line = fr.readline()while line:vocab_list.append(line.strip())line = fr.readline()return vocab_listdef get_vocab_embedding(vocab_list, pre_embedding_file):lookup_dict = {}for vocab in vocab_list:lookup_dict[vocab] = []with open(pre_embedding_file, 'r', encoding='utf-8') as fr:line = fr.readline().strip("\n")while line:vocab, embed = line.split(" ", 1)if vocab in lookup_dict:embed = [float(digit_str) for digit_str in embed.strip().split()]lookup_dict[vocab].append(embed)line = fr.readline().strip("\n")lookup_table = [np.mean(lookup_dict[vocab], axis=0) for vocab in vocab_list]return np.array(lookup_table)
下面是主程序:
if __name__ == "__main__":vocab_file = "./vocab.txt"pre_embedding_file = "./glove.840B.300d.txt"vocab_list = load_my_vocab(vocab_file)lookup_table = get_vocab_embedding(vocab_list, pre_embedding_file)vocab_size = 5inputs = tf.Variable([[1, 2, 3], [2, 2, 2], [0, 0, 1]], dtype=tf.int32)word_embed = embedding(inputs=inputs, vocab_size=5, num_units=300, pre_embed=lookup_table)with tf.Session() as sess:sess.run(tf.global_variables_initializer())outputs = sess.run(word_embed)print(outputs)
Tensorflow Embedding相关推荐
- TensorFlow 2.X中的动手NLP深度学习模型准备
简介:为什么我写这篇文章 (Intro: why I wrote this post) Many state-of-the-art results in NLP problems are achiev ...
- 解决tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,32] = -1 is not in [0, 50)
今晚遇到了一个很神奇的问题,在tensorflow embedding过程中: File "/root/Handwriting/Diffusion-Handwriting-Generatio ...
- TensorFlow 特征列介绍
文 / TensorFlow 团队 欢迎阅读介绍 TensorFlow 数据集和估算器系列的第 2 部分(第一部分戳这里).我们将在这篇文章中介绍特征列 (Feature Column) - 一种说明 ...
- TensorFlow Estimator 官方文档之----Feature column
Feature column 本文档详细介绍了特征列(feature columns).您可以将特征列视为原始数据和 Estimator 之间的媒介.特征列非常丰富,使您可以将各种原始数据转换为 Es ...
- Metapath2vec:Scalable Representation Learning for Heterogeneous Networks(结构化深度网络特征表示)
目录 1.图嵌入背景介绍 1.1 什么是图嵌入 1.2 为什么要使用图嵌入 2.论文背景介绍 2.1 同质网络 & 异质网络 2.2 异质网络与Metapath2vec 3.Metapath2 ...
- [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表
[源码解析] NVIDIA HugeCTR,GPU版本参数服务器- (5) 嵌入式hash表 文章目录 [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表 ...
- 利用卷积神经网络对座头鲸进行声学探测
文 / Matt Harvey,Google AI Perception 软件工程师 在过去几年中,Google AI Perception 团队开发出音频事件分析技术,并将其应用于 YouTube ...
- TensorFlow tf.keras.layers.Embedding
参数 参数 描述 input_dim 词汇表的维度(总共有多少个不相同的词) output_dim 嵌入词空间的维度 input_length 输入语句的长度 embeddings_initializ ...
- tensorflow基础知识10,三维可视化embedding
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.c ...
最新文章
- 揭秘双11丝滑般剁手之路背后的网络监控技术
- 根据一级分类查询所有子级分类
- 使用Java的RESTful Web服务
- matplotlib —— 添加文本信息(text)
- 闲鱼如何利用端计算提升推荐场景的ctr
- Speerio Skinergy 'Image' is ambiguous 错误
- qq浏览器私密空间在哪 具体操作步骤
- 数据预处理与数据分类预测
- 极光推送---安卓Demo
- oracle awr报告生成_5.性能测试 - Oracle体系结构和性能优化简介
- Cannot read property 'style' of null 问题
- 适合Java零基础学习的视频教程资源合集(小白入门到项目实战)
- 游戏测试和软件测试哪个好点?
- 全网最详细numpy的argmin与argmax解析(一次性理解np.argmin)
- Shopee关键词广告投放策略解析-马六甲erp
- 生物特征识别六大技术,你知道多少?
- 2 第二章 集群环境搭建(kubeadm 方式)
- Prometheus pod 流量监控
- Discuz论坛 创始人密码忘记解决办法!
- CCPC-Wannafly Comet OJ 夏季欢乐赛(2019)部分题解
热门文章
- html5css设置链接颜色,html超链接颜色设置
- 13.溯源分析(寻找攻击目标,警方破案)
- 神州租车:为消费者出行增添幸福感
- APIO2016 Fireworks
- 开发微信小程序实现上传图片 拍照功能
- Vue-i18n在Routerd动态路由下实现国际化
- Mask R-CNN:UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot sho
- FPGA-基本IP核的应用-FIFO(同步)
- 《视觉SLAM进阶:从零开始手写VIO》第一讲作业
- 使用 ESP-Prog _ Jlink 进行 JTAG 调试时的常见错误及解决办法