场景:对包含单个数字的图片进行识别,识别出图片中的数字

训练数据: 采用 mnist 数据集中的 60000张灰度图像(每个像素值范围:0-255),每张图像用一个 28x28 像素的矩阵表示,以及每张图像表示的是 0-9 中的哪一个数字。

输入:一个 28x28 像素的灰度图像 (目标:对输入的这个图片进行数字识别)

输出:0-9 的数字 (识别出来的数字)

模型训练的代码实现:

以下代码先加载 mnist 的图片数据集,然后构建模型进行训练,评估模型,图形化展示训练集和测试集的损失和准确度。最后保存模型到文件。

train.py 代码

#### train.py 训练手写数字体图片识别的模型
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd#载入 MNIST 数据集,并将整型转换为浮点型,除以 255 是为了归一化。
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0#使用 tf.keras.Sequential 建立模型,并且选择优化器和损失函数
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy', metrics=['accuracy'])#训练模型
history = model.fit(x_train, y_train, epochs=5,validation_data=(x_test,y_test))#模型评估
model.evaluate(x_test,  y_test, verbose=2)#查看训练集与测试集的均方误差和准确率变化情况
history.history.keys()#查看 training set, validation set 的损失和准确率
plt.plot(history.epoch,history.history.get('loss'),label='Loss')
plt.plot(history.epoch,history.history.get('val_loss'),label='Validation Loss')
plt.legend()
plt.show()plt.plot(history.epoch,history.history.get('accuracy'),label='Accuracy')
plt.plot(history.epoch,history.history.get('val_accuracy'),label='Validation Accuracy')
plt.legend()
plt.show()# 保存全模型
model.save('tf_model.h5')

在 Mac M1 上运行代码

python3 train.py

代码运行报错, 如下

Epoch 1/5
2021-10-28 00:47:57.991 python3[17544:2291672] -[MPSGraph adamUpdateWithLearningRateTensor:beta1Tensor:beta2Tensor:epsilonTensor:beta1PowerTensor:beta2PowerTensor:valuesTensor:momentumTensor:velocityTensor:maximumVelocityTensor:gradientTensor:name:]: unrecognized selector sent to instance 0x12e992940
2021-10-28 00:47:58.013 python3[17544:2291672] *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '-[MPSGraph adamUpdateWithLearningRateTensor:beta1Tensor:beta2Tensor:epsilonTensor:beta1PowerTensor:beta2PowerTensor:valuesTensor:momentumTensor:velocityTensor:maximumVelocityTensor:gradientTensor:name:]: unrecognized selector sent to instance 0x12e992940'
*** First throw call stack:
(0   CoreFoundation                      0x0000000191a9f838 __exceptionPreprocess + 2401   libobjc.A.dylib                     0x00000001917c90a8 objc_exception_throw + 602   CoreFoundation                      0x0000000191b30694 -[NSObject(NSObject) __retain_OA] + 03   CoreFoundation                      0x0000000191a00cd4 ___forwarding___ + 14444   CoreFoundation                      0x0000000191a00670 _CF_forwarding_prep_0 + 965   libmetal_plugin.dylib               0x000000011f89a290 _ZN12metal_plugin14MPSApplyAdamOpIfEC2EPNS_20OpKernelConstructionE + 6566   libmetal_plugin.dylib               0x000000011f899ebc _ZN12metal_pluginL14CreateOpKernelINS_14MPSApplyAdamOpIfEEEEPvP23TF_OpKernelConstruction + 527   libtensorflow_framework.2.dylib     0x00000001159d85d4 _ZN10tensorflow12_GLOBAL__N_120KernelBuilderFactory6CreateEPNS_20OpKernelConstructionE + 88…
30  _pywrap_tfe.so                      0x0000000116e6e41c _ZN10tensorflow32TFE_Py_ExecuteCancelable_wrapperERKN8pybind116handleEPKcS5_S3_S3_PNS_19CancellationManagerES3_ + 16031  _pywrap_tfe.so                      0x0000000116e9f208 _ZZN8pybind1112cpp_function10initializeIZL25pybind11_init__pywrap_tfeRNS_7module_EE4$_44NS_6objectEJRKNS_6handleEPKcSA_S8_S8_S8_EJNS_4nameENS_5scopeENS_7siblingEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE_8__invokeESR_ + 18432  _pywrap_tfe.so                      0x0000000116e810e0 _ZN8pybind1112cpp_function10dispatcherEP7_objectS2_S2_ + 321633  python3                             0x0000000100d07398 cfunction_call + 8034  python3                             0x0000000100cb31e8 _PyObject_MakeTpCall + 34035  python3                             0x0000000100dc36ac call_function + 724
….77  python3                             0x0000000100e1ad48 PyRun_SimpleFileExFlags + 81678  python3                             0x0000000100e3de84 Py_RunMain + 291679  python3                             0x0000000100e3f018 pymain_main + 127280  python3                             0x0000000100c59ddc main + 5681  libdyld.dylib                       0x0000000191941430 start + 4
)
libc++abi: terminating with uncaught exception of type NSException

运行这个错误是因为 ADAM 的优化函数在执行的时候出错, 把代码中的 adam 换成 sdg在 m1 上可以正常执行。

model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy', metrics=['accuracy'])

很快运行出了结果,随着迭代的不断进行,准确度也越来越高。

Metal device set to: Apple M1systemMemory: 16.00 GB
maxCacheSize: 5.33 GBEpoch 1/5
1875/1875 [==============================] - 7s 4ms/step - loss: 0.7205 - accuracy: 0.8033 - val_loss: 0.3633 - val_accuracy: 0.9048
Epoch 2/5
1875/1875 [==============================] - 7s 4ms/step - loss: 0.3846 - accuracy: 0.8904 - val_loss: 0.2933 - val_accuracy: 0.9189
Epoch 3/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.3211 - accuracy: 0.9083 - val_loss: 0.2533 - val_accuracy: 0.9303
Epoch 4/5
1875/1875 [==============================] - 7s 4ms/step - loss: 0.2833 - accuracy: 0.9200 - val_loss: 0.2282 - val_accuracy: 0.9358
Epoch 5/5
1875/1875 [==============================] - 7s 4ms/step - loss: 0.2560 - accuracy: 0.9271 - val_loss: 0.2074 - val_accuracy: 0.9426
313/313 - 1s - loss: 0.2074 - accuracy: 0.9426

前面5行代码具体是什么作用,后面再做详细的讲解。先了解建模的步骤。

接下来,使用前面创建的模型来做预测。 有3个输入的图片, digit-number-3.jpg, digit-number-4.jpg, digit-number-7.jpg,分别对应数字 3, 4, 7。 这3个图片中, 7, 4 是手写的一个数字,3 是从一张图片上截取下来的片段。图片在代码仓库中有。

下面加载模型, 实现一个函数, 使用模型 来对图片文件的内容做预测(单个数字的识别):

#### predict.py
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from silence_tensorflow import silence_tensorflow
silence_tensorflow()#调用模型
new_model = tf.keras.models.load_model('tf_model.h5')#调用模型对输入的图片进行识别,输出一个预测的数字
def predict_digit(filename):im = Image.open(filename)  #读取图片路径im = im.resize((28,28)) #调整大小和模型输入大小一致im = np.array(im)#对图片进行灰度化处理p3 = im.min(axis = -1)plt.imshow(p3,cmap = 'gray')plt.show()#将白底黑字变成黑底白字   由于训练模型是这种格式for i in range(28):for j in range(28):p3[i][j] = 255-p3[i][j]#模型输出结果是每个类别的概率,取最大的概率的类别就是预测的结果ret = new_model.predict((p3/255).reshape((1,28,28)))number = np.argmax(ret) return numberinput_file = "digit-number-7.jpg"
print("filename: %s predicted:%s" % ( input_file, predict_digit(input_file) ) )input_file = "digit-number-4.jpg"
print("filename: %s predicted:%s" % ( input_file, predict_digit(input_file) ) )input_file = "digit-number-3.jpg"
print("filename: %s predicted:%s" % ( input_file, predict_digit(input_file) ) )

下面是预测的结果

Metal device set to: Apple M1systemMemory: 16.00 GB
maxCacheSize: 5.33 GBfilename: digit-number-7.jpg predicted:7
filename: digit-number-4.jpg predicted:4
filename: digit-number-3.jpg predicted:8

7, 4 两张图片的数字识别是准确的。 3的识别成了8,识别的是不准确的。 完整的代码参考:

https://github.com/davideuler/beauty-of-math-in-deep-learning.git

这是一个非常简单的构建和使用神经网络的例子。 实际的图片识别中,往往不是识别单个的字符,而是识别连续的字符, 那么还需要使用图片分割的算法对图片进行分割。 同时识别的也不仅仅是数字,可能还有字母,中文,或者其他语言的文字,都可以使用类似的方法来进行训练和识别。

神经网络模型不仅仅用于图像识别, 语音识别,语义理解,图像分割,机器翻译等等领域都可以用到。 神经网络是机器学习的一种, 其处理过程包含两个步骤:

学习:输入的训练集进行学习(从已知结果/打过标签的对象和输入的特征进行学习)

预测/推理:对未知的对象,根据输入特征,自动做推理预测(打标签)

其他的场景都可以类似前面的代码过程来处理。 后面的文章继续介绍如何手写一个神经网络,你会对前面的代码有更多理解。

深度学习代码实践(三)5行代码创建手写数字体识别的Tensorflow模型相关推荐

  1. 深度学习(32)随机梯度下降十: 手写数字识别问题(层)

    深度学习(32)随机梯度下降十: 手写数字识别问题(层) 1. 数据集 2. 网络层 3. 网络模型 4. 网络训练 本节将利用前面介绍的多层全连接网络的梯度推导结果,直接利用Python循环计算每一 ...

  2. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  3. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 在https://blog.csdn.net/A ...

  4. pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...

  5. 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)

    写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 提起LSTM大家第一反应是在NLP的数据集上比较 ...

  6. 【深度学习】实验1答案:Softmax实现手写数字识别

    DL_class 学堂在线<深度学习>实验课代码+报告(其中实验1和实验6有配套PPT),授课老师为胡晓林老师.课程链接:https://www.xuetangx.com/training ...

  7. 【神经网络与深度学习】第一章 使用神经网络来识别手写数字

    人类的视觉系统,是大自然的奇迹之一. 来看看下面一串手写的数字: 大多数人可以毫不费力地认出这些数字是504192.这种轻松是欺骗性的,我们觉得很轻松的一瞬,其实背后过程非常复杂. 在我们大脑的每个半 ...

  8. 【深度学习】基于Numpy实现的神经网络进行手写数字识别

    直接先用前面设定的网络进行识别,即进行推理的过程,而先忽视学习的过程. 推理的过程其实就是前向传播的过程. 深度学习也是分成两步:学习 + 推理.学习就是训练模型,更新参数:推理就是用学习到的参数来处 ...

  9. tensorflow 语义slam_研究《视觉SLAM十四讲从理论到实践第2版》PDF代码+《OpenCV+TensorFlow深度学习与计算机视觉实战》PDF代码笔记...

    我们知道随着人工神经网络和深度学习的发展,通过模拟视觉所构建的卷积神经网络模型在图像识别和分类上取得了非常好的效果,借助于深度学习技术的发展,使用人工智能去处理常规劳动,理解语音语义,帮助医学诊断和支 ...

最新文章

  1. 程序员的视角:java GC
  2. flutter打开android界面,在已有Android项目中使用Flutter
  3. 定位插件_微创新 | 开发PL/SQL插件,快速定位所需字段
  4. mysql把一个数据库中的数据复制到另一个数据库中的表 2个表结构相同
  5. 利用cookies实现对弹出窗口频率的控制
  6. [Offer收割]编程练习赛48
  7. 凸优化第八章几何问题 8.6 分类
  8. dynamix判定_Dynamix
  9. 廖雪峰Git教程笔记(十一)添加远程库
  10. 解决el-table 树形结构expand 操作后 stripe 显示失效问题
  11. 第3章 Linux内核调试手段之内核打印
  12. karabiner-elements Mac下实现按键全定制 capslox完美替代品
  13. 推荐系统深度学习篇-NFM 模型介绍(1)
  14. 阿里巴巴高德地图春季2023届校园招聘正式启动!
  15. 《那些年啊,那些事——一个程序员的奋斗史》——40
  16. 【vim】禁止vim生成 un~文件
  17. 请问,我要去工商局申请一个工作室,法律上需要那些流程
  18. Exercise14_11
  19. centos7部署openwhisk
  20. java中<<与>>的意思

热门文章

  1. 日期选择器时间选择范围限制
  2. debian下配置网络 安装无线网卡驱动 Broadcom BCMXX系列
  3. 使用Python 录音、调整音量、播放
  4. Zynga收购Rovio和PopCap无果,企业文化是关键
  5. 卡尔.波普尔摘要: 三个世界
  6. joost(p2p)
  7. ASP: Response 对象 错误 'ASP 0251 : 80004005' 解决办法
  8. 微信网页授权(前端)
  9. 金蝶EAS标准登录接口EASLogin
  10. TI-BASIC 计算器游戏开发之文字、图形、音频教程 II:图形处理