使用 Colab 在 tf.keras 中训练模型,并使用 TensorFlow.js 在浏览器中运行
文 / Zaid Alyafeai
我们将创建一个简单的工具来识别图纸并输出当前图纸的名称。 此应用程序将直接在浏览器上运行,无需任何安装。我们会使用 Google Colab 来训练模型,并使用 TensorFlow.js 在浏览器上部署它。
【想获取 TensorFlow js. 视频教程,请前往 Bilibili,TensorFlow 渠道查看:https://www.bilibili.com/video/BV1D54y1p7PQ】
代码和演示
在 GitHub 上找到现场演示和代码。 另外,请务必在此处测试 Google Colab 上的 notebook。
注:此处链接
https://colab.research.google.com/github/zaidalyafeai/zaidalyafeai.github.io/blob/master/sketcher/Sketcher.ipynb
数据集
我们将使用 CNN 识别不同类型的图样。 CNN 将在 Quick Draw 数据集上进行训练。 该数据集包含大约 345 个类别 5000 万个图样。
类的子集
传递途径
我们将使用 Keras 在 Google Colab 的 GPU 上免费训练模型,然后使用 TensorFlow.js(tfjs)直接在浏览器上运行。 我在 TensorFlow.js 上制作了一个教程,烦请阅读之后再继续。 这是该项目的传递途径:
在 Colab 上培训
Google 在 GPU 上提供免费处理能力。 您可以在本教程中看到如何创建笔记本和激活 GPU 编程。
输入
我们将基于 tensorflow 使用 keras
1 import os 2 import glob 3 import numpy as np 4 from tensorflow.keras import layers 5 from tensorflow import keras 6 import tensorflow as tf
加载数据
由于内存有限,我们不会对所有类别进行训练。 我们只使用 100 个数据集。 每个类别的数据在 Google Cloud 上可用作形状为 [N,784] 的 numpy 数组,其中 N 是该特定类的图像数。 我们首先下载数据集
1 import urllib.request 2 def download(): 34 base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/' 5 for c in classes: 6 cls_url = c.replace('_', '%20') 7 path = base+cls_url+'.npy' 8 print(path) 9 urllib.request.urlretrieve(path, 'data/'+c+'.npy')
由于内存有限,我们只会将每个类别中的 5000 张图像加载到内存。 还保留 20% 的未经测试的数据
1 def load_data(root, vfold_ratio=0.2, max_items_per_class= 5000 ): 2 all_files = glob.glob(os.path.join(root, '*.npy')) 3 4 #initialize variables 5 x = np.empty([0, 784]) 6 y = np.empty([0]) 7 class_names = [] 89 #load a subset of the data to memory 10 for idx, file in enumerate(all_files): 11 data = np.load(file) 12 data = data[0: max_items_per_class, :] 13 labels = np.full(data.shape[0], idx) 1415 x = np.concatenate((x, data), axis=0) 16 y = np.append(y, labels) 1718 class_name, ext = os.path.splitext(os.path.basename(file)) 19 class_names.append(class_name) 2021 data = None 22 labels = None 2324 #separate into training and testing 25 permutation = np.random.permutation(y.shape[0]) 26 x = x[permutation, :] 27 y = y[permutation] 2829 vfold_size = int(x.shape[0]/100*(vfold_ratio*100)) 3031 x_test = x[0:vfold_size, :] 32 y_test = y[0:vfold_size] 3334 x_train = x[vfold_size:x.shape[0], :] 35 y_train = y[vfold_size:y.shape[0]] return x_train, y_train, x_test, y_test, class_names
预处理数据
我们预处理数据准备开始训练。
1 # Reshape and normalize 2 x_train = x_train.reshape(x_train.shape[0], image_size, image_size, 1).astype('float32') 3 x_test = x_test.reshape(x_test.shape[0], image_size, image_size, 1).astype('float32') 45 x_train /= 255.0 6 x_test /= 255.0 78 # Convert class vectors to class matrices 9 y_train = keras.utils.to_categorical(y_train, num_classes) 10 y_test = keras.utils.to_categorical(y_test, num_classes)
创建模型
我们将创建一个简单的 CNN。 请注意,参数数量越少,模型越简单越好。 实际上,我们将在浏览器转换后运行模型,并且我们希望让模型快速运行并进行预测。 以下模型包含 3 个转换层和 2 个密集层。
1 # Define model 2 model = keras.Sequential() 3 model.add(layers.Convolution2D(16, (3, 3), 4 padding='same', 5 input_shape=x_train.shape[1:], activation='relu')) 6 model.add(layers.MaxPooling2D(pool_size=(2, 2)))7 model.add(layers.Convolution2D(32, (3, 3), padding='same', activation= 'relu')) 8 model.add(layers.MaxPooling2D(pool_size=(2, 2)))9 model.add(layers.Convolution2D(64, (3, 3), padding='same', activation= 'relu')) 10 model.add(layers.MaxPooling2D(pool_size =(2,2)))11 model.add(layers.Flatten()) 12 model.add(layers.Dense(128, activation='relu')) 13 model.add(layers.Dense(100, activation='softmax')) 14 # Train model 15 adam = tf.train.AdamOptimizer() 16 model.compile(loss='categorical_crossentropy', 17 optimizer=adam, 18 metrics=['top_k_categorical_accuracy']) 19 print(model.summary())
适配,验证和测试
之后,我们基于 5 个 epochs 和 256 个 batch 训练模型。
1 #fit the model 2 model.fit(x = x_train, y = y_train, validation_split=0.1, batch_size = 256, verbose=2, epochs=5) 34 #evaluate on unseen data 5 score = model.evaluate(x_test, y_test, verbose=0)6 print('Test accuarcy: {:0.2f}%'.format(score[1] * 100))
以下是训练的结果
测试精度为 92.20%。
准备 Web 格式的模型
在我们对模型的准确性感到满意之后,我们将其保存以便进行转换
1 model.save('keras.h5')
我们安装了 tfjs 包进行转换
1 !pip install tensorflowjs
之后我们转换该模型
1 !mkdir model 2 !tensorflowjs_converter --input_format keras keras.h5 model/
这将创建一些权重文件以及包含模型体系结构的 json 文件。
压缩模型,准备将其下载到本地计算机
1 !zip -r model.zip model
最后下载模型
1 from google.colab import files 2 files.download('model.zip')
浏览器推断
在本节中,我们将展示如何加载模型并进行推理。 假设我们有一个尺寸为 300 x 300 的画布。 关于界面和 TensorFlow.js 部分我就不一一详细展开了。
加载模型
为了使用 TensorFlow.js 首先我们使用以下的脚本
1 <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
你需要在本地计算机上运行服务器来承载权重文件。 可以像我一样在项目上创建一个 apache 服务器或在 GitHub 上托管页面。
之后,将模型加载到浏览器
1 model = await tf.loadModel('model/model.json')
使用 await 等待浏览器加载模型。
预处理
我们需要在进行预测之前预处理数据。 首先从画布中获取图像数据
1 //the minimum boudning box around the current drawing2 const mbb = getMinBox()3 //cacluate the dpi of the current window 4 const dpi = window.devicePixelRatio5 //extract the image data 6 const imgData = canvas.contextContainer.getImageData(mbb.min.x * dpi, mbb.min.y * dpi,7 (mbb.max.x - mbb.min.x) * dpi, (mbb.max.y - mbb.min.y) * dpi);
我们以后再解释 getMinBox()。 变量 dpi 用于根据屏幕像素的密度拉伸画布。
我们将画布的当前图像数据转换为张量,调整大小并进行标准化。
1 function preprocess(imgData) 2 { 3 return tf.tidy(()=>{ 4 //convert the image data to a tensor 5 let tensor = tf.fromPixels(imgData, numChannels= 1) 6 //resize to 28 x 28 7 const resized = tf.image.resizeBilinear(tensor, [28, 28]).toFloat() 8 // Normalize the image 9 const offset = tf.scalar(255.0); 10 const normalized = tf.scalar(1.0).sub(resized.div(offset)); 11 //We add a dimension to get a batch shape 12 const batched = normalized.expandDims(0)13 return batched 14 }) 15 }
对于预测,我们使用 model.predict 这将返回形状为 [N,100]的概率。
1 const pred = model.predict(preprocess(imgData)).dataSync()
然后我们可以使用简单的函数来找到前 5 个概率。
提高准确率
请记住,我们的模型接受形状为 [N,28,28,1] 的张量。 我们的绘图画布尺寸为 300 x 300,对于绘图来说可能是太大了,或者说,用户可能想绘制一个小图。 最好裁剪到仅包含当前图形大小的框。 为此,我们通过查找左上角和右下角来提取图形周围的最小边界框
1 //record the current drawing coordinates 2 function recordCoor(event) 3 { 4 //get current mouse coordinate 5 var pointer = canvas.getPointer(event.e); 6 var posX = pointer.x; 7 var posY = pointer.y;89 //record the point if withing the canvas and the 10 mouse is pressed if(posX >=0 && posY >= 0 && mousePressed) 11 { 12 coords.push(pointer) 13 } 14 }1516 //get the best bounding box by finding the top left and bottom right cornders 17 function getMinBox(){1819 var coorX = coords.map(function(p) {return p.x});20 var coorY = coords.map(function(p) {return p.y}); 21 //find top left corner 22 var min_coords = { 23 x : Math.min.apply(null, coorX), 24 y : Math.min.apply(null, coorY) 25 } 26 //find right bottom corner 27 var max_coords = { 28 x : Math.max.apply(null, coorX), 29 y : Math.max.apply(null, coorY) 30 } 31 return { 32 min : min_coords, 33 max : max_coords 34 } 35 }
测试绘图
下面是大家初次绘图出现的最频繁的图样。 我用鼠标画了所有的图样。 如果使用笔,准确性会更高。
想了解 TensorFlow js. 组件更多实操案例,请前往 Bilibili Google 中国—— TensorFlow 频道查看 Made With TensorFlow js. 中文系列视频。
https://www.bilibili.com/video/BV1D54y1p7PQ
有关 TensorFlow 更多资讯,可前往 TensorFlow 中国官网(tensorflow.google.cn)查看,或扫描下方二维码,关注 TensorFlow 官方公众号!
使用 Colab 在 tf.keras 中训练模型,并使用 TensorFlow.js 在浏览器中运行相关推荐
- 狗和披萨:使用TensorFlow.js在浏览器中实现计算机视觉
目录 起点 托管说明 MobileNet v1 运行物体识别 终点线 下一步是什么?绒毛动物? 下载TensorFlowJS示例-6.1 MB TensorFlow + JavaScript.现在,最 ...
- 使用TensorFlow.js在浏览器中进行深度学习入门
目录 设置TensorFlow.js 创建训练数据 检查点 定义神经网络模型 训练AI 测试结果 终点线 内存使用注意事项 下一步是什么?狗和披萨? 下载TensorFlowJS示例-6.1 MB T ...
- 使用face-api和Tensorflow.js在浏览器中进行AI年龄估计
目录 性别和年龄检测 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们学习了如何使用face-api.js和Tensorflow.js在浏览器中对人的情绪进行分类. 如果您尚未阅读该文章, ...
- 图像迁移风格保存模型_用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型...
风格迁移一直是很多读者感兴趣的内容之一,近日,网友ReiichiroNakano公开了自己的一个实现:用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型.让我们一起去看看吧! Gi ...
- 用 TensorFlow.js 在浏览器中训练一个计算机视觉模型(手写数字分类器)
文章目录 Building a CNN in JavaScript Using Callbacks for Visualization Training with the MNIST Dataset ...
- 使用迁移学习和TensorFlow.js在浏览器中进行AI情感检测
目录 KNN分类器 迁移学习 我们的技术栈 配置 使用KNN分类器 将代码放在一起 测试结果 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们已经看到了加载预训练模型有多么容易.在本文中, ...
- 用TensorFlow.js在浏览器中进行实时语义分割 | MixLab算法系列
语义分割是监测和描绘图像中每个感兴趣对象的问题 当前,有几种方法可以解决此问题并输出结果 如下图示: 语义分割示例 这种分割是对图像中的每个像素进行预测,也称为密集预测. 十分重要且要注意的是,同一类 ...
- 有了TensorFlow.js,浏览器中也可以实时人体姿势估计
翻译文章,内容有删减.原文地址:https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-te ...
- js判断wifi_使用JS在浏览器中判断当前网络连接状态的几种方法
使用JS在浏览器中判断当前网络状态的几种方法如下: 1. navigator.onLine 2. ajax请求 3. 获取网络资源 4. bind() 1. navigator.onLine 通过na ...
最新文章
- 修改mint-ui的主题色
- Java——容器(Comparable)
- java面试mysql的引擎_面试官:你用过mysql哪些存储引擎,请分别展开介绍一下
- Nginx全局块的其他配置指令
- SAP UI5 BarcodeScannerButton 的初始化逻辑 - feature 检测,Cordova API 检测等逻辑
- 如何在ASP.NET Core 中快速构建PDF文档
- 算法一看就懂之「 堆栈 」
- php生成excel范例,支持任意行列
- 基本操作2-常用命令
- 《我的眼睛--图灵识别》第九章:训练:制作识别字库
- 影响力最大化算法——degreediscount以及python实现代码
- 金蝶专业版怎么反过账当月_金蝶KIS专业版没有反过账功能,怎么反过账
- Google与百度、搜狗合作,共同推进移动网络发展
- 开关柜绝缘状态检测与故障诊断
- 隐藏IP地址的4个好处
- 人工蜂群算法python_python如何实现人工蜂群算法 python实现人工蜂群算法代码示例...
- 华硕服务器设置固态盘启动不了系统盘,华硕uefi引导启动不了系统安装系统安装...
- 计算机辅助数学教学论文,计算机辅助数学教学论文
- 操作系统中用户态和内核态(系统态)是什么?用户态如何变成内核态?
- samp自建服务器教程,网管实战:十分钟建立SAMP开发环境