文 / Zaid Alyafeai

我们将创建一个简单的工具来识别图纸并输出当前图纸的名称。 此应用程序将直接在浏览器上运行,无需任何安装。我们会使用 Google Colab 来训练模型,并使用 TensorFlow.js 在浏览器上部署它。

【想获取 TensorFlow js. 视频教程,请前往 Bilibili,TensorFlow 渠道查看:https://www.bilibili.com/video/BV1D54y1p7PQ】

代码和演示

在 GitHub 上找到现场演示和代码。 另外,请务必在此处测试 Google Colab 上的 notebook。

注:此处链接

https://colab.research.google.com/github/zaidalyafeai/zaidalyafeai.github.io/blob/master/sketcher/Sketcher.ipynb

数据集

我们将使用 CNN 识别不同类型的图样。 CNN 将在 Quick Draw 数据集上进行训练。 该数据集包含大约 345 个类别 5000 万个图样。

类的子集

传递途径

我们将使用 Keras 在 Google Colab 的 GPU 上免费训练模型,然后使用 TensorFlow.js(tfjs)直接在浏览器上运行。 我在 TensorFlow.js 上制作了一个教程,烦请阅读之后再继续。 这是该项目的传递途径:

在 Colab 上培训

Google 在 GPU 上提供免费处理能力。 您可以在本教程中看到如何创建笔记本和激活 GPU 编程。

输入

我们将基于 tensorflow 使用 keras

1    import os    2    import glob    3    import numpy as np    4    from tensorflow.keras import layers    5    from tensorflow import keras    6    import tensorflow as tf    

加载数据

由于内存有限,我们不会对所有类别进行训练。 我们只使用 100 个数据集。 每个类别的数据在 Google Cloud 上可用作形状为 [N,784] 的 numpy 数组,其中 N 是该特定类的图像数。 我们首先下载数据集

1    import urllib.request    2    def download():    34        base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'    5        for c in classes:    6            cls_url = c.replace('_', '%20')    7            path = base+cls_url+'.npy'    8            print(path)    9            urllib.request.urlretrieve(path, 'data/'+c+'.npy')  

由于内存有限,我们只会将每个类别中的 5000 张图像加载到内存。 还保留 20% 的未经测试的数据

1    def load_data(root, vfold_ratio=0.2, max_items_per_class= 5000 ):    2        all_files = glob.glob(os.path.join(root, '*.npy'))    3    4        #initialize variables     5        x = np.empty([0, 784])    6        y = np.empty([0])    7        class_names = []    89        #load a subset of the data to memory     10        for idx, file in enumerate(all_files):    11            data = np.load(file)    12            data = data[0: max_items_per_class, :]    13            labels = np.full(data.shape[0], idx)    1415            x = np.concatenate((x, data), axis=0)    16            y = np.append(y, labels)    1718            class_name, ext = os.path.splitext(os.path.basename(file))    19            class_names.append(class_name)    2021        data = None    22        labels = None    2324        #separate into training and testing     25        permutation = np.random.permutation(y.shape[0])    26        x = x[permutation, :]    27        y = y[permutation]    2829        vfold_size = int(x.shape[0]/100*(vfold_ratio*100))    3031        x_test = x[0:vfold_size, :]    32        y_test = y[0:vfold_size]    3334        x_train = x[vfold_size:x.shape[0], :]    35        y_train = y[vfold_size:y.shape[0]]    return x_train, y_train, x_test, y_test, class_names

预处理数据

我们预处理数据准备开始训练。

1    # Reshape and normalize    2    x_train = x_train.reshape(x_train.shape[0], image_size, image_size, 1).astype('float32')    3    x_test = x_test.reshape(x_test.shape[0], image_size, image_size, 1).astype('float32')    45    x_train /= 255.0    6    x_test /= 255.0    78    # Convert class vectors to class matrices    9    y_train = keras.utils.to_categorical(y_train, num_classes)    10    y_test = keras.utils.to_categorical(y_test, num_classes)

创建模型

我们将创建一个简单的 CNN。 请注意,参数数量越少,模型越简单越好。 实际上,我们将在浏览器转换后运行模型,并且我们希望让模型快速运行并进行预测。 以下模型包含 3 个转换层和 2 个密集层。

1    # Define model    2    model = keras.Sequential()    3    model.add(layers.Convolution2D(16, (3, 3),    4                            padding='same',    5                            input_shape=x_train.shape[1:], activation='relu'))    6    model.add(layers.MaxPooling2D(pool_size=(2, 2)))7    model.add(layers.Convolution2D(32, (3, 3), padding='same', activation= 'relu'))    8    model.add(layers.MaxPooling2D(pool_size=(2, 2)))9    model.add(layers.Convolution2D(64, (3, 3), padding='same', activation= 'relu'))    10    model.add(layers.MaxPooling2D(pool_size =(2,2)))11    model.add(layers.Flatten())    12    model.add(layers.Dense(128, activation='relu'))    13    model.add(layers.Dense(100, activation='softmax'))    14    # Train model    15    adam = tf.train.AdamOptimizer()    16    model.compile(loss='categorical_crossentropy',    17                    optimizer=adam,    18                    metrics=['top_k_categorical_accuracy']) 19    print(model.summary())

适配,验证和测试

之后,我们基于 5 个 epochs 和 256 个 batch 训练模型。

1    #fit the model     2    model.fit(x = x_train, y = y_train, validation_split=0.1, batch_size = 256, verbose=2, epochs=5)    34    #evaluate on unseen data    5    score = model.evaluate(x_test, y_test, verbose=0)6    print('Test accuarcy: {:0.2f}%'.format(score[1] * 100)) 

以下是训练的结果

测试精度为 92.20%。

准备 Web 格式的模型

在我们对模型的准确性感到满意之后,我们将其保存以便进行转换

1    model.save('keras.h5')

我们安装了 tfjs 包进行转换

1    !pip install tensorflowjs 

之后我们转换该模型

1    !mkdir model    2    !tensorflowjs_converter --input_format keras keras.h5 model/

这将创建一些权重文件以及包含模型体系结构的 json 文件。

压缩模型,准备将其下载到本地计算机

1    !zip -r model.zip model

最后下载模型

1    from google.colab import files    2    files.download('model.zip')

浏览器推断

在本节中,我们将展示如何加载模型并进行推理。 假设我们有一个尺寸为 300 x 300 的画布。 关于界面和 TensorFlow.js 部分我就不一一详细展开了。

加载模型

为了使用 TensorFlow.js 首先我们使用以下的脚本

1    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>

你需要在本地计算机上运行服务器来承载权重文件。 可以像我一样在项目上创建一个 apache 服务器或在 GitHub 上托管页面。

之后,将模型加载到浏览器

1    model = await tf.loadModel('model/model.json')

使用 await 等待浏览器加载模型。

预处理 

我们需要在进行预测之前预处理数据。 首先从画布中获取图像数据

1    //the minimum boudning box around the current drawing2    const mbb = getMinBox()3    //cacluate the dpi of the current window 4    const dpi = window.devicePixelRatio5    //extract the image data 6    const imgData = canvas.contextContainer.getImageData(mbb.min.x * dpi, mbb.min.y * dpi,7                                                (mbb.max.x - mbb.min.x) * dpi, (mbb.max.y - mbb.min.y) * dpi);

我们以后再解释 getMinBox()。 变量 dpi 用于根据屏幕像素的密度拉伸画布。

我们将画布的当前图像数据转换为张量,调整大小并进行标准化。

1    function preprocess(imgData)    2    {    3    return tf.tidy(()=>{    4            //convert the image data to a tensor     5            let tensor = tf.fromPixels(imgData, numChannels= 1)    6            //resize to 28 x 28     7            const resized = tf.image.resizeBilinear(tensor, [28, 28]).toFloat()    8            // Normalize the image     9            const offset = tf.scalar(255.0);    10        const normalized = tf.scalar(1.0).sub(resized.div(offset));    11            //We add a dimension to get a batch shape     12            const batched = normalized.expandDims(0)13            return batched    14    })    15    }

对于预测,我们使用 model.predict 这将返回形状为 [N,100]的概率。

1    const pred = model.predict(preprocess(imgData)).dataSync()

然后我们可以使用简单的函数来找到前 5 个概率。

提高准确率

请记住,我们的模型接受形状为 [N,28,28,1] 的张量。 我们的绘图画布尺寸为 300 x 300,对于绘图来说可能是太大了,或者说,用户可能想绘制一个小图。 最好裁剪到仅包含当前图形大小的框。 为此,我们通过查找左上角和右下角来提取图形周围的最小边界框

1    //record the current drawing coordinates       2    function recordCoor(event)    3    {    4        //get current mouse coordinate     5        var pointer = canvas.getPointer(event.e);    6        var posX = pointer.x;    7        var posY = pointer.y;89        //record the point if withing the canvas and the 10        mouse is pressed     if(posX >=0 && posY >= 0 && mousePressed)    11        {      12            coords.push(pointer)    13        }    14    }1516    //get the best bounding box by finding the top left and bottom right cornders        17    function getMinBox(){1819        var coorX = coords.map(function(p) {return p.x});20        var coorY = coords.map(function(p) {return p.y}); 21        //find top left corner     22        var min_coords = {    23        x : Math.min.apply(null, coorX),    24        y : Math.min.apply(null, coorY)    25        }    26        //find right bottom corner     27        var max_coords = {    28        x : Math.max.apply(null, coorX),    29        y : Math.max.apply(null, coorY)    30        }    31        return {    32        min : min_coords,    33        max : max_coords    34        }    35    }   

测试绘图

下面是大家初次绘图出现的最频繁的图样。 我用鼠标画了所有的图样。 如果使用笔,准确性会更高。

想了解 TensorFlow js. 组件更多实操案例,请前往 Bilibili Google 中国—— TensorFlow 频道查看 Made With TensorFlow js. 中文系列视频。

https://www.bilibili.com/video/BV1D54y1p7PQ

有关 TensorFlow 更多资讯,可前往 TensorFlow 中国官网(tensorflow.google.cn)查看,或扫描下方二维码,关注 TensorFlow 官方公众号!

使用 Colab 在 tf.keras 中训练模型,并使用 TensorFlow.js 在浏览器中运行相关推荐

  1. 狗和披萨:使用TensorFlow.js在浏览器中实现计算机视觉

    目录 起点 托管说明 MobileNet v1 运行物体识别 终点线 下一步是什么?绒毛动物? 下载TensorFlowJS示例-6.1 MB TensorFlow + JavaScript.现在,最 ...

  2. 使用TensorFlow.js在浏览器中进行深度学习入门

    目录 设置TensorFlow.js 创建训练数据 检查点 定义神经网络模型 训练AI 测试结果 终点线 内存使用注意事项 下一步是什么?狗和披萨? 下载TensorFlowJS示例-6.1 MB T ...

  3. 使用face-api和Tensorflow.js在浏览器中进行AI年龄估计

    目录 性别和年龄检测 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们学习了如何使用face-api.js和Tensorflow.js在浏览器中对人的情绪进行分类. 如果您尚未阅读该文章, ...

  4. 图像迁移风格保存模型_用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型...

    风格迁移一直是很多读者感兴趣的内容之一,近日,网友ReiichiroNakano公开了自己的一个实现:用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型.让我们一起去看看吧! Gi ...

  5. 用 TensorFlow.js 在浏览器中训练一个计算机视觉模型(手写数字分类器)

    文章目录 Building a CNN in JavaScript Using Callbacks for Visualization Training with the MNIST Dataset ...

  6. 使用迁移学习和TensorFlow.js在浏览器中进行AI情感检测

    目录 KNN分类器 迁移学习 我们的技术栈 配置 使用KNN分类器 将代码放在一起 测试结果 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们已经看到了加载预训练模型有多么容易.在本文中, ...

  7. 用TensorFlow.js在浏览器中进行实时语义分割 | MixLab算法系列

    语义分割是监测和描绘图像中每个感兴趣对象的问题 当前,有几种方法可以解决此问题并输出结果 如下图示: 语义分割示例 这种分割是对图像中的每个像素进行预测,也称为密集预测. 十分重要且要注意的是,同一类 ...

  8. 有了TensorFlow.js,浏览器中也可以实时人体姿势估计

    翻译文章,内容有删减.原文地址:https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-te ...

  9. js判断wifi_使用JS在浏览器中判断当前网络连接状态的几种方法

    使用JS在浏览器中判断当前网络状态的几种方法如下: 1. navigator.onLine 2. ajax请求 3. 获取网络资源 4. bind() 1. navigator.onLine 通过na ...

最新文章

  1. 修改mint-ui的主题色
  2. Java——容器(Comparable)
  3. java面试mysql的引擎_面试官:你用过mysql哪些存储引擎,请分别展开介绍一下
  4. Nginx全局块的其他配置指令
  5. SAP UI5 BarcodeScannerButton 的初始化逻辑 - feature 检测,Cordova API 检测等逻辑
  6. 如何在ASP.NET Core 中快速构建PDF文档
  7. 算法一看就懂之「 堆栈 」
  8. php生成excel范例,支持任意行列
  9. 基本操作2-常用命令
  10. 《我的眼睛--图灵识别》第九章:训练:制作识别字库
  11. 影响力最大化算法——degreediscount以及python实现代码
  12. 金蝶专业版怎么反过账当月_金蝶KIS专业版没有反过账功能,怎么反过账
  13. Google与百度、搜狗合作,共同推进移动网络发展
  14. 开关柜绝缘状态检测与故障诊断
  15. 隐藏IP地址的4个好处
  16. 人工蜂群算法python_python如何实现人工蜂群算法 python实现人工蜂群算法代码示例...
  17. 华硕服务器设置固态盘启动不了系统盘,华硕uefi引导启动不了系统安装系统安装...
  18. 计算机辅助数学教学论文,计算机辅助数学教学论文
  19. 操作系统中用户态和内核态(系统态)是什么?用户态如何变成内核态?
  20. samp自建服务器教程,网管实战:十分钟建立SAMP开发环境

热门文章

  1. python 股票信息分析
  2. EasyGBS云台控制对讲功能因端口不通导致功能失效如何解决?
  3. js中eval方法的使用
  4. sqlserver 根据汉字获取拼音首字母 函数
  5. 微信公众号开发者模式回复信息带表情(QQ,emoji)
  6. [Ynoi2017]舌尖上的由乃
  7. 利用ENVI进行辐射定标和投影转换
  8. 基于Hi3516AV200/Hi3519V101的Qt绘图优化
  9. gm怎么刷东西 rust_腐蚀RUST开挂玩家识别方法 如何识别玩家开挂
  10. 我推出了微博寻人在线系统