Keras Hello World

最近开始学习Keras,个人觉得最有效的学习方法就是写很多很多代码,熟能成巧。我们先从最简单的例子来学习Keras,Keras版的Hello World。
在这个例子中,我们的任务是进行花朵的分类。

  • 本文的代码来自 这里,作者将Keras与sklearn进行了比较,有详细代码的解释。本文只做Keras部分的内容,删去了原本sklearn的内容
  • 本文为个人的代码记录,只为记录在写代码时的困惑,并假设读者有一些深度学习的基本概念。
  • 完整代码在 https://github.com/jiemojiemo/Keras-Demo/tree/master/Keras_Hello_World

Keras的安装

在Keras官网中已经给出了详细的安装指南

  • Linux 安装指南
  • Windows 安装指南
  • 但是个人还是推荐使用Anaconda进行安装,在Anaconda安装好的情况下使用以下命令进行安装Keras
    • 创建名为keras的环境

    conda create -n keras python=3

    • 进入环境

    activate keras

    • 安装keras

    conda update conda
    conda update --all
    conda install mingw libpython
    conda install keras

    • 安装TensorFlow

    pip install tensorflow

    • 安装一些必要的库(缺啥装啥)

    conda install jupyter notebook numpy matplotlib

  • Anaconda的下载速度可以通过设置国内镜像来提到,具体请看Windows下有什么办法提高conda install的速度?

让我们开始吧

首先,我们先导入模块,一些常见的模块如numpy, matplotlib就不解释了。

  • seaborn 一个matplotlib的高级封装,让画图更简单漂亮,但是在这个例子中,我们主要用它来导入数据
  • Sequential 叫做“序贯模型”,是Keras模型之一。详见关于Keras模型
  • keras.layers.core 常用层模块,包括全连接层(Dense),激活层等。详见常用层
  • keras.utils utils工具模块,提供了一系列有用的工具。详见utils工具
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as npfrom sklearn.model_selection import train_test_splitimport pandas as pdfrom keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.utils import np_utils

导入数据

  • sns.load_dataset('iris')导入iris数据库,iris包含了150条花的记录,前四个length和width是花的属性,species是花的种共有三种花,分别是setosa,versicolor和virginica。
  • iris.head()查看数据。
  • sns.pairplot(iris, hue='species')用于显示数据(这不是我们要关注的内容)
iris = sns.load_dataset('iris')
iris.head()
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
sns.pairplot(iris, hue='species')
<seaborn.axisgrid.PairGrid at 0x1bdee328710>

训练集和测试集

  • train_test_split将数据集分割为训练集和测试集,train_size为训练集占整个数据集的大小,random_state为随机种子(详见这里)
X = iris.values[:, :4]
y = iris.values[:, 4]
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.8, random_state=0)

One-Hot-Encoding

  • 深度网络只能接受数字作为输入,iris数据集中species是字符串类型的,因此我们需要将其数字化。
  • One-hot encoding 就是这样一种数字化的编码方法,相关概念详见数据处理——One-Hot Encoding
  • 下面给出两种不同的one-hot encoding的实现方式,它们并没有差别
def one_hot_encode_object_array(arr):uniques, ids = np.unique(arr, return_inverse=True)return np_utils.to_categorical(ids, len(uniques))
# if you are a pandas man ...
def one_hot_encode_object_array_pandas(arr):return pd.get_dummies(arr).values
train_y_ohe = one_hot_encode_object_array_pandas(train_y)
test_y_ohe = one_hot_encode_object_array_pandas(test_y)

搭起我们的网络结构

  • 我们要搭的网络很简单,只有两层,可以表示为 4-16-3,隐层激活函数是sigmoid,输出层时softmax
  • Sequential是多个网络层的线性堆叠,也就是“一条路走到黑”
  • Dense就是全连接层了,注意的是,第一层网络需要指明输入数据的大小,后面层就不需要了,keras会自动推导
  • Activation是激活层,常见的激活函数有sigmoid, softmax, ReLU等等
model = Sequential()
# hidden layer
model.add(Dense(16, input_shape=(4,)))
model.add(Activation('sigmoid'))
# output layer
model.add(Dense(3))
model.add(Activation('softmax'))

Compile 编译

  • 对学习过程进行配置。详见编译
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

Training 训练

  • verbose 是否打印训练log
  • validation_split 验证集的大小,关于验证集,简单的说就是用于调整模型参数(模型结构,学习速率,batch_size等等)。人们通过观察训练时的模型在验证集上的表现,来对自己模型作出调整
model.fit(train_X, train_y_ohe, epochs=100, batch_size=1, verbose=0, validation_split=0.2)
<keras.callbacks.History at 0x195e2abcf98>

Test 测试

loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=0)
print('Accuracy = {:.2f}'.format(accuracy))
Accuracy = 1.00

总结

对于一个深度学习的任务,大致可以分为以下三个部分

  1. 数据的准备。包括数据的获取,数据清洗,数据预处理等等
  2. 模型的构建。采用何种模型,网络结构是怎样的等都要考虑
  3. 参数的调整。通过反复训练来进行参数的调整,通常这也是最花时间的

Keras-1 学习Keras,从Hello World开始相关推荐

  1. 怎么装python的keras库_matlab调用keras深度学习模型(环境搭建)

    matlab没有直接调用tensorflow模型的接口,但是有调用keras模型的接口,而keras又是tensorflow的高级封装版本,所以就研究一下这个--可以将model-based方法和le ...

  2. 利用深度学习(Keras)进行癫痫分类-Python案例

    目录 癫痫介绍 数据集 Keras深度学习案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:903290195 癫痫介绍 癫痫,即俗称"羊癫风",是由多种 ...

  3. DL之Keras: Keras深度学习框架的注意事项(默认下载存放路径等)、使用方法之详细攻略

    DL之Keras: Keras深度学习框架的注意事项(自动下载存放路径等).使用方法之详细攻略 目录 Keras深度学习框架的注意事项 1.Keras自动下载默认数据集/模型存放位置 Windows系 ...

  4. 【Keras】学习笔记(一)

    传送门:Keras 中文文档 文章目录 一.准备工作 1.概述 2.安装 3.GPU设置 (1)单GPU运行 (2)多GPU运行 二.顺序模型 简单示例 1.整体流程 (1)顺序模型的构建--Sequ ...

  5. 深度学习者的入门福利-Keras深度学习笔记

    Keras深度学习笔记 最近本人在github上发现一个不错的资源,是利用keras来学习深度学习的笔记,笔记内容充实,数据完善,本人亲自实操了里面的所有例子,深感收获颇丰,今天特意推荐给大家,希望能 ...

  6. [Python人工智能] 三十.Keras深度学习构建CNN识别阿拉伯手写文字图像

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN). ...

  7. [深度学习] Keras 如何使用fit和fit_generator

    介绍 在本教程中,您将了解Keras .fit和.fit_generator函数的工作原理,包括它们之间的差异.为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数 ...

  8. keras的学习笔记

    简介 Keras是一个高层神经网络API,Keras由纯Python编写而成并基于Tensorflow.Theano和CNTK后端.Keras 支持快速实验,能够把你的idea迅速转换为结果,适用场景 ...

  9. 【TensorFlow-windows】keras接口学习——线性回归与简单的分类

    前言 之前有写过几篇TensorFlow相关文章,但是用的比较底层的写法,比如tf.nn和tf.layers,也写了部分基本模型如自编码和对抗网络等,感觉写起来不太舒服,最近看官方文档发现它的教程基本 ...

  10. Keras 深度学习框架中文文档

    2019独角兽企业重金招聘Python工程师标准>>> Keras深度学习框架中文文档 Keras官网:http://keras.io/ Github项目:https://githu ...

最新文章

  1. Dom4j和Xpath(转)
  2. Android Bluetooth模块学习笔记
  3. 天翼云从业认证(3.6)了解天翼云大数据SaaS服务
  4. pandas对象保存到mysql出错提示“BLOB/TEXT column used in key specification without a key length”解决办法
  5. linux iptables 如何设置允许几个 ip访问,Linux防火墙iptables限制几个特定ip才能访问服务器。...
  6. 对于《软件工程》课程的认识
  7. python文件合法模式组合_python设计模式之组合模式
  8. HTML表div布局,html使用列表 以及div的布局和table的布局
  9. Python稳基修炼的经典案例10(计算机二级、初学者必会turtle库例题)
  10. “不会Linux,会有什么影响?”资深程序员:基本等于自废武功!
  11. 解决只能滑动弹框内容不能滑动弹框底层内容
  12. 一个...买裤子的全过程
  13. linux 锐捷客户端登录密码,Linux使用经验_使用锐捷客户端登录校园网
  14. Unity知识点0001(Yanlz+协程+List+MeshRender+对象池+链条关节+PlayerPrefs+脚本生命周期+LOD+)
  15. 掌握这60个Excel小技巧
  16. 微信收藏保存服务器,微信的收藏和保存功能有啥区别?
  17. win7,win10系统安装时硬盘格式转换(MBR,GPT)
  18. 如何将大硬盘对拷到小硬盘
  19. HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.64.0-py39h06
  20. css-reset样式重置

热门文章

  1. idea错误提示不明显_微信公众号扫一扫功能提示:10003 redirect_uri域名与后台配置不一致错误解决方案...
  2. Python03 拉格朗日插值法 牛顿插值法(附代码)
  3. python读取excel写入mysql_python读取excel写入mysql
  4. 【项目调研+论文阅读】基于医学文献的实体抽取(NER)方法研究 day5
  5. python列表赋值 连续整数_Python_03_字符串_数据类型_for循环_列表操作
  6. java实验的总结_java实验总结
  7. linux c 修改用户组,Linux C Function()参照之用户组篇
  8. 安卓案例:初试谷歌图表
  9. 安卓学习笔记04:安卓平台架构
  10. 单击跳转_如何在100张工作表中快速实现查找和跳转