every blog every motto:

0. 前言

以fashion_mnist 为例,自定义流程,针对一机多卡的情况。

1. 代码部分

1. 导入模块

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras# os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0'
print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)

2. GPU设置

tf.debugging.set_log_device_placement(True) # 查看变量分布在哪个GPU上
gpus = tf.config.experimental.list_physical_devices('GPU') # 获取物理GPU
print(gpus)# 设置选中的(最后一个)GPU可见
tf.config.experimental.set_visible_devices(gpus[-1],'GPU')for gpu in gpus: # 物理GPU 设置成自增长tf.config.experimental.set_memory_growth(gpu,True)print(len(gpus))
print('='*10)
logical_gpus = tf.config.experimental.list_logical_devices('GPU') # 获取逻辑GPU
print(len(logical_gpus))

3. 数据读取与处理

3.1 读取数据

fashion_mnist = keras.datasets.fashion_mnist
# print(fashion_mnist)
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 打印格式
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)

3.2 数据归一化

# 数据归一化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
# x_train:[None,28,28] -> [None,784]
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)

3.3 生成dataset

# 生成dataset
def make_dataset(images,labels,epochs,batch_size,shuffle=True):dataset = tf.data.Dataset.from_tensor_slices((images,labels))if shuffle:dataset = dataset.shuffle(10000)dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)return datasetbatch_size = 256
train_dataset = make_dataset(x_train_scaled,y_train,1,batch_size)
valid_dataset = make_dataset(x_valid_scaled,y_valid,1,batch_size)

4. 构建模型

# tf.keras.models.Sequential()
# 构建模型
model = keras.models.Sequential()# 卷积神经网络
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding="same",activation='relu',input_shape=(28,28,1)))
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))# 展平
model.add(keras.layers.Flatten())# 全连接层
model.add(keras.layers.Dense(512,activation='relu'))# 输出层
model.add(keras.layers.Dense(10,activation="softmax"))
model.summary()

5. 自定义部分

# customized training loop.
# 1. define losses functions
# 2. define function train_step
# 3. define function test_step
# 4. for-loop training looploss_func = keras.losses.SparseCategoricalCrossentropy(reduction=keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
test_loss = keras.metrics.Mean(name='test_loss')train_accuracy = keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")
test_accuracy = keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")optimizer = keras.optimizers.SGD(lr=0.01)@tf.function
def train_step(inputs):images,labels = inputswith tf.GradientTape() as tape:predictions = model(images,training=True)loss = loss_func(labels,predictions)gradients = tape.gradient(loss,model.trainable_variables)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_accuracy.update_state(labels,predictions)return loss@tf.function
def test_step(inputs):images,labels = inputspredictions = model(images)t_loss = loss_func(labels,predictions)test_loss.update_state(t_loss)test_accuracy.update_state(labels,predictions)epochs = 10
for epoch in range(epochs):total_loss = 0.0num_batches = 0for x in train_dataset:start_time = time.time()total_loss += train_step(x)run_time = time.time() - start_timenum_batches += 1print('\rtotal_loss: %3.3f,num_batches: %d,average_loss: %3.3f,time: %3.3f'%(total_loss,num_batches,total_loss/num_batches,run_time),end='')train_loss = total_loss / num_batchesfor x in valid_dataset:test_step(x)print('\rEpoch: %d, Loss: %3.3f Acc: %3.3f,Val_Loss: %3.3f,Val_Acc:%3.3f'%(epoch+1,train_loss,train_accuracy.result(),test_loss.result(),test_accuracy.result()))test_loss.reset_states()train_accuracy.reset_states()test_accuracy.reset_states()

从零基础入门Tensorflow2.0 ----八、42. 自定义流程相关推荐

  1. python零基础入门教程视频下载-Python零基础入门学习视频教程全42集,资源教程下载...

    课程名称 Python零基础入门学习视频教程全42集,资源教程下载 课程目录 001我和Python的第一次亲密接触 002用Python设计第一个游戏 003小插曲之变量和字符串 004改进我们的小 ...

  2. python基础教程视频教程百度云-Python零基础入门学习视频教程全42集百度云网盘下载...

    课程简介 Python零基础入门学习视频教程全42集百度云网盘下载 课程目录 042魔法方法:算术运算 041魔法方法:构造和析构 040类和对象:一些相关的BIF 039类和对象拾遗 038类和对象 ...

  3. python基础教程百度云-Python零基础入门学习视频教程全42集百度云网盘下载

    课程简介 Python零基础入门学习视频教程全42集百度云网盘下载 课程目录 042魔法方法:算术运算 041魔法方法:构造和析构 040类和对象:一些相关的BIF 039类和对象拾遗 038类和对象 ...

  4. 视频编码零基础入门(0):零基础,史上最通俗视频编码技术入门

    [来源申明]本文引用了微信公众号"鲜枣课堂"的<视频编码零基础入门>文章内容.为了更好的内容呈现,即时通讯网在引用和收录时内容有改动,转载时请注明原文来源信息,尊重原作 ...

  5. SQL零基础入门学习(八)

    SQL零基础入门学习(七) SQL 连接(JOIN) SQL join 用于把来自两个或多个表的行结合起来. 下图展示了 LEFT JOIN.RIGHT JOIN.INNER JOIN.OUTER J ...

  6. C语言零基础入门习题(八)四则运算

    前言 C语言是大多数小白走上程序员道路的第一步,在了解基础语法后,你就可以来尝试解决以下的题目.放心,本系列的文章都对新手非常友好. Tips:题目是英文的,但我相信你肯定能看懂 一.四则运算 题目 ...

  7. 指针02 - 零基础入门学习C语言42

    第八章:指针02 让编程改变世界 Change the world by program 对"&"和"*"运算符再做些说明 如果已执行了语句 point ...

  8. Apache Flink 零基础入门(十八)Flink Table APISQL

    什么是Flink关系型API? 虽然Flink已经支持了DataSet和DataStream API,但是有没有一种更好的方式去编程,而不用关心具体的API实现?不需要去了解Java和Scala的具体 ...

  9. Apache Flink 零基础入门(十七)Flink 自定义Sink

    需求:socket发送过来的数据,把String类型转成对象,然后把Java对象保存到Mysql数据库中. 创建数据库和表 create database imooc_flink; create ta ...

  10. SQL零基础入门学习(九)

    SQL零基础入门学习(八) SQL UNION 操作符 UNION 操作符用于合并两个或多个 SELECT 语句的结果集. 请注意,UNION 内部的每个 SELECT 语句必须拥有相同数量的列.列也 ...

最新文章

  1. SAP MM 不常用事务代码之MB59
  2. 使用WinSCP上传文件到指定服务器
  3. 坐在马桶上看算法:Dijkstra最短路算法
  4. busybox date 时间的加减
  5. Python selenium chrome 环境配置
  6. 计算机课教案学法,计算机应用基础教学方法初探
  7. pytorch 与 numpy 的相互转换
  8. 第12章 决策树 学习笔记下 决策树的学习曲线 模型复杂度曲线
  9. 对称加密、非对称加密、数字签名、数字证书、签名加密
  10. jmeter beanshell 之常用的代码
  11. 世界城市与北京时差表
  12. 与以太坊同源异流,eCash“PoW+雪崩”组合共识各司其职
  13. 图像修复:专栏博文推荐查阅顺序
  14. Python人脸识别项目-人脸识别-获取人脸图片
  15. java excel 导出 下载_使用Java导出Excel表格并由浏览器直接下载
  16. python去除Excel重复项
  17. Superset 数据分析平台搭建及使用 1
  18. mysql大写数字转阿拉伯数字_阿拉伯数字转化为大写
  19. R语言实战笔记--第十五章 处理缺失数据
  20. python dataframe打乱行

热门文章

  1. 深入理解计算机系统第四版_深入理解计算机系统第三版2.4节中文版的一处翻译问题及英文版可能的一处错误...
  2. python股票交易微信提醒_python实现秒杀商品的微信自动提醒功能(代码详解)
  3. macbook python安装_mac下安装Python3.*(最新版本)
  4. 让部署到服务器上的springboot项目持续运行(nohup)
  5. python 读取access_python读取数据access出错
  6. linux 解压 7z 乱码,7z-linux下解决中文名乱码的终极办法
  7. 多台电脑集群运算_Linux服务器集群概念辨识
  8. 幂等校验是什么意思_什么是接口的幂等性,如何实现接口幂等性?一文搞定
  9. pythonsys用法_Python 使用sys模块
  10. 算法笔记_面试题_8.零钱兑换