使⽤tensorflow完成模型搭建和训练,实现对fashion_mnist数据集的分类

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import datasets, layers, optimizers, Sequential, metricsassert tf.__version__.startswith('2.')def preprocess(x, y):#预处理x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,y(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print(x.shape, y.shape)batchsz = 128db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(10000).batch(batchsz)db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).batch(batchsz)db_iter=iter(db)
sample=next(db_iter)
print('batch:',sample[0].shape,sample[1].shape)model = Sequential([layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
])
model.build(input_shape=[None, 28*28])
model.summary()
#w=w-lr*grad
optimizer=optimizers.Adam(lr=1e-3)#优化器def main():for epoch in range(30):for step, (x,y) in enumerate(db):# x: [b, 28, 28] => [b, 784]# y: [b]x = tf.reshape(x, [-1, 28*28])with tf.GradientTape() as tape:# [b, 784] => [b, 10]logits = model(x)y_onehot = tf.one_hot(y, depth=10)# [b]loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss_ce = tf.reduce_mean(loss_ce)grads = tape.gradient(loss_ce, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))# testtotal_correct = 0total_num = 0for x,y in db_test:# x: [b, 28, 28] => [b, 784]# y: [b]x = tf.reshape(x, [-1, 28*28])# [b, 10]logits = model(x)# logits => prob, [b, 10]prob = tf.nn.softmax(logits, axis=1)# [b, 10] => [b], int64pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)# pred:[b]# y: [b]# correct: [b], True: equal, False: not equalcorrect = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))total_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

其中

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

assert tf.__version__.startswith('2.')

这二个代码不知道怎么回事,请大佬们知道的,告诉我哦!

fashion minst相关推荐

  1. 嵌入和降维投影 数据集--fashion minst

    工具的评价 这个仅仅是视觉化的工具,可以帮助理解模型和数据,因为是降维投影,损失了信息,并不能作为提高模型效果的算法. 数据集和模型 数据集--fashion minst 简单的模型--2层全连接,效 ...

  2. 使用TensorFlow进行深度学习-第2部分

    Hi All, this is a series of blogs that I intend to write about how to use TensorFlow 2.0 for deep le ...

  3. TensorFlow中的Fashion MNIST图像识别实战

    1.导入相应的库: 关于Fashion MNIST数据集的介绍:看这位博主: https://blog.csdn.net/qq_28869927/article/details/85079808 im ...

  4. Fashion MNIST自编码器网络实战

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import tensorflow as kera ...

  5. 使用神经网络做二分类,输出层需要几个神经元?应该选择哪一种激活函数?如果要处理minst数据、输出层需要几个神经元?使用那种激活函数?如果使用神经网络预测房价,输出层需要几个神经元、使用什么激活函数?

    使用神经网络做二分类,输出层需要几个神经元?应该选择哪一种激活函数?如果要处理minst数据.输出层需要几个神经元?使用那种激活函数?如果使用神经网络预测房价,输出层需要几个神经元.使用什么激活函数? ...

  6. signature=17cdfa42b38e299201383f4fa6ccc23f,EYE FOR FASHION

    摘要: This article features photos of celebrities and comments regarding their fashion choices. Singer ...

  7. FASHION STORE OPENCART 2.X 自适应主题模板 ABC-0588

    2019独角兽企业重金招聘Python工程师标准>>> FASHION STORE OPENCART 2.X 自适应主题模板 ABC-0588 FEATURES HTML5 and  ...

  8. Minst 0-9特征迭代次数曲线表达式

    本文尝试收集minst 0-9对不同收敛标准δ的特征迭代次数,并拟合n(δ)曲线. 制作一个有三个输出的网络,输入minst数据集0-9的前200张图片, 将这个网络简写成 S(Minst 0-9)8 ...

  9. 用特征迭代次数区分minst数据集的0和1

    既然前面大量的实验都证明了,对于特定结构特定收敛标准的网络的收敛迭代次数是特征的,而这个值和输入有关,那能不能用这个特性去用来对输入进行分类. 本文制作了一个81*11*11-11*11*1的网络 让 ...

  10. mysql is fashion ctf_《亲爱的,热爱的》中的 CTF 大赛是什么?参加这个比赛的体验怎么样?...

    一个CTF选手来回答一下参赛体验 1.很休闲,很养生.什么36小时,48小时,72小时持续作战,不存在的!什么8罐红牛,两箱雀巢,不存在的!什么通宵熬夜黑眼圈,不存在的! 2.很娱乐,很友善.比赛中, ...

最新文章

  1. layoutSubviews 调用情况
  2. 大脑如何判断该睡觉了?可能是这80种蛋白说了算
  3. sql server 数据库 ' ' 附近有语法错误
  4. Java 策略模式和状态模式
  5. Java初学者必知 关于Java字符串问题
  6. curl命令php,php生成curl命令行的方法
  7. python安装nodejs_linux上nodejs安装
  8. sql vb xml 换行_vb中换行代码 vb代码输出怎么换行
  9. 山科大离散数学期末考试_西安电子科技大学网络与继续教育学院 2020 学年上学期 《离散数学》期末考试试题...
  10. request域中放入参数几种方法
  11. oswatch的安装和使用
  12. gltf 2.0快速入门
  13. linux桥接模式配置
  14. Linux系统根目录详解
  15. Android Studio 3.1 正式版
  16. 计算机系统必须配置,AI运行需要什么电脑配置?(复杂路径,且流畅)
  17. MATLAB--数字图像处理 im2col()
  18. 陈力:传智播客古代 珍宝币 泡泡龙游戏开发第46讲:PHP程序设计中的session应用实例
  19. Quartus16怎么修改IP核
  20. palm 680入手使用记录

热门文章

  1. Delphi2007的重构功能
  2. Win11字体显示不全怎么解决?
  3. APISpace 手机号码归属地API
  4. 大学excel题库含答案_Excel练习题及答案
  5. NVIDIA更新驱动之后,NVIDIA控制面板消失不见的解决办法
  6. java poi jar包下载_poi.jar包下载
  7. 协创物联网合肥产业园项目远程预付费电能管理系统的设计与应用
  8. 图书馆占座系统-产品需求规格说明书
  9. 推荐C语言视频<<跟着星仔学C语言>>
  10. 如何在Vue项目中使用websql数据库