文章目录

  • 1.keras.Sequential
  • 2.Layer/Model
  • 3.自定义层
  • 4.自定义网络
  • 5.自定义网络实战-手写数字识别

1.keras.Sequential


2.Layer/Model

3.自定义层

# 自定义Dense层
class MyDense(layers.Layer):# 初始化方法def __init__(self,inp_dim,outp_dim):# 调用母类的初始化super(MyDense,self).__init__()# self.add_variable作用是在创建这两个Variable时,同时告诉类这两个variable是需要创建的# 当两个容器拼接时,会把这两个variable交给上面的容器来管理,统一管理,不需要人为管理参数# 这个函数在母类中实现,所以可以直接调用self.kernel = self.add_variable('w',[inp_dim,outp_dim])self.bias = self.add_variable('b',[outp_dim])def call(self,inputs,training = None):out = inputs @   self.kernel + self.biasreturn out

4.自定义网络

# 利用自定义层,创建自定义网络(5层)
class MyModel(keras.Model):def __init__(self):super(MyModel,self).__init__()self.fc1 = MyDense(28*28,256)self.fc2 = MyDense(256,128)self.fc3 = MyDense(128,64)self.fc4 = MyDense(64,32)self.fc5 = MyDense(32,10)# 定义前向传播def call(self,inputs,training = None):x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)   x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)return x

5.自定义网络实战-手写数字识别

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras# 数据预处理
def preprocess(x, y):"""x is a simple image, not a batch"""x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [28 * 28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, ybatchsz = 128
# 数据集加载
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)sample = next(iter(db))
print(sample[0].shape, sample[1].shape)# 构建多层网络
network = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()# 自定义构建多层网络
# 自定义层
class MyDense(layers.Layer):def __init__(self, inp_dim, outp_dim):super(MyDense, self).__init__()self.kernel = self.add_variable('w', [inp_dim, outp_dim])self.bias = self.add_variable('b', [outp_dim])def call(self, inputs, training=None):out = inputs @ self.kernel + self.biasreturn out# 自定义网络
class MyModel(keras.Model):def __init__(self):super(MyModel, self).__init__()self.fc1 = MyDense(28 * 28, 256)self.fc2 = MyDense(256, 128)self.fc3 = MyDense(128, 64)self.fc4 = MyDense(64, 32)self.fc5 = MyDense(32, 10)def call(self, inputs, training=None):x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)return xnetwork = MyModel()network.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])network.fit(db, epochs=5, validation_data=ds_val,validation_freq=2)network.evaluate(ds_val)sample = next(iter(ds_val))
x = sample[0]
y = sample[1]  # one-hot
pred = network.predict(x)  # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)print(pred)
print(y)

深度学习2.0-22.Keras高层接口之自定义层或网络相关推荐

  1. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  2. Keras深度学习实战(22)——生成对抗网络详解与实现

    Keras深度学习实战(22)--生成对抗网络详解与实现 0. 前言 1. 生成对抗网络原理 2. 模型分析 3. 利用生成对抗网络生成手写数字图像 小结 系列链接 0. 前言 生成对抗网络 (Gen ...

  3. 常用深度学习框——Caffe/TensorFlow / Keras/ PyTorch/MXNet

    常用深度学习框--Caffe/TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tensor ...

  4. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  5. 深度学习笔记(22) Padding

    深度学习笔记(22) Padding 1. 卷积的缺陷 2. Padding 3. Valid卷积 4. Same卷积 5. 奇数的过滤器 1. 卷积的缺陷 为了构建深度神经网络,需要学会使用的一个基 ...

  6. 日月光华深度学习(一、二)深度学习基础和tf.keras

    日月光华深度学习(一.二)深度学习基础和tf.keras [2.2]--tf.keras实现线性回归 [2.5]--多层感知器(神经网络)的代码实现 [2.6]--逻辑回归与交叉熵 [2.7]--逻辑 ...

  7. halcon 深度学习标注_HALCON深度学习工具0.4 早鸟版发布了

    原标题:HALCON深度学习工具0.4 早鸟版发布了 HALOCN深度学习工具在整个深度学习过程中扮演着重要的作用,而且在将来将扮演更重要的辅助作用,大大加快深度学习的开发流程,目前发布版本工具的主要 ...

  8. halcon显示坐标_HALCON深度学习工具0.4 早鸟版发布了

    HALOCN深度学习工具在整个深度学习过程中扮演着重要的作用,而且在将来将扮演更重要的辅助作用,大大加快深度学习的开发流程,目前发布版本工具的主要作用是图像数据处理和目标检测和分类中的标注. 标注训练 ...

  9. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  10. 深度学习之卷积神经网络(7)池化层

    深度学习之卷积神经网络(7)池化层 在卷积层中,可以通过调节步长参数s实现特征图的高宽成倍缩小,从而降低了网络的参数量.实际上,处理通过设置步长,还有一种专门的网络层可以实现尺寸缩减功能,它就是这里要 ...

最新文章

  1. Sphinx武林秘籍(上)
  2. curl 忽略证书访问 https
  3. NTU 笔记 6422quiz 复习(1~3节)
  4. ECCV 2020 Spotlight | CFBI:前背景整合的协作式视频目标分割
  5. 路由器mysql密码重置密码_【验证】mysql root密码恢复
  6. RocketMQ的组织架构和基本概念,Dledger高可用集群架构原理
  7. 详细了解为什么支持Postman Chrome应用程序已被弃用?
  8. android samba github,安卓手机访问树莓派samba文件共享出错解决
  9. HDU 5773 The All-purpose Zero(O(nlgn)求LIS)
  10. 为什么这本书大家都称好
  11. 让我们一起ML吧(一)聚类分析
  12. [深入React] 8.refs
  13. php提取文本数据处理,PHP文件处理—读取文件(一个字符,字串)
  14. python 高并发 tomcat_TOMCAT 高并发配置
  15. python str函数数字转换成字符串,Pandas将数字转换为字符串意外结果
  16. 【例题】利用伴随矩阵求逆矩阵
  17. 沟通在日常管理工作中的重要性
  18. Flutter 2.8 更新详解
  19. 关于iOS中UITableView下拉距离短刷新没事,下拉距离长就会崩溃的问题解决方案
  20. ViewFlipper实现带索引效果的自动播放也可手动滑动的广告栏

热门文章

  1. 偏差-方差分解 Bias-Variance Decomposition(转载)
  2. uva 10562 Undraw the Trees
  3. [NOI2006] 神奇口袋
  4. Windows 10下使用Xshell5连接虚拟机的ubuntu18系统
  5. 用Python自动发送邮件
  6. 高可用Kubernetes集群原理介绍
  7. selenium--python如何定位一组元素并返回文本值
  8. 仅当使用了列列表并且 IDENTITY_INSERT 为 ON 时,才能为表'XXX'中的标识列指定显式值。...
  9. Vert.x(vertx) 认证和授权
  10. powershell 常用命令之取磁盘分区信息