tf.nn.nce_loss(weights,biases,labels,inputs,num_sampled,num_classes,num_true=1,sampled_values=None,remove_accidental_hits=False,partition_strategy='mod',name='nce_loss'
)

对于 nce_loss 的了解源于 word2vec,主要是通过 负采样的方式减少 softmax 函数的计算,具体函数值讲解,可以字节看疼我tensorflow,这里提供一个简单的案例,说明怎么使用 nce_loss。

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mathtrain_loss_lst = []
train_accuracy_lst = []
test_accuracy_lst = []def deepn(x):"""Args:x: an input tensor with the dimensions (N_examples, 784), where 784 is thenumber of pixels in a standard MNIST image.Returns:y: (N_examples, 64)X(N_examples, 784) x w1(784, 64) + b1(64,) = y(N_examples, 64)"""w1 = tf.Variable(initial_value=tf.truncated_normal(shape=[784, 64]),name="w1")b1 = tf.Variable(initial_value=tf.random_uniform(shape=[64,], minval=0, maxval=1))fc1 = tf.matmul(x, w1) + b1keep_prob = tf.placeholder(tf.float32)fc1_drop = tf.nn.dropout(fc1, keep_prob)return fc1_drop, keep_probdef main():# Import datamnist = input_data.read_data_sets("/mnist_data", one_hot=True)X = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])y_idx = tf.placeholder(tf.float32, [None, 1])fc1_drop, keep_prob = deepn(X)num_sampled = 1vocabulary_size = 10embedding_size = 64nce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],stddev=1.0/math.sqrt(embedding_size)),name="embed")nce_biases = tf.Variable(tf.zeros([vocabulary_size]))loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weights,biases=nce_biases,labels=y_idx,inputs=fc1_drop,num_sampled=num_sampled,num_classes=vocabulary_size),)train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)output = tf.matmul(fc1_drop, tf.transpose(nce_weights)) + nce_biasescorrect_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(20000):print("num: %d"%i)batch = mnist.train.next_batch(20)idx_ = np.argmax(batch[1], axis=1)[:, np.newaxis].astype("float32")train_accuracy, train_loss, _ = sess.run([accuracy, loss, train_step],feed_dict={X: batch[0],y_: batch[1],y_idx: idx_,keep_prob: 1.0})print("loss: ", train_loss)print("train accuracy: ", train_accuracy)idx_ = np.argmax(mnist.test.labels, axis=1)[:, np.newaxis].astype("float32")test_accuracy = sess.run(accuracy,feed_dict={X: mnist.test.images,y_: mnist.test.labels,y_idx: idx_,keep_prob: 1.0})print("test accuracy: ", test_accuracy)train_loss_lst.append(train_loss)train_accuracy_lst.append(train_accuracy)test_accuracy_lst.append(test_accuracy)def test():mnist = input_data.read_data_sets("./mnist_data", one_hot=True)X = tf.placeholder(tf.float32, [None, 784])fc1_drop, keep_prob = deepn(X)with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(2):batch = mnist.train.next_batch(20)idx_ = np.argmax(batch[1], axis=1)[:, np.newaxis].astype("float32")fc1 = sess.run([fc1_drop], feed_dict={X: batch[0], keep_prob:1.0})print(np.array(fc1).shape)main()
def summary(x, tag, path):"""根据提供的 x 列表绘制曲线"""print(x)loss = 0.0# tf.summary模块的定义位于summary.py文件中,该文件中主要定义了在进行可视化将要用到的各种函数loss_summary = tf.Summary()# 调用tf.summary.Summary.Value子类loss_summary.value.add(tag=tag, simple_value=loss)  # tag就是待会产生的图标名称with tf.Session() as sess:# 生成一个写日志的writer,将当前tensorflow计算图写入日志。summary_writer1 = tf.summary.FileWriter(path, sess.graph)tf.global_variables_initializer().run()for i in range(len(x)):# 固定用法,具体为什么我也不懂loss_summary.value[0].simple_value = x[i]summary_writer1.add_summary(loss_summary, i)summary(train_loss_lst, tag="loss", path="./train_loss")
summary(train_accuracy_lst, tag="accuracy", path="./train_accuracy")
summary(test_accuracy_lst, tag="accuracy", path="./test_accuracy")

tf.nn.nce_loss 函数应用案例相关推荐

  1. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  2. tf.nn.embedding_lookup函数的用法

    tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素.tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等 ...

  3. tf.nn.moments( )函数的使用

    先来看个例子吧. 测试代码: x = np.arange(12,dtype=np.float32).reshape(3,4) a = tf.nn.moments(tf.constant(x),[0]) ...

  4. tf.nn.nce_loss分析

    tf官方代码 def nce_loss(weights,biases,labels,inputs,num_sampled,num_classes,num_true=1,sampled_values=N ...

  5. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits 函数:求交叉熵损失

    [TensorFlow]tf.nn.softmax_cross_entropy_with_logits的用法_xf__mao的博客-CSDN博客 https://blog.csdn.net/mao_x ...

  6. tf.nn.bidirectional_dynamic_rnn()函数详解

    转载自:https://blog.csdn.net/zhylhy520/article/details/86364789 首先我们了解一下函数的参数 bidirectional_dynamic_rnn ...

  7. tf.nn.conv2d()函数详解(strides与padding的关系)

    tf.nn.conv2d()是TensorFlow中用于创建卷积层的函数,这个函数的调用格式如下: def conv2d(input: Any,filter: Any,strides: Any,pad ...

  8. tf.nn.leaky_relu()函数

    **计算Leaky ReLU激活函数 tf.nn.leaky_relu( features, alpha=0.2, name=None ) 参数: features:一个Tensor,表示预激活 al ...

  9. 深度学习-函数-tf.nn.embedding_lookup 与tf.keras.layers.Embedding

    embedding函数用法 1. one_hot编码 1.1. 简单对比 1.2.优势分析: 1.3. 缺点分析: 1.4. 延伸思考 2. embedding的用途 2.1 embedding有两个 ...

最新文章

  1. css的padding
  2. linux 修改ssh banner
  3. 虚拟Web主机(基于域名配置,基于ip地址,基于端口)
  4. StarWind RAM 磁盘仿真程序
  5. oracle非管理员锁表,oracle默认管理员的帐号和密码以及密码修改和解除锁定
  6. jQuery easing动画效果扩展
  7. 尚品汇Vue项目 前台+后台完成品源码(含在线演示)
  8. 推荐算法初探---CF、LR
  9. 业务安全漏洞挖掘归纳总结
  10. html添加B站视频,iframe嵌入BiliBili视频方法B站视频外链
  11. # 搜狗输入法~快捷键总结
  12. 身为一名Java程序员,在面试的时候常常被问到的,下面我总结一些常常别问到的问题。
  13. flowable modeler6.5.0集成spring boot
  14. Conflux DAO 社区技术委员会成立 助力生态繁荣发展
  15. 开课吧java广告,开课吧Java面试题:虚引用与软引用和弱引用的区别
  16. 启动virtualbox虚拟机显示Attempted to kill the idle task错误
  17. 关于win10 链接安卓设备报错winusb.sys未经签名的解决办法
  18. Microsoft Visio 2010 - 弧线
  19. python一个月收入_我月薪5000,靠Python搞副业月入3万
  20. e575 viminfo 错误.

热门文章

  1. yarn run lint
  2. Unity 3D : 解富士 RAF 檔案
  3. squirrel sql 使用
  4. 浪潮服务器外接显示器,浪潮服务器简易配置手册.doc
  5. 百度云视频利用chrome进行倍速播放
  6. 斌酱归档---C语言实现Linux cp命令
  7. 3D屏保: 彩色盘子
  8. 在农村种植什么赚钱快又赚钱,不妨来看看这4种种植项目!
  9. 使用QQ账号,新浪微博账号登录第三方应用
  10. 智慧采购管理系统电子招投标优势浅析,助力建筑工程企业高效做好采购管理工作