在VsCode中利用TensorFlow.js结合迁移学习实现商标识别。

一、加载商标数据并可视化

数据保存在data文件夹下面,需要先在data文件夹下创建一个静态服务器,用于加载图片。

http-server data --cors
Available on:http://192.168.4.167:8080http://127.0.0.1:8080
Hit CTRL-C to stop the server

编写获取图片的脚本文件。

const IMAGE_SIZE = 224;const loadImg = (src) => {return new Promise(resolve => {const img = new Image();img.crossOrigin = "anonymous";img.src = src;img.width = IMAGE_SIZE;img.height = IMAGE_SIZE;img.onload = () => resolve(img);});
};
export const getInputs = async () => {const loadImgs = [];const labels = [];for (let i = 0; i < 30; i += 1) {['android', 'apple', 'windows'].forEach(label => {const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;const img = loadImg(src);loadImgs.push(img);labels.push([label === 'android' ? 1 : 0,label === 'apple' ? 1 : 0,label === 'windows' ? 1 : 0,]);});}const inputs = await Promise.all(loadImgs);return {inputs,labels,};
}

创建index.html文件,作为程序的入口文件,在index.html中利用script标签跳转到script.js文件,在script.js中编写主要代码。

加载图片数据。

import {getInputs} from "./data";window.onload = async() =>{const {inputs, labels} = await getInputs();console.log(inputs, labels)};

利用TensorFlow.js中的tfvis进行可视化。

    import * as tfvis from "@tensorflow/tfjs-vis"// 可视化图片const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });inputs.forEach(img => {surface.drawArea.appendChild(img);});

每行显示两个,旁边的滚动体可以拉动查看更多图片。

二、定义模型结构

加载MobileNet模型并截断所有的卷积池化操作,生成截断模型。并定义新的全连接层。

    import * as tf from "@tensorflow/tfjs"// mobilenet模型存放位置const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';// 加载MobileNet模型const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);// 查看模型结构mobilenet.summary();// 截断mobilenet卷积操作const layer = mobilenet.getLayer('conv_pw_13_relu');const truncatedMobilenet = tf.model({inputs: mobilenet.inputs,outputs: layer.output});// 定义全连接层const model = tf.sequential();model.add(tf.layers.flatten({inputShape: layer.outputShape.slice(1)}));model.add(tf.layers.dense({units: 10,activation: 'relu'}));// 定义输出层model.add(tf.layers.dense({units: NUM_CLASSES,activation: 'softmax'}));// 配置损失函数和优化器model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

三、迁移学习下的模型训练

首先先定义一个工具类utils.js,用于处理输入到截断模型(mobilenet)中的数据。

import * as tf from '@tensorflow/tfjs';
// img格式转成tensor
export function img2x(imgEl){return tf.tidy(() => {const input = tf.browser.fromPixels(imgEl).toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);return input;});
}
// 图片文件转成img格式
export function file2img(f) {return new Promise(resolve => {const reader = new FileReader();reader.readAsDataURL(f);reader.onload = (e) => {const img = document.createElement('img');img.src = e.target.result;img.width = 224;img.height = 224;img.onload = () => resolve(img);};});
}

训练数据经过截断模型输出,转为可以用于自定义的全连接层的输入数据。

    // 先经过截断模型const { xs, ys } = tf.tidy(() => {const xs = tf.concat(inputs.map(imgEl =>truncatedMobilenet.predict(img2x(imgEl))));const ys = tf.tensor(labels);return { xs, ys };});// 截断模型的输出当成自定义模型的输入await model.fit(xs, ys, {epochs: 20,callbacks: tfvis.show.fitCallbacks({ name: '训练效果' },['loss'],{ callbacks: ['onEpochEnd'] })});

可以看出训练损失值降得非常低,因为采用迁移模型,卷积层的参数是使用别人训练好的,这部分参数的训练结果是非常优秀的。

四、预测

编写前端页面用于上传带预测图片,就是编写一个上传按钮。

<script src="script.js"></script><input type="file" onchange="predict(this.files[0])">

将预测图片先经过mobileNet预测,吐出来的结果再经过自定义模型预测。

    window.predict = async (file) => {const img = await file2img(file);document.body.appendChild(img);const pred = tf.tidy(() => {const x = img2x(img);const input = truncatedMobilenet.predict(x);return model.predict(input);});const index = pred.argMax(1).dataSync()[0];setTimeout(() => {alert(`预测结果:${BRAND_CLASSES[index]}`);}, 0);};

五、完整代码

index.html

<script src="script.js"></script><input type="file" onchange="predict(this.files[0])">

script.js 

import * as tf from "@tensorflow/tfjs"
import * as tfvis from "@tensorflow/tfjs-vis"
import {getInputs} from "./data";
import {img2x, file2img} from "./utils"// mobilenet模型存放位置
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];window.onload = async() =>{// 加载图片const {inputs, labels} = await getInputs();// 可视化图片const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });inputs.forEach(img => {surface.drawArea.appendChild(img);});// 加载MobileNet模型const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);// 查看模型结构mobilenet.summary();// 截断mobilenet卷积操作const layer = mobilenet.getLayer('conv_pw_13_relu');const truncatedMobilenet = tf.model({inputs: mobilenet.inputs,outputs: layer.output});// 定义全连接层const model = tf.sequential();model.add(tf.layers.flatten({inputShape: layer.outputShape.slice(1)}));model.add(tf.layers.dense({units: 10,activation: 'relu'}));// 定义输出层model.add(tf.layers.dense({units: NUM_CLASSES,activation: 'softmax'}));// 配置损失函数和优化器model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });// 先经过截断模型const { xs, ys } = tf.tidy(() => {const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));const ys = tf.tensor(labels);return { xs, ys };});// 截断模型的输出当成自定义模型的输入await model.fit(xs, ys, {epochs: 20,callbacks: tfvis.show.fitCallbacks({ name: '训练效果' },['loss'],{ callbacks: ['onEpochEnd'] })});// 预测window.predict = async (file) => {const img = await file2img(file);document.body.appendChild(img);const pred = tf.tidy(() => {const x = img2x(img);const input = truncatedMobilenet.predict(x);return model.predict(input);});const index = pred.argMax(1).dataSync()[0];setTimeout(() => {alert(`预测结果:${BRAND_CLASSES[index]}`);}, 0);};};

data.js 

const IMAGE_SIZE = 224;const loadImg = (src) => {return new Promise(resolve => {const img = new Image();img.crossOrigin = "anonymous";img.src = src;img.width = IMAGE_SIZE;img.height = IMAGE_SIZE;img.onload = () => resolve(img);});
};
export const getInputs = async () => {const loadImgs = [];const labels = [];for (let i = 0; i < 30; i += 1) {['android', 'apple', 'windows'].forEach(label => {const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;const img = loadImg(src);loadImgs.push(img);labels.push([label === 'android' ? 1 : 0,label === 'apple' ? 1 : 0,label === 'windows' ? 1 : 0,]);});}const inputs = await Promise.all(loadImgs);return {inputs,labels,};
}

utils.js 

import * as tf from '@tensorflow/tfjs';export function img2x(imgEl){return tf.tidy(() => {const input = tf.browser.fromPixels(imgEl).toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);return input;});
}export function file2img(f) {return new Promise(resolve => {const reader = new FileReader();reader.readAsDataURL(f);reader.onload = (e) => {const img = document.createElement('img');img.src = e.target.result;img.width = 224;img.height = 224;img.onload = () => resolve(img);};});
}

TensorFlow.js实现商标识别相关推荐

  1. Tensorflow.js||使用 CNN 识别手写数字

    Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...

  2. 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别

    2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...

  3. 绒毛动物探测器:通过TensorFlow.js中的迁移学习识别浏览器中的自定义对象

    目录 起点 MobileNet v1体系结构上的迁移学习 修改模型 训练新模式 运行物体识别 终点线 下一步是什么?我们可以检测到脸部吗? 下载TensorFlowJS-Examples-master ...

  4. 用tensorflow.js实现浏览器内的手写数字识别

    原文 简介 Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预 ...

  5. 利用tensorflow.js在线实现图像要素识别提取

    什么是Tensorflow.js? TensorFlow.js是一个开源的基于硬件加速的JavaScript库,用于训练和部署机器学习模型.谷歌推出的第一个基于TensorFlow的前端深度学习框架T ...

  6. 小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别

    一.题目要求 1.三个js文件,分别完成:网络训练以及模型保存.模型加载及准确率测试.在线手写数字识别: 2.模型测试准确率要高于99.3%(尽量): 3.在线手写数字识别需要能够通过鼠标在画布中写入 ...

  7. 来自前端开发者的灵魂发问:TensorFlow.js 好学吗?

    本文作者 蔡善清(Shanqing Cai),谷歌公司软件工程师,深度参与了 TensorFlow 和 TensorFlow.js 的开发工作.从清华大学毕业后,他前往约翰斯 · 霍普金斯大学和麻省理 ...

  8. 独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)

    作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...

  9. TensorFlow2020:如何使用Tensorflow.js执行计算机视觉应用程序?

    本文转载自公众号"读芯术"(ID:AI_Discovery). 很多人都能运行操作计算机视觉应用程序.是的,学习并执行它并不难,现在有很多库可以用来执行如此强大的计算机视觉应用程序 ...

最新文章

  1. Kafka如何对Topic元数据进行细粒度的懒加载、同步等待?
  2. linux正则表达式BRE
  3. easyui select ajax,easyui的combobox根据后台数据实现自动输入提示功能
  4. ZOJ 1760 How Many Shortest Path
  5. Maven学习总结(51)——Maven 常用属性和常量说明
  6. centos的服务管理
  7. java jdbc 表存在_JDBC / Java – 如何检查数据库中是否存在表和列?
  8. 6D姿态估计算法汇总(上)
  9. 【图像重建】基于matlab卷积神经网络的图像超分辨率重建【含Matlab源码 1816期】
  10. python注释程序_Python程序里的注释和#号
  11. 如何有效解决企业敏感文件泄露问题
  12. 机器学习算法之SVM的多分类
  13. 消息队列以及非常牛的kafka
  14. 互联网信息安全与加密技术
  15. python pdfminer3k_python 使用pdfminer3k 读取PDF文档的例子
  16. arm开发板上电设置静态ip_Tiny4412友善之臂ARM开发板静态IP设置(重启有效)
  17. ajax提交图片流,img显示
  18. RocketMQ源码分析(十五)之文件恢复
  19. cf596B. Wilbur and Array
  20. 提问的艺术 - 敏捷教练技巧

热门文章

  1. SSN 社会安全号码
  2. 华工计算机学院院长篡改考研成绩,多名院级领导涉嫌篡改研究生复试成绩?华南理工大学正式回应...
  3. 我和Android娘情缘
  4. 如何部署 H5 游戏到云服务器?
  5. qos cbs_我在CBS Interactive担任视频软件工程师实习生的夏天
  6. 运行python文件、电脑突然黑屏_Python初学者请注意!别这样直接运行Python命令,否则电脑等于“裸奔”...
  7. Axure的简单了解
  8. LeetCode75-颜色分类
  9. Cambridge Pixel发布新产品雷达信号输出卡HPx-310
  10. 手把手教会全局透明壁纸,Android2.3以下操作系统适用