原文

简介

Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预测放在前端,置于浏览器内。

本文会用一个简单的demo介绍如何从零开始训练一个tensorflow模型,并在浏览器内实现手写数字识别,最终效果大约如下:

手写数字识别示例

本文会假设你有基本的python和JavaScript的知识。项目的完整代码参考github。

准备

项目代码的目录结构如下:

项目目录结构

整个结构大概分成server和web两个部分,分别是服务端和浏览器端的代码。

我们的流程大概如下:

  1. 下载训练数据集,用python的tensorflow训练模型,并保存模型文件。
  2. 使用python的flask启动服务,使模型文件可以作为本地服务的静态文件被访问。
  3. 在网页html内,用canvas创建一个可以随意涂抹的画布,并能够获取画布上的像素信息。
  4. 在JavaScript脚本内导入tf.js,载入训练模型,通过模型计算画布上的信息的预测结果,并显示在图表上。

我们需要的所有依赖如下:

python:

建议使用3.5以上的版本。我不能保证在<3.5的版本中它是否能正常工作。Tensorflow的兼容性问题一向令人头疼。注意在mac和linux上默认的python是python2。

  • numpy —— 一个知名的python数学计算库,在矩阵和数组运算方面非常强大
  • tensorflow —— 机器学习库,直接用pip安装的是cpu版本。如果你的pc有一个足够好的独立显卡,可以试试tensorflow-gpu。它可以使训练的速度更快。但tensorflow-gpu的配置方法比较复杂。我们的模型比较简单,即使用cpu训练也不会耗时太久。
  • tensorflowjs —— 用于导出并保存可以被浏览器使用的模型文件
  • flask —— 一个轻量级的python网络服务框架
  • flask-cors —— 用于支持flask跨域请求的一个库

这个demo内已经包含一个已经训练好的模型,所以你如果并不想自己再训练一次,可以不安装tensorflow和tensorflowjs。所有这些依赖都可以通过pip安装。

JavaScript:

你不需要特别安装任何东西,因为我们的库都是通过链接导入的。

  • tf.js —— 它就是本文要介绍的,尽管只会涉及它的极小的一点。
  • fabric.js —— 可选,用于比较方便地构造画布。
  • Chart.js —— 可选,只是用来画出下边的图表的。你也可以不要它,如果你对这种可视化的结果不感兴趣。

浏览器:

反正在chrome浏览器里是能跑起来的……

训练

项目文件里面已经包含了一个训练好的模型,位于{项目路径}/server/models/mnist文件夹内。

我们使用MNIST数据集来训练模型。MNIST是一个知名的手写数字识别的数据集。对很多机器学习的初学者而言,这很可能是他们接触到的第一个数据集。这个数据集中包含60000张训练图片以及10000张测试图片,每张图片都是一个28×28像素的手写数字图片。如下图所示:

mnist.png

MNIST用一个28×28的矩阵来代表这样的一张数字图片,矩阵内的每个元素表示对应点位置的灰度,在0~255之间。

下载数据:

事实上,你可以跳过下载数据这一步而直接开始训练,因为在训练函数中会自动下载数据,但鉴于国内糟糕的网络环境,我还是建议你先把数据手动下载下来。我会优先从本地读取数据。

下载地址:mnist.npz
下载完成后保存在路径{项目路径}/server/datasets/mnist.npz的位置。npz是numpy的一种数据压缩格式。文件大小大概11m。然后我们用load_data函数载入数据:

import numpy as np
from tensorflow.keras import layers, datasetsdef load_data(path):try:with np.load(path) as f:x_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']x_train, x_test = x_train/255.0, x_test/255.0return (x_train, y_train), (x_test, y_test)except FileNotFoundError:return datasets.mnist.load_data()

其中,x_train是一个60000×28×28的3维向量,代表60000张图片;y_train是长度60000的向量,每一项代表对应图片的实际数字,是一个0~9的整数。x_test,y_test是测试集上的对应数据,测试集大小为10000。注意x_train, x_test = x_train/255.0, x_test/255.0这一步是把每个灰度数字转换为一个0~1之间的小数。

训练模型:

我们使用tensorflow.keras的接口来实现一个简单的卷积神经网络(Convolutional Neural Network, CNN)模型。它包含了一个卷积层,一个池化层,和两个全连接层。我不会这里解释全部概念——对新手来说,它们过于令人困惑而费解。而且你也不需要在这里理解它。如果你真的很想从直观上把握它的话,你可以试试这篇博客:An Intuitive Explanation of Convolutional Neural Networks。它有点长,但为此花一些时间依然是值得的。

在server/train.py文件下可以看到训练函数的代码:

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, datasets
import tensorflowjs as tfjsdef train_modle(data):(x_train, y_train), (x_test, y_test) = datamodel = Sequential([layers.Reshape((28, 28, 1), input_shape=(28, 28)),layers.Conv2D(16, (5, 5), padding='valid', input_shape=(28, 28, 1), activation='relu'),layers.MaxPooling2D(pool_size=2),layers.Dropout(0.2),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(10, activation='softmax')])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])model.fit(x_train, y_train, epochs=5, batch_size=64)model.evaluate(x_test, y_test)tfjs.converters.save_keras_model(model, model_path)

从上至下,这个训练函数做的:

  • 读取从npz中或自动下载的训练数据。
  • 构造模型的层序列,tensorflow.keras是以layer这种对象组织计算过程的。每一层输出是下一层的输入,最后的输出就是模型的输出。它每层依次是:
    • Reshape。注意我们的输入的数据对应的是图片的每个像素的灰度,是没有‘深度’的。而卷积层要求的输入必须是有‘深度’的。所以我们首先为数据额外地加一个‘深度’为1的第三个维度。
    • 卷积层。这一层的作用是提取一个图片的每一个点周围的‘局部特征’,并传递给下一层。我们需要16种特征,对每一种特征,我们用一个5×5大小的矩阵,‘扫描’图片,并据此计算出一个值。所以这一步,是把每个点都映射到一个长16的向量,来代表这个点的16种不同的局部特征。
    • 池化层。这一步是为了降低数据的大小。在所有的相邻的2×2的的范围内,我们只保留其中的最大值。
    • Dropout。在训练过程中,每次更新参数时,随机地把一部分输入节点忽略掉。这是一种防止过拟合的简单技巧。它只会应用于训练时,不会用在预测上。
    • Flatten。输入展平成一个一维数组。如果你的下一层是全连接层,那么这一步是必要的(除非你想把输出格式搞得一团糟)。
    • Dense。大小为128的全连接层。上一层的所有点都与这层的所有点相连。
    • Dense。大小为10的最末端的全连接层,它的输出就是模型的预测结果,对应一张图片是每个数字的概率。
  • 损失函数是用于估计模型预测结果和正确结果的偏差的函数。比如实际的数字为2,那么我们期待的结果应该是[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ],即除第3位为1外其他的都是0;而我们的预测结果可能是[0.2, 0.3, 0.1,...]。我们这里使用交叉熵算法来评估两种概率分布间的差别。训练的目的就是使得这样的损失函数的值尽量接近0。
  • 优化函数决定了在预测结果和正确结果的特定偏差下,应该如何更新参数。这里我们使用adam优化器。
  • fit。使用训练集来训练。我们每次在大小60000数据中取出64个作为一批,计算损失函数并优化参数。在整个数据集上,重复5次。
  • evaluate。使用测试集来评估训练结果。只计算损失函数,不做参数优化。
  • 最后一步,是把模型的训练结果保存成文件,在预测时可以调用。

撇开数学上的概念理解不谈,一般初学者在训练过程中最容易让人弄错的地方是数据的格式(shape)。

运行文件,开始训练

python server/train.py

如果你的环境配置正确,你应该会看到这样的输出:

60000/60000 [==============================] - 11s 185us/sample - loss: 0.1896 - acc: 0.9453
Epoch 2/5
60000/60000 [==============================] - 14s 225us/sample - loss: 0.0678 - acc: 0.9791
Epoch 3/5
60000/60000 [==============================] - 13s 221us/sample - loss: 0.0504 - acc: 0.9840
Epoch 4/5
60000/60000 [==============================] - 14s 233us/sample - loss: 0.0377 - acc: 0.9881
Epoch 5/5
60000/60000 [==============================] - 14s 231us/sample - loss: 0.0301 - acc: 0.9900
10000/10000 [==============================] - 1s 93us/sample - loss: 0.0360 - acc: 0.9879

在我的i5 cpu电脑上,整个训练过程大约耗时不到1分钟。
这个输出的结果显示了每一个epoch的耗时、损失函数的值和准确率。最后一行是在测试集上的结果。可以看到,我们的训练结果在测试集上有 98.79% 的准确率。同时,在{项目路径}/server/models/mnist内的文件也会被覆盖更新。你可以调整模型的结构和条件,多试几次,来评估不同条件下的训练结果。

{项目路径}/server/models/mnist下有两个文件,一个很小的model.json文件和另一个大小约1m以上的.bin文件。model.json文件可以直接打开,里面包括了模型的一些总体信息,如模型的结构和参数文件的位置,也就是.bin文件,这个文件记录了这个模型训练出来的所有参数。另外,如果你已经改动过模型的结构或者其他条件重新训练,那么这样的参数文件可能不止一个。

服务

我们已经训练好了模型,但这个模型文件是不能直接被浏览器载入使用的,因为现代浏览器一般都会阻止js直接读取本地文件内容。并且在设计上,这个模型文件也应该是保存在服务端而不是客户端。
我们需要做的,是启动一个服务,并使得这个模型成为这个服务的静态资源,这样js就可以通过请求拉取文件内容。
在server目录下的main.py文件:

from flask import Flask
from flask_cors import CORSapp = Flask(__name__,static_url_path='/models', static_folder='models')cors = CORS(app) @app.route("/")
def hello():return "Hello World!"if __name__ == '__main__':app.run(debug=True)

这是一个非常简单的flask应用代码。在这个应用中,我们把url路径/models映射到了文件目录models(相对于本文件),外界通过{host}/models就能访问到models内的文件。
在项目的根目录下,用命令行启动这个文件:

python server/main.py

如果一切正常的话,你应该会看到这样的输出

* Serving Flask app "main" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment
Use a production WSGI server instead.
* Debug mode: on
* Restarting with stat
* Debugger is active!
* Debugger PIN: 267-971-636
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

现在,你可以打开 http://127.0.0.1:5000/ 或者 http://localhost:5000/,如果你在屏幕上看到了“Hello World!”,就说明服务已经启动成功了。ctrl+c可以退出服务。
此时如果你打开http://localhost:5000/models/mnist/model.json就可以看到我们之前训练出来的模型的json文件。
另外,注意在代码中,我们还加了一句cors = CORS(app),这是为了让这个服务接受跨域请求。本文在这里不会展开讨论这个问题,简单地说:如果在js脚本中试图请求拉取的后端资源的协议或域名与js本身的不一致,那么浏览器会阻止这个请求——这是一种安全保护策略,除非你加了这行代码让后端资源接受跨域。

预测

我们在html头引入这几个库:

    <head><script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script><script src='http://cdnjs.cloudflare.com/ajax/libs/fabric.js/1.4.0/fabric.min.js'></script><script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.4.0/Chart.min.js"></script><script src='./model.js'></script></head>

最后一个model.js是本地的js文件,我们会把模型的导入和数据预测函数都封装在这里。

在model.js文件里,导入模型:

const MODEL_URL = 'http://localhost:5000/models/mnist/model.json' var loadModel = (async function() {window.model = await tf.loadLayersModel(MODEL_URL);console.log('load model')return model;
})
loadModel();

MODEL_URL 就是模型的model.json文件的url地址。用tf.loadLayersModel函数来载入模型并绑定在window上。当你在浏览器控制台里看到 'load model',模型就载入成功了。注意tf.loadLayersModel是异步函数,它返回的是一个Promise对象,你需要用await或者.then()的回调式方法来获取载入的模型对象。

在html里,添加一个id="canvas"的canvas和两个按钮,一个用于识别,另一个用于清空canvas。

        <div id='container'><canvas width="140" height="140" id="canvas" class="canvas"></canvas><div class='button-container'><button onclick="recognize()">recognize</button><button onclick="clear_canvas()">clear</button>   </div></div>

在html内<script>内加上:

        var fabric_canvas = new fabric.Canvas('canvas', {backgroundColor: "#000000"});fabric_canvas.renderTop();fabric_canvas.isDrawingMode = true;fabric_canvas.freeDrawingBrush.width = 12;fabric_canvas.freeDrawingBrush.color = "#ffffff";var recognize = async function() {var results = await predict('canvas');console.log(results);}var clear_canvas = function() {fabric_canvas.clear();}

我们使用fabric.js来构造可以任意涂抹的画图。这并不是必要的,只是可以少写一点代码。

在识别图片recognize函数内,我们调用了一个predict函数,并传入了canvas的id。我们希望这个函数返回的结果就是预测结果。

const width = 28;
const height = 28;var predict = async function(id) {var model = window.model;var canvas = document.getElementById(id);var example = this.load_img(canvas);var prediction = await model.predict(example).data();var results = Array.from(prediction);return results
}var load_img = function(img) {var tensor = tf.browser.fromPixels(img).resizeNearestNeighbor([width, height]).mean(2).expandDims().toFloat().div(255.0)return tensor;
};

predict函数的逻辑也相当直接了当:

  • 获取canvas对象
  • 调用load_img函数,从图片得到tensor张量对象
  • 调用已经载入的模型,预测结果。这里值得注意的是模型的预测函数model.predict同样是一个异步函数,需toFloat要在它的回调的.data方法中取出预测结果。这个结果默认是Float32Array类型,可以转换为Array。

与之前训练模型类似,最麻烦的地方是数据格式shape的处理,我们在load_img里有这么几步:

  • 用tf.browser.fromPixels方法读取canvas的像素信息。这个方法同样可以读取图片的信息。返回的结果是一个张量tensor,shape是140×140×3。最后一维是这张图片的每个像素在3个颜色通道上的值,每个值是一个0~255之间的整数。
  • 把140×140的数据resize成一个28×28的数据。因为我们训练的模型只接受28×28大小。此时的大小为28×28×3。
  • 计算灰度。.mean方法是求平均值的方法,用它我们把第3维的颜色转换为灰度。考虑到我们的图片是黑白的,它在3个颜色上应该是一样的,所以在这里我们也可以用.min.max(最小值、最大值)来计算灰度。此时的大小是28×28。
  • 我们的模型接受的必须是多个图片,所以我们用.expandDims加上一维。此时大小为1×28×28。这里可以用.reshape([1, 28, 28])来达到同样的效果。
  • toFloat,把tensor的元素转换为Float类型。
  • 记住我们的模型在处理mnist输入之前的时候曾做过一个除以255的操作,把灰度转为了0~1之间的小数,这里我们也要做一个同样的处理.div(255.0)

现在我们在canvas上写数据,再点击recognize按钮,就能在浏览器的控制台里看到预测的结果:

Array(10) [ 2.229090443993517e-15, 1.264737121454973e-12, 6.231850036009234e-10, 0.9999980926513672, 7.358067470207216e-14, 7.870837634982308e-7, 3.1836545118929527e-13, 6.341550395916329e-9, 8.096231454146618e-7, 1.0121870008816813e-10 ]

这个长度为10的数组表示模型预测canvas上的图片是0~9之间每个数字的概率。在我的项目里我还加了一个直方图表来表示这个数据,本文略过此处。

其他

一些其他的值得注意的地方:

  • 在这个demo中,训练使用的是python,tf.js只用于预测。而tf.js本身也可以用于训练数据。只是在浏览器里做训练意义不大。
  • 我们的网页是通过直接打开index.html来打开的。事实上,我们也可以把web文件夹下的index.html和model.js放在flask服务的静态资源路径里,这样就能通过url来访问网页了,并且这样flask也不用启用CORS,因为没有跨域。

所有的代码都在这里:digits-recognition-tfjs。作者十分感谢这篇博客:Recognizing Digits using TensorFlow.js in Google Chrome,它对本文启发很大。

如果你觉得这篇文章有帮助的话,记得赞赏。

用tensorflow.js实现浏览器内的手写数字识别相关推荐

  1. tensorflow应用:双向LSTM神经网络手写数字识别

    tensorflow应用:双向LSTM神经网络手写数字识别 思路 Python程序1.建模训练保存 Tensorboard检查计算图及训练结果 打开训练好的模型进行预测 思路 将28X28的图片看成2 ...

  2. Tensorflow 学习入门(二) 初级图像识别——手写数字识别

    初级图像识别--手写数字识别 背景知识储备 Softmax Regression MNIST 矩阵相乘 One Hot 编码 Cross Entropy(交叉熵) 代码实现 引入数据 设计数据结构 完 ...

  3. matlab 对mnist手写数字数据集进行判决分析_人工智能TensorFlow(十四)MINIST手写数字识别...

    MNIST是一个简单的视觉计算数据集,它是像下面这样手写的数字图片: MNIST 每张图片还额外有一个标签记录了图片上数字是几,例如上面几张图的标签就是:5.0.4.1. MINIST数据 MINIS ...

  4. 北京大学曹健——Tensorflow笔记 05 MNIST数据集输出手写数字识别准确率

              # 前向传播:描述了网络结构 minist_forward.py # 反向传播:描述了模型参数的优化方法 mnist_backward.py # 测试输出准确率minist_tes ...

  5. pytorch实现手写数字识别_送源码!人工智能实现:识别图片中的手写数字,值得收藏...

    作者|小林同学 关注<高手杰瑞>,每天有不一样的实用小教程发布哦! 哈喽,大家好我是杰瑞.今天我给大家带来一个用机器学习的方法来实现手写数字识别的教程,就像C语言中输出的那一行" ...

  6. 小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别

    一.题目要求 1.三个js文件,分别完成:网络训练以及模型保存.模型加载及准确率测试.在线手写数字识别: 2.模型测试准确率要高于99.3%(尽量): 3.在线手写数字识别需要能够通过鼠标在画布中写入 ...

  7. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  8. TensorFlow 2.0 快速上手教程与手写数字识别例子讲解

    文章目录 TensorFlow 基础 自动求导机制 参数优化 TensorFlow 模型建立.训练与评估 通用模型的类结构 多层感知机手写数字识别 Keras Pipeline * TensorFlo ...

  9. Tensorflow 神经网络作业手写数字识别 训练、回测准确率

    大白话讲解卷积神经网络工作原理,推荐一个bilibili的讲卷积神经网络的视频,up主从youtube搬运过来,用中文讲了一遍. 这篇文章是 TensorFlow 2.0 Tutorial 入门教程的 ...

最新文章

  1. 投稿2877篇,EMNLP 2019公布4篇最佳论文
  2. linux shell 去掉 文本换行符
  3. Ext JS学习第十六天 事件机制event(一)
  4. 027_Badge标记
  5. Forrester 首席分析师对话阿里云容器服务负责人:容器的未来趋势是什么?
  6. 把Sql数据转换为业务数据的几种方法
  7. nginx和apache的伪静态区别
  8. PHP通知弹窗代码_公告弹窗
  9. kafka修改默认端口号
  10. 【随记】Q号解除限制一波三折
  11. linux window nginx性能,KVM虚拟机 Nginx性能测试
  12. unity下载教育版_新的现场学习系列为Unity教育工作者提供支持
  13. 服务器电源ic芯片,8种常见电源管理IC芯片介绍
  14. BGP 模式下 Calico 与 MetalLB 的组合
  15. Python代码大全,海量代码任你下载
  16. 如何进入docker系统
  17. python启动netron
  18. 整活~使用webAI做一个网页AR吃豆人小游戏
  19. 计算机内存4G,笔记本电脑4g内存和8g内存的区别
  20. 【基于QMediaPlayer的简易视频播放器】— 3、结合QSlider实现播放进度控制和音量控制

热门文章

  1. 华为云ECS下安装MySQL
  2. Nginx正向代理和反向代理配置
  3. 系统定时重启服务脚本案例
  4. 学习总结-《父与子的编程之旅》chapter 20
  5. 利用黑客手段一台手机“变”出千万台,新型诈骗技术曝光
  6. web实验2 制作简单网页(HTML+CSS)
  7. Codeforces 715A Plus and Square Root
  8. Linux Signal (2): signal函数
  9. 所有的 Boost 库文档的索引
  10. 彻底掌握 Javascript(十一)日期-曾亮-专题视频课程