Caffe的几个重要文件

用了这么久Caffe都没好好写过一篇新手入门的博客,最近应实验室小师妹要求,打算写一篇简单、快熟入门的科普文。
利用Caffe进行深度神经网络训练第一步需要搞懂几个重要文件:

  1. solver.prototxt
  2. train_val.prototxt
  3. train.sh

接下来我们按顺序一个个说明。

solver.prototxt

solver这个文件主要存放模型训练所用到的一些超参数:

  • net := 指定待训练模型结构文件,即train_val.prototxt
  • test_interval := 测试间隔,即每隔多少次迭代进行一次测试
  • test_initialization := 指定是否进行初始测试,即模型未进行训练时的测试
  • test_iteration := 指定测试时进行的迭代次数
  • base_lr := 指定基本学习率
  • lr_policy := 学习率变更策略,这里有介绍,可供参考
  • gamma := 学习率变更策略需要用到的参数
  • power := 同上
  • stepsize := 学习率变更策略Step的变更步长(固定步长)
  • stepvalue := 学习率变更策略Multistep的变更步长(可变步长)
  • max_iter := 模型训练的最大迭代次数
  • momentum := 动量,这是优化策略(Adam, SGD, … )用到的参数
  • momentum2 := 优化策略Adam用到的参数
  • weight_decay := 权重衰减率
  • clip_gradients := 固定梯度范围
  • display := 每隔几次迭代显示一次结果
  • snapshot := 快照,每隔几次保存一次模型参数
  • snapshot_prefix := 保存模型文件的前缀,可以是路径
  • type := solver优化策略,即SGD、Adam、AdaGRAD、RMSProp、NESTROVE、ADADELTA等
  • solver_mode := 指定训练模式,即GPU/CPU
  • debug_info := 指定是否打印调试信息,这里有对启用该功能的输出作介绍
  • device_id := 指定设备号(使用GPU模式),默认为0

用户根据自己的情况进行相应设置,黑体参数为必须指定的,其余参数为可选(根据情况选择)。

train_val.prototxt

train_val文件是用来存放模型结构的地方,模型的结构主要以layer为单位来构建。下面我们以LeNet为例介绍网络层的基本组成:

name: "LeNet"
layer {name: "mnist"                                #网络层名称type: "Data"                                 #网络层类型,数据层top: "data"                                  #这一层的输出,数据top: "label"                                 #这一层的输出,标签include {    phase: TRAIN  }                 #TRAIN:=用于训练,TEST:=用于测试transform_param {    scale: 0.00390625  }    #对数据进行scaledata_param {                                 #数据层配置 source: "examples/mnist/mnist_train_lmdb"  #数据存放路径batch_size: 64                             #指定batch大小backend: LMDB                              #指定数据库格式,LMDB/LevelDB}
}
layer {name: "mnist"type: "Data"top: "data"top: "label"include {    phase: TEST  }transform_param {    scale: 0.00390625  }data_param {source: "examples/mnist/mnist_test_lmdb"batch_size: 100backend: LMDB}
}
layer{name:"conv1"       type:"Convolution" #卷积层bottom:"data"      #上一层的输出作为输入top:"conv1"        param{name:"conv1_w" lr_mult:1 decay_mult:1} #卷积层参数w的名称,学习率和衰减率(相对于base_lr和weight_decay的倍数)param{name:"conv1_b" lr_mult:2 decay_mult:0} #卷积层参数b的名称,学习率和衰减率convolution_param{num_output:20         #卷积层输出的feature map数量 kernel_size:5         #卷积层的大小pad:0                 #卷积层的填充大小stride:1              #进行卷积的步长weight_filler{type:"xavier" }      #参数w的初始话策略weight_filler{type:"constant" value:0.1}     #参数b的初始化策略}
}
layer {        #BatchNorm层,对feature map进行批规范化处理name:"bn1"type:"BatchNorm"bottom:"conv1"top:"conv1"batch_norm_param{ use_global_stats:false} #训练时为false,测试时为true
}
layer {           #池化层,即下采样层name: "pool1"type: "Pooling"bottom: "conv1"top: "pool1"pooling_param {pool: MAX   #最大值池化,还有AVE均值池化kernel_size: 2stride: 2}
}
layer {name: "conv2"type: "Convolution"bottom: "pool1"top: "conv2"param {    lr_mult: 1  }param {    lr_mult: 2  }convolution_param {num_output: 50kernel_size: 5stride: 1weight_filler {      type: "xavier"    }bias_filler {      type: "constant"    }}
}
layer {name:"bn2"type:"BatchNorm"bottom:"conv2"top:"conv2"batch_norm_param{ use_global_stats:false}
}
layer {name: "pool2"type: "Pooling"bottom: "conv2"top: "pool2"pooling_param {pool: MAXkernel_size: 2stride: 2}
}
layer {                         #全连接层name: "ip1"type: "InnerProduct"bottom: "pool2"top: "ip1"param {    lr_mult: 1  }  param {    lr_mult: 2  }inner_product_param {num_output: 500weight_filler {      type: "xavier"    }bias_filler {      type: "constant"    }}
}
layer {                             #激活函数层,提供非线性能力name: "relu1"type: "ReLU"bottom: "ip1"top: "ip1"
}
layer {name: "ip2"type: "InnerProduct"bottom: "ip1"top: "ip2"param {    lr_mult: 1  }param {    lr_mult: 2  }inner_product_param {num_output: 10weight_filler {      type: "xavier"    }bias_filler {      type: "constant"    }}
}
layer {                             #损失函数层name: "prob"type: "SoftmaxWithLoss"bottom: "ip2"bottom: "label"top: "prob"
}

参数初始化策略可参考这里, 激活函数可参考这里。

网络结构和超参数都设计完了,接下来就可以进行模型训练了。这里我介绍最常用的模型训练脚本,也是Caffe官方文档给的例子。

train.sh

这个脚本文件可写,可不写。每次运行需要写一样的命令,所以建议写一下。

TOOLS=/path/to/your/caffe/build/tools
GLOG_logtostderr=0 GLOG_log_dir=log/ \ #该行用于调用glog进行训练日志保存,使用时请把该行注释删除,否则会出错
$TOOLS/caffe train --solver=/path/to/your/solver.prototxt #--snapshot=/path/to/your/snapshot or --weights=/path/to/your/caffemodel ,snapshot和weights两者只是选一,两个参数都可以用来继续训练,区别在于是否保存solver状态

数据准备

这里我们举个简单的例子,改代码是Caffe官方文档提供的,但只能用于单标签的任务,多标签得对源码进行修改。该脚本是对图片数据生成对应的lmdb文件,博主一般使用原图,即数据层类型用ImageData。

#!/usr/bin/env sh
# Create the imagenet lmdb inputs
# N.B. set the path to the imagenet train + val data dirs
set -eEXAMPLE=""                            #存储路径
DATA=""                               #数据路径
TOOLS=/path/to/your/caffe/build/tools #caffe所在目录TRAIN_DATA_ROOT=""                   #训练数据根目录
VAL_DATA_ROOT=""                     #测试数据根目录
# RESIZE=true to resize the images to 256x256. Leave as false if images have
# already been resized using another tool.
RESIZE=false                         #重新调整图片大小
if $RESIZE; thenRESIZE_HEIGHT=256RESIZE_WIDTH=256
elseRESIZE_HEIGHT=0RESIZE_WIDTH=0
fi#检测路径是否存在
if [ ! -d "$TRAIN_DATA_ROOT" ]; thenecho "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"echo "Set the TRAIN_DATA_ROOT variable in create_imagenet.sh to the path" \"where the ImageNet training data is stored."exit 1
fiif [ ! -d "$VAL_DATA_ROOT" ]; thenecho "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"echo "Set the VAL_DATA_ROOT variable in create_imagenet.sh to the path" \"where the ImageNet validation data is stored."exit 1
fiecho "Creating train lmdb..."GLOG_logtostderr=1 $TOOLS/convert_imageset \--resize_height=$RESIZE_HEIGHT \--resize_width=$RESIZE_WIDTH \--shuffle \$TRAIN_DATA_ROOT \$DATA/train.txt \                #训练图片列表,运行时请把该行注释删除,否则会出错$EXAMPLE/mnist_train_lmdbecho "Creating val lmdb..."GLOG_logtostderr=1 $TOOLS/convert_imageset \--resize_height=$RESIZE_HEIGHT \--resize_width=$RESIZE_WIDTH \--shuffle \$VAL_DATA_ROOT \$DATA/val.txt \$EXAMPLE/mnist_test_lmdbecho "Done."

这样,我们就可以愉快的开始训练啦。


2017-05-15 记。

[Caffe]:关于caffe新手入门相关推荐

  1. 《挑战30天C++入门极限》新手入门:C/C++中枚举类型(enum)

        新手入门:C/C++中枚举类型(enum) 如果一个变量你需要几种可能存在的值,那么就可以被定义成为枚举类型.之所以叫枚举就是说将变量或者叫对象可能存在的情况也可以说是可能的值一一例举出来. ...

  2. LINUX新手入门-1.装系统

    LINUX新手入门-1.装系统 首先我们用虚拟机模拟 装linux系统,然后下一步下一步,然后完成后,编辑一些设置,把镜像放上面就可以了 选第一项,安装系统,查看镜像是否能运行,直接跳过,选择语言 和 ...

  3. 人工智能新手入门学习路线!附学习资源合集

    有段时间没跟大家分享编程资源福利了!今天为大家整理了人工智能新手入门学习路线,同时附700分钟的学习资源合集,相信这套福利可以帮你顺利入行AI!文末领取全部资料. 一.AI基础好课学习资料整理(约31 ...

  4. 【LaTeX】E喵的LaTeX新手入门教程(4)图表

    这里说的不是用LaTeX画图,而是插入已经画好的图片..想看画图可以把滚动条拉到底.前情回顾[LaTeX]E喵的LaTeX新手入门教程(1)准备篇  [LaTeX]E喵的LaTeX新手入门教程(2)基 ...

  5. 想学python都要下载什么软件-学编程闲余时间建议下载的软件_Python新手入门教程...

    原标题:学编程闲余时间建议下载的软件_Python新手入门教程 Python新手入门教程_在手机上就能学习编程的软件 很多小伙伴会问:我在学编程,想利用坐地铁坐公交吃饭间隙学编程,在手机上能学编程的软 ...

  6. 编程入门python语言是多大孩子学的-不学点编程,将来怎么给孩子辅导作业―Python新手入门教程...

    为了填满AI时代的人才缺口,编程语言教育都从娃娃抓起了!如果你还不懂Python是什么将来怎么给孩子辅导作业呢? Python新手入门教程 近期,浙江省信息技术课程改革方案出台,Python言语现已断 ...

  7. python2好还是python3好-新手入门选择Python2还是Python3

    1. 前言 Python的发展很快,几乎每年都在版本迭代.目前Python有两个主要版本,一个是python2.x,另一个是python3.x. 兔子先生最早接触Python的时候,使用的是pytho ...

  8. python新手入门-python新手入门方法

    随着人工智能 大数据的火热 Python成为了广大科学家和普通大众的学习语言.在学习Python的过程中 有很多人感到迷茫 不知道自己该从什么地方入手,今天我们就来说一些新手该如何学习Python编程 ...

  9. 【LaTeX】E喵的LaTeX新手入门教程(6)中文

    假期玩得有点凶 ._.前情回顾[LaTeX]E喵的LaTeX新手入门教程(1)准备篇  [LaTeX]E喵的LaTeX新手入门教程(2)基础排版  [LaTeX]E喵的LaTeX新手入门教程(3)数学 ...

最新文章

  1. Ubuntu .deb包安装方法
  2. NUnit学习笔记之进阶篇
  3. 新后缀再开放,投资者应谨慎对待!
  4. 状态模式 设计模式_设计模式:状态
  5. java B2B2C源码电子商务平台 --zuul跨域访问问题
  6. 新顶级域名、Cloud域名
  7. Eclipse 高清显示屏 图示太少
  8. ACE6.3.3在Linux(CentOS7.0)下的安装和使用
  9. 频率相噪中相关公式、名词注释详解
  10. 概念数据模型到逻辑数据模型的转化
  11. QQ浏览器不能播放视频怎么办?要如何解决
  12. 【WebLogic】解决opatch执行报错“Exception occured: fuser could not be located”
  13. 【c项目】网吧管理系统的设计和实现
  14. 解决易语言出现死循环代码错误提示
  15. ESL第七章 模型评估及选择 【期望】测试误差、模型偏差估计偏差、【平均】乐观、AIC、参数有效数、BIC、最小描述长度、VC/结构风险最小化、一标准误差准则/广义交叉验证、【留一】自助/.632估计
  16. kubernetes Affinity亲和性
  17. ssim算法计算图片_图像质量评估算法 SSIM(结构相似性)
  18. ARM芯片(S5PV210芯片)——串口通信详解
  19. React 高阶组件HOC详解
  20. 35_Pandas计算满足特定条件的元素的数量

热门文章

  1. [图说]舍命不舍球瞬间:罗德曼雷人 科比人仰马翻 - Qzone日志
  2. 敞口杯市场前景分析及行业研究报告
  3. openresty+lua在反向代理服务中的玩法
  4. Redis 可视化工具 Mac中文版
  5. CSS水平居中+垂直居中+水平/垂直居中的方法总结
  6. 信号隔离器解决干扰的方法
  7. QQ管家在你的电脑上不能卸载,结束进程怎么办?
  8. Axure RP的认识
  9. 简单神经网络实现手写数字图片识别
  10. exynos4412 时钟系统分析