文章目录

  • 基类Layer的实现:
  • 激活层的实现:
  • CostLayer的实现:

基类Layer的实现:

前面已经提到过一个layer的包含:shape,激活函数,梯度的处理以及输出层的处理。

import numpy as npclass Layer:def __init__(self, shape):self.shape = shapedef _activate(self, x, predict):passdef __str__(self):return self.__class__.__name__def __repr__(self):return str(self)@propertydef name(self):return str(self)# v^i = f^(u^i)  , y is v^kdef derivative(self, y):pass# forward pass : XW + bdef activate(self, x, w, bias):return x.dot(w) + bias# y = v^i, w = w^i, prev_delta = delta^i+1# f' = f *(1-f), so if y is v^i, it is easy to derivate# 反向传播计算误差,非输出层的处理def bp(self, y, w, pre_delta):return pre_delta.dot(w.T)*self.derivative(y)

激活层的实现:

class Sigmoid(Layer):def _activate(self, x):return 1 / (1 + np.exp(-x))def derivative(self, y):return y*(1-y)

CostLayer的实现:

涉及到2个部分处理,一是损失的计算方式,而是特殊的变换,还有最后一层的误差需要单独计算。

# =============================================================================
#     没有激活函数、但可能会有特殊的变换函数(比如说 Softmax),同时还需要定义某个损失函数
#     定义导函数时,需要考虑到自身特殊的变换函数并计算相应的、整合后的梯度
# =============================================================================
# 输出层有两个功能:一是特殊的变换,二是计算最后一层的梯度,最后一层的梯度需要特殊处理
class CostLayer(Layer):"""初始化结构self._available_cost_functions:记录所有损失函数的字典self._available_transform_functions:记录所有特殊变换函数的字典self._cost_function、self._cost_function_name:记录损失函数及其名字的两个属性self._transform_function 、self._transform:记录特殊变换函数及其名字的两个属性"""def __init__(self, shape, cost_function="MSE", transform=None):super(CostLayer, self).__init__(shape)self._available_cost_functions = {"MSE": CostLayer._mse,"SVM": CostLayer._svm,"CrossEntropy": CostLayer._cross_entropy}self._available_transform_functions = {"Softmax": CostLayer._softmax,"Sigmoid": CostLayer._sigmoid}self._cost_function_name = cost_functionself._cost_function = self._available_cost_functions[cost_function]if transform is None and cost_function == "CrossEntropy":self._transform = "Softmax"self._transform_function = CostLayer._softmaxelse:self._transform = transformself._transform_function = self._available_transform_functions.get(transform,None)def __str__(self):return self._cost_function_namedef _activate(self, x, predict):if self._transform_function is None:return xreturn self._transform_function(x)def _derivative(self, y, delta=None):pass@staticmethoddef safe_exp(x):return np.exp(x - np.max(x,axis=1,keepdims=True))# 特殊变换函数@staticmethoddef _softmax(y, diff=False):if diff:return y*(1-y)exp_y = CostLayer.safe_exp(y)return exp_y / np.sum(exp_y, axis=1, keepdims=True)@staticmethoddef _sigmoid(y, diff=False):if diff:return y * (1 - y)return 1 / (1 + np.exp(-y))# 单独计算输出层的误差def bp_first(self, y, y_pred):if self._cost_function_name == "CrossEntropy" and (self._transform == "Softmax" or self._transform =="Sigmoid"):return y - y_pred        # 否则、就只能用普适性公式进行计算:#            (没有特殊变换函数)#  (有特殊变换函数)dy = -self._cost_function(y, y_pred)if self._transform_function is None:return dyreturn dy * self._transform_function(y_pred, diff=True)@propertydef calculate(self):return lambda y, y_pred: self._cost_function(y,y_pred,False)# 损失函数@staticmethoddef _mse(y, y_pred, diff=True):if diff:return y_pred - yreturn 0.5*np.average((y-y_pred)**2)@staticmethoddef _cross_entropy(y, y_pred, diff=True, eps=1e-8):if diff:return -y / (y_pred + eps) + (1-y)/ (1- y_pred +eps)return np.average(-y * np.log(y_pred + eps) - (1 - y) * np.log(1 - y_pred + eps))

深度学习:tensorflow Layers的实现,numpy实现深度学习(二)相关推荐

  1. 中文版!学习TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!(附免费下载)...

    编辑:深度学习冲鸭公众号 学习深度学习以及面试肯定离不开下面的5个重要的资料,更何况是中文版! 获得方式: 1. 关注[深度学习冲鸭]公众号 2. 在[深度学习冲鸭]公众号后台回复 五件套 即可. 1 ...

  2. 中文版!学习 TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!(附免费下载)...

    学习深度学习以及面试肯定离不开下面的5个重要的资料,更何况是中文版! 获得方式: 1. 关注[AI有道]公众号 2. 在[AI有道]公众号后台回复 五件套 即可. 1. TensorFlow深度学习 ...

  3. 学习TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!附下载链接!...

    学习深度学习以及面试肯定离不开下面的5个重要的资料,更何况是中文版! 资料领取: 扫码后台回复:3070,即可获取电子版 内容简介 1. TensorFlow深度学习 书籍特点             ...

  4. python深度学习tensorflow和fme结合,实现档案扫描件数据自动分类

    文章目录 前言 一.深度学习基础知识简介 1.什么是深度学习 2.深度学习的原理 3.深度学习应用场景 二.深度学习环境搭建 1.深度学习库的安装 2.CUDA和对应版本的cudnn下载 三.实战教学 ...

  5. 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(八)(TensorFlow基础))

    [神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(八)(TensorFlow基础)) 8 TensorFlow基础 8.1 TensorFlow2.0特性 8.1.1 Tenso ...

  6. 深度学习 TensorFlow入门

    文章目录 一.深度学习框架-TensorFlow 1.1 TensorFlow介绍 1.2 TensorFlow的安装 1.3 张量及其操作 1.3.1 张量Tensor 1.基本方法 2.转换成nu ...

  7. 【神经网络与深度学习-TensorFlow实践】-中国大学MOOC课程(十四)(卷积神经网络))

    [神经网络与深度学习-TensorFlow实践]-中国大学MOOC课程(十四)(卷积神经网络)) 14 卷积神经网络 14.1 深度学习基础 14.1.1 深度学习的基本思想 14.1.2 深度学习三 ...

  8. 使用TensorFlow.js在浏览器中进行深度学习入门

    目录 设置TensorFlow.js 创建训练数据 检查点 定义神经网络模型 训练AI 测试结果 终点线 内存使用注意事项 下一步是什么?狗和披萨? 下载TensorFlowJS示例-6.1 MB T ...

  9. 从TensorFlow到PyTorch:九大深度学习框架哪款最适合你?

    人工智能AI与大数据技术实战  公众号: weic2c 开源的深度学习神经网络正步入成熟,而现在有许多框架具备为个性化方案提供先进的机器学习和人工智能的能力.那么如何决定哪个开源框架最适合你呢?本文试 ...

  10. 使用TensorFlow进行股票价格预测的简单深度学习模型

    使用TensorFlow进行股票价格预测的简单深度学习模型(翻译) 原文链接:https://medium.com/mlreview/a-simple-deep-learning-model-for- ...

最新文章

  1. Java中的等待/通知机制(wait/notify)
  2. wpf checkbox选中触发事件_Cypress 可操作事件
  3. log4jdbc mysql jdbc_spring boot 集成log4jdbc 查看完整sql
  4. 【转载】通俗理解极大似然估计
  5. Ubuntu 13.04 安装 SSH SERVER记
  6. python消息队列celery高可用_分布式消息队列-Celery
  7. php免登录接口,PHPWind 8.0 论坛免登陆发布接口发布
  8. 平台建设的根节与基础保障——互联网平台建设系列...
  9. ajax中的url怎么写_简历中的自我评价怎么写,才能成功吸引HR?
  10. Python生成带自定义信息和头像图片的二维码
  11. LaTeX(6)——LaTeX引用使用(\label)
  12. Linux系统下常用的帮助man,whatis,info,help总结
  13. maven项目中:java.io.IOException: java.io.FileNotFoundException--- (文件名、目录名或卷标语法不正确。)
  14. 【马仔创业感悟】公司售前和售后维护制度思考
  15. python鼠标移动到网页上、获取网页信息_python 调用pyautogui 实时获取鼠标的位置、移动鼠标的方法...
  16. 如何对研发团队绩效进行考核--附各环节人员考核参考表
  17. warning: mysql_fetch_array_php提示Warning:mysql_fetch_array() expects的解决方法,expects
  18. 由EIG牵头的财团与沙特阿美达成124亿美元的基础设施交易
  19. target找不到*.xml和*.properties文件 报错:FileNotFoundException
  20. IM即时通讯项目讲解(一) 实现类似qq微信表情面板无缝切换

热门文章

  1. 双边z变换公式_光通信与数学 傅里叶变换
  2. roundrobin来历_Linux系统管理
  3. java运行机制以及 运行流程
  4. vs java调试_基于VSCode的Java编程语言的构建调试环境搭建指南(作业三)
  5. 运行war包的命令及linux下实时查看日志
  6. wxml修改样式_微信小程序 动态绑定事件并实现事件修改样式
  7. python解码base64_在python中解码Base64 Gzip
  8. python如何进阶提升_Python序列操作之进阶篇
  9. 【maven】配置多个仓库
  10. linux的网络不可达问题,我的服务器日志中的linux – (网络不可达)错误