在darknet框架上运行类似以下训练实例时必然会进入到train_detector函数,它是训练目标检测器的入口函数。

./darknet detector train cfg/coco.data cfg/yolov2.cfg darknet19_448.conv.23

./darknet detector train cfg/coco.data cfg/yolov2.cfg darknet19_448.conv.23 -gpus 0,1,2,3,4

//if not define -gpus,gpus=0,ngpus=1
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{list *options = read_data_cfg(datacfg);char *train_images = option_find_str(options, "train", "data/train.list");//store weights?char *backup_directory = option_find_str(options, "backup", "/backup/");srand(time(0));//from /a/b/yolov2.cfg extract yolov2char *base = basecfg(cfgfile); //network configprintf("%s\n", base);float avg_loss = -1;network **nets = calloc(ngpus, sizeof(network));srand(time(0));int seed = rand();int i;for(i = 0; i < ngpus; ++i){srand(seed);
#ifdef GPUcuda_set_device(gpus[i]);
#endif//create network for every GPUnets[i] = load_network(cfgfile, weightfile, clear);nets[i]->learning_rate *= ngpus;}srand(time(0));network *net = nets[0];//subdivisions,why not divide?int imgs = net->batch * net->subdivisions * ngpus;printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);data train, buffer;//the last layer e.g. [region] for yolov2layer l = net->layers[net->n - 1];int classes = l.classes; float jitter = l.jitter;list *plist = get_paths(train_images);//int N = plist->size;char **paths = (char **)list_to_array(plist);load_args args = get_base_args(net);args.coords = l.coords;args.paths = paths;args.n = imgs;        //一次加载的数量       args.m = plist->size; //总的图片数量args.classes = classes;args.jitter = jitter;args.num_boxes = l.max_boxes;args.d = &buffer;args.type = DETECTION_DATA;//args.type = INSTANCE_DATA;args.threads = 64;//n张图片以及图片上的truth box会被加载到buffer.X,buffer.y里面去pthread_t load_thread = load_data(args); double time;int count = 0;//while(i*imgs < N*120){while(get_current_batch(net) < net->max_batches){//l.random决定是否多尺度,如果要的话每训练10个batch进行一下下面的操作if(l.random && count++%10 == 0){printf("Resizing\n");//这个会随机产生{320,352,...608}这样的尺寸int dim = (rand() % 10 + 10) * 32;//意思是最后的200个batch图片都缩放到608if (get_current_batch(net)+200 > net->max_batches) dim = 608;//int dim = (rand() % 4 + 16) * 32;printf("%d\n", dim);args.w = dim;args.h = dim;pthread_join(load_thread, 0); //wait for load_thread ternimatetrain = buffer; free_data(train);load_thread = load_data(args);#pragma omp parallel forfor(i = 0; i < ngpus; ++i){//要调整网络resize_network(nets[i], dim, dim);}net = nets[0];}time=what_time_is_it_now();//args.n数量的图像由args.threads个子线程加载完成,该线程会退出pthread_join(load_thread, 0); //加载完成的args.n张图像会存入到args.d中train = buffer;//next batch?load_thread = load_data(args);printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);time=what_time_is_it_now();float loss = 0;
#ifdef GPUif(ngpus == 1){loss = train_network(net, train);} else {loss = train_networks(nets, ngpus, train, 4);}
#elseloss = train_network(net, train);
#endifif (avg_loss < 0) avg_loss = loss;avg_loss = avg_loss*.9 + loss*.1;i = get_current_batch(net);printf("%ld: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, i*imgs);if(i%100==0){
#ifdef GPUif(ngpus != 1) sync_nets(nets, ngpus, 0);
#endifchar buff[256];sprintf(buff, "%s/%s.backup", backup_directory, base);save_weights(net, buff);}if(i%10000==0 || (i < 1000 && i%100 == 0)){
#ifdef GPUif(ngpus != 1) sync_nets(nets, ngpus, 0);
#endifchar buff[256];sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);save_weights(net, buff);}//这里要相当注意,train指针指向的空间来自于buffer,而buffer中的空间来自于load_data函数//后续逻辑中动态分配的空间,而在train被赋值为buffer以后,在下一次load_data逻辑中会//再次动态分配,这里一定要记得释放前一次分配的,否则指针将脱钩,内存泄漏不可避免free_data(train);}
#ifdef GPUif(ngpus != 1) sync_nets(nets, ngpus, 0);
#endifchar buff[256];sprintf(buff, "%s/%s_final.weights", backup_directory, base);save_weights(net, buff);
}

while循环里每一次循环代表一次训练迭代,一次训练的数据量为imgs,它等于net->batch * net->subdivisions * ngpus,我只考虑CPU的情形或只考虑只含单GPU的情形的话,ngpus就等于1。而subdivisions这个参数,我所观察到大部分cfg文件(如:yolov2.cfg)中的默认设置都为1,如果该参数不为1的话,while循环中一次加载的图像数量就是net->batch * net->subdivisions(后面人都只考虑cpu情形,所以ngpus为1),否则的话就是一个net->batch(在cfg文件中会有明确的定义)。有了数据之后紧接就可以开始训练,进入到train_network函数中。

float train_network(network *net, data d)
{assert(d.X.rows % net->batch == 0);int batch = net->batch;int n = d.X.rows / batch;int i;float sum = 0;for(i = 0; i < n; ++i){//d.X.rows is net->batch * net->subdivisions * ngpus?//this batch is not that batch?get_next_batch(d, batch, i*batch, net->input, net->truth);float err = train_network_datum(net);sum += err;}//calc average lossreturn (float)sum/(n*batch);
}

首先解释一下变量n,它等于d.X.rows/batch,d.X.rows就是我们上面在一次while循环中准备的数据量imgs,考虑到ngpus=1,那么这里求出来的n实际就应该等于subdivisions。后面就相当于每次取一个batch的数据,训练n次。

float train_network_datum(network *net)
{*net->seen += net->batch; //更新已经参与训练的图片数量net->train = 1;forward_network(net);backward_network(net);float error = *net->cost;if(((*net->seen)/net->batch)%net->subdivisions == 0) update_network(net);return error;
}

意思已经比较明显了,一次迭代训练一个batch的数据,包含前向传播(forward_network),反向传播(backward_network)以及网络更新(update_network)。

darknet源码解读-train_detector相关推荐

  1. 目标检测之DarkNet-DarkNet源码解读<一>测试篇

    目标检测-DarkNet源码解读 DarkNet源码解读 1.一些思考  1.1 DarkNet的本质  1.2 深度学习分为两条线  1.3 检测任务的步骤 2.代码走读  2.1 程序入口  2. ...

  2. 【MMDetection 源码解读之yolov3】Neck - FPN

    目录 前言 一.FPN 总结 前言 这部分接着前一篇文章 [MMDetection 源码解读 yolov3]Backbone - Darknet53 继续往后讲.搭建完了主干特征提取模块,接着就是搭建 ...

  3. 【darknet源码】:导入训练数据

    darknet源码中的权重读取由函数load_network()中的load_weight函数搞定. 导入的数据的结构体信息见:[darknet源码]:网络核心结构体 整体调用流程: detector ...

  4. Darknet源码阅读【吐血整理,持续更新中】

    github地址 https://github.com/BBuf/Darknet Darknet源码阅读 Darknet是一个较为轻型的完全基于C与CUDA的开源深度学习框架,其主要特点就是容易安装, ...

  5. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  6. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  7. linux下free源码,linux命令free源码解读:Procps free.c

    linux命令free源码解读 linux命令free源码解读:Procps free.c 作者:isayme 发布时间:September 26, 2011 分类:Linux 我们讨论的是linux ...

  8. nodeJS之eventproxy源码解读

    1.源码缩影 !(function (name, definition) { var hasDefine = typeof define === 'function', //检查上下文环境是否为AMD ...

  9. PyTorch 源码解读之即时编译篇

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 作者丨OpenMMLab 来源丨https://zhuanlan.zhihu.com/ ...

最新文章

  1. 编程打开Windows服务控制管理器
  2. HarmonyOS之深入解析通知的使用
  3. 北航博士,研究所月入两万,是一种什么体验?
  4. 课程 3: Content Providers 简介
  5. Cadence OrCAD Capture交叉参考报表生成方法图文教程
  6. 河南理工大学计算机软件考研857数据结构
  7. 基于SEIRD和元胞自动机(CA)模型的传染病发展趋势预测
  8. lammps教程:Ovito分析并绘制单原子应变方法
  9. BCB6使用ReportMachine创建报表
  10. leetcode寻找重复数
  11. 企业级自动化运维工具Ansible详解(上)
  12. kafka在rack间平衡replica
  13. STM32F103C8T6实现流水灯
  14. office钓鱼学习
  15. ExifTool如何格式化日期和时间信息以进行书写
  16. 网上图书订阅系统之(招标书,投标书)
  17. excel-counta
  18. LCD液晶屏接口和显示器接口介绍
  19. Did the Microsoft Stack Kill MySpace?
  20. 【uni-app】基础知识篇

热门文章

  1. 次世代游戏建模美术教程—贴图烘培篇
  2. 安装 Android SDK
  3. 放下助人情结,尊重他人命运
  4. 工程师需要知道的计算技巧
  5. matlab中str2func函数,MATLAB 的函数句柄
  6. php js attr,jquery属性与自定义属性操作:attr()和removeAttr()
  7. uin-app 使用canvas画简易海报
  8. HTTP的303、307状态码
  9. python03---第三章:基本数据类型(天天向上的力量、文本进度条)(time库)
  10. 递归——迭代法求平方根