这里基于PyTorch框架,实现通过Faster RCNN算法检测图像中的小麦麦穗。当然,用YOLO算法也同样能够完成。本文最终实现的效果如下:

麦穗检测示例

一、数据下载

数据集名:Global Wheat Head Dataset

下载地址:www.kaggle.com/c/global-wheat-detection

更多深度学习数据集:https://www.cvmart.net/dataSets

相关论文:Global Wheat Head Detection (GWHD) Dataset: A Large and Diverse Dataset of High-Resolution RGB-Labelled Images to Develop and Benchmark Wheat Head Detection Methods

数据描述:全球麦穗数据集由来自7个国家的9个研究机构领导,东京大学、国家农业、营养和环境研究所、Arvalis、ETHZ、萨斯喀彻温大学、昆士兰大学、南京农业大学和洛桑研究所。包括全球粮食安全研究所、DigitAg、Kubota和Hiphen在内的许多机构都加入了这些机构的行列,致力于精确的小麦麦穗检测。

数据集贡献机构

数据集为室外小麦植物图像,包括来自全球各地不同平台采集的4698张RGB图像,标记了193,634个小麦麦穗,1024×1024像素,每张图像含有20~70个麦穗。2020年通过Kaggle举办了相关比赛,并在2021年更新了数据集。该数据集可以用于麦穗检测,评估穗数和大小。研究成果有助于准确估计不同品种小麦麦穗的密度和大小。

数据集示例

二、代码实战

2.1 导入所需要的包

# 导入所需要的包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch  import torch.nn as nn
import albumentations as A   # pip install albumentations==1.1.0
from albumentations.pytorch import ToTensorV2
import torchvision
from torchvision import datasets,transforms
from tqdm import tqdm
import cv2
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim
from PIL import Image
import os
import torch.nn.functional as F
import ast

2.2 参数配置

# 定义参数
LR = 1e-4
SPLIT = 0.2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
EPOCHS = 2
DATAPATH = '../global-wheat-detection'

2.3 读取数据

# 读取 train.csv文件
df = pd.read_csv(DATAPATH + '/train.csv')
df.bbox = df.bbox.apply(ast.literal_eval)   # # 将string of list 转成list数据  #  # 利用groupby 将同一个image_id的数据进行聚合,方式为list进行,并且用reset_index直接转变成dataframe
df = df.groupby("image_id")["bbox"].apply(list).reset_index(name="bboxes")

2.4 划分数据

# # 划分数据集
def train_test_split(dataFrame,split):  len_tot = len(dataFrame)  val_len = int(split*len_tot)  train_len = len_tot-val_len  train_data,val_data = dataFrame.iloc[:train_len][:],dataFrame.iloc[train_len:][:]  return train_data,val_data  len(df)  train_data_df,val_data_df = train_test_split(df,SPLIT)  # 划分 train val 8:2
len(train_data_df), len(val_data_df)  # 查看数据
train_data_df

2.5 构建Dataset类

# 定义WheatDataset 返回 图片,标签
class WheatDataset(Dataset):  def __init__(self,data,root_dir,transform=None,train=True):  self.data = data  self.root_dir = root_dir  self.image_names = self.data.image_id.values  self.bboxes = self.data.bboxes.values  self.transform = transform  self.isTrain = train  def __len__(self):  return len(self.data)  def __getitem__(self,index):
#         print(self.image_names)
#         print(self.bboxes)  img_path = os.path.join(self.root_dir,self.image_names[index]+".jpg")  # 拼接路径  image = cv2.imread(img_path, cv2.IMREAD_COLOR)   # 读取图片  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)  # BGR2RGB  image /= 255.0    # 归一化  bboxes = torch.tensor(self.bboxes[index],dtype=torch.float64)
#         print(bboxes)  """  As per the docs of torchvision  we need bboxes in format (xmin,ymin,xmax,ymax)  Currently we have them in format (xmin,ymin,width,height)  """  bboxes[:,2] = bboxes[:,0]+bboxes[:,2]   # 格式转换 (xmin,ymin,width,height)-----> (xmin,ymin,xmax,ymax)  bboxes[:,3] = bboxes[:,1]+bboxes[:,3]
#         print(image.size,type(image))  """  we need to return image and a target dictionary  target:  boxes,labels,image_id,area,iscrowd  """  area = (bboxes[:,3]-bboxes[:,1])*(bboxes[:,2]-bboxes[:,0])   # 计算面积  area = torch.as_tensor(area,dtype=torch.float32)  # there is only one class  labels = torch.ones((len(bboxes),),dtype=torch.int64)   # 标签  # suppose all instances are not crowded  iscrowd = torch.zeros((len(bboxes),),dtype=torch.int64)  target = {}   # target是个字典 里面 包括 boxes,labels,image_id,area,iscrowd  target['boxes'] = bboxes  target['labels']= labels  target['image_id'] = torch.tensor([index])  target["area"] = area  target['iscrowd'] = iscrowd  if self.transform is not None:  sample = {  'image': image,  'bboxes': target['boxes'],  'labels': labels  }  sample = self.transform(**sample)  image = sample['image']  # 沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,
#             把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠  target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)  return image,target

2.6 数据增强

# 训练与验证数据增强,利用albumentations  随机翻转转换,随机图片处理
# 对象检测的增强与正常增强不同,因为在这里需要确保 bbox 在转换后仍然正确与对象对齐
train_transform = A.Compose([  A.Flip(0.5),  ToTensorV2(p=1.0)
],bbox_params = {'format':"pascal_voc",'label_fields': ['labels']})
val_transform = A.Compose([  ToTensorV2(p=1.0)
],bbox_params = {'format':"pascal_voc","label_fields":['labels']})
`### 2.7 数据整理`"""
collate_fn默认是对数据(图片)通过torch.stack()进行简单的拼接。对于分类网络来说,默认方法是可以的(因为传入的就是数据的图片),
但是对于目标检测来说,train_dataset返回的是一个tuple,即(image, target)。
如果我们还是采用默认的合并方法,那么就会出错。
所以我们需要自定义一个方法,即collate_fn=train_dataset.collate_fn
"""
def collate_fn(batch):  return tuple(zip(*batch))

2.8 创建数据加载器

# 创建数据加载器  train_data = WheatDataset(train_data_df,DATAPATH+"/train",transform=train_transform)
valid_data = WheatDataset(val_data_df,DATAPATH+"/train",transform=val_transform)

2.9 查看数据

# 查看一个训练集中的数据
image,target = train_data.__getitem__(0)
plt.imshow(image.numpy().transpose(1,2,0))
print(image.shape)  

训练集示例

2.10 定义模型

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)

2.11 定义Averager类

# 这一个类来保存对应的loss
class Averager:  def __init__(self):  self.current_total = 0.0  self.iterations = 0.0  def send(self, value):  self.current_total += value  self.iterations += 1  @property  def value(self):  if self.iterations == 0:  return 0  else:  return 1.0 * self.current_total / self.iterations  def reset(self):  self.current_total = 0.0  self.iterations = 0.0

2.12 构建训练和测试 dataloader

# 构建训练和测试 dataloader
train_dataloader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn)
val_dataloader = DataLoader(valid_data,batch_size=BATCH_SIZE,shuffle=False,collate_fn=collate_fn)

2.13 定义模型参数

# 定义模型, 优化器,损失, 迭代,以及 学习率
train_loss = []
# val_loss = []
model = model.to(DEVICE)
params =[p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(params,lr=LR)
loss_hist = Averager()
itr = 1
lr_scheduler=None  loss_hist = Averager()
itr = 1

2.14 模型训练

if __name__ == '__main__':  for epoch in range(EPOCHS):  loss_hist.reset()  for images, targets in train_dataloader:  # print(images)  # print(targets)  # for image in images:  #     print(image.dtype)  # torch.float32  # for t in targets:  #     for k, v in t.items():  #         print(k ,v.dtype)  images = list(image.to(DEVICE) for image in images)  targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]  loss_dict = model(images, targets)  # for loss in loss_dict.values():  #     print(loss.dtype)  losses = sum(loss for loss in loss_dict.values())  loss_value = losses.item()  loss_hist.send(loss_value)  optimizer.zero_grad()  losses.backward()  optimizer.step()  if itr % 50 == 0:  print(f"Iteration #{itr} loss: {loss_value}")  itr += 1  # update the learning rate  if lr_scheduler is not None:  lr_scheduler.step()  print(f"Epoch #{epoch} loss: {loss_hist.value}")

2.15 模型保存

# 模型保存
torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn.pth')

训练好的模型                                                whaosoft aiot http://143ai.com

2.16 加载模型进行预测

images, targets = next(iter(val_dataloader))
images = list(img.to(DEVICE) for img in images)
# print(images[0].shape)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
boxes = targets[1]['boxes'].cpu().numpy().astype(np.int32)
sample = images[1].permute(1, 2, 0).cpu().numpy()  model.eval()
cpu_device = torch.device("cpu")
# print(images[0].shape)  outputs = model(images)
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
# print(outputs[1]['boxes'].detach().numpy().astype(np.int32))  pred_boxes = outputs[1]['boxes'].detach().numpy().astype(np.int32)  fig, ax = plt.subplots(1, 1, figsize=(16, 8))  for b, box in zip(boxes, pred_boxes):  # 绘制预测边框 红色表示  cv2.rectangle(sample,  (box[0], box[1]),  (box[2], box[3]),  (220, 0, 0), 3)  # 绘制实际边框  绿色表示  cv2.rectangle(sample,  (b[0], b[1]),  (b[2], b[3]),  (0, 220, 0), 3)  ax.set_axis_off()
ax.imshow(sample)
plt.show()  

检测结果

对比预测框与实际框,可以看出模型能够很好的预测出麦穗。可以尝试测试不同的麦穗图片,来进行测试查看效果。

PyTorch~Faster RCNNの小麦麦穗检测相关推荐

  1. 基于Faster RCNN的医学图像检测(肺结节检测)

    Faster-R-CNN算法由两大模块组成:1.PRN候选框提取模块 2.Fast R-CNN检测模块.其中,RPN是全卷积神经网络,用于提取候选框:Fast R-CNN基于RPN提取的proposa ...

  2. 计算机视觉与深度学习 | 基于Faster R-CNN的目标检测(深度学习Matlab代码)

    ===================================================== github:https://github.com/MichaelBeechan CSDN: ...

  3. 【论文解读】Faster R-CNN 实时目标检测

    前言 Faster R-CNN 的亮点是使用RPN来提取候选框:RPN全称是Region Proposal Network,也可理解为区域生成网络,或区域候选网络:它是用来提取候选框的.RPN特点是耗 ...

  4. 面试真题总结:Faster Rcnn,目标检测,卷积,梯度消失,Adam算法

    目标检测可以分为两大类,分别是什么,他们的优缺点是什么呢? 答案:目标检测算法分为单阶段和双阶段两大类.单阶段目标验测算法(one-stage),代表算法有 yolo 系列,SSD 系列:直接对图像进 ...

  5. iCAN使用faster r-cnn得到目标检测结果文件为空

    问题在于图片文件夹后少了/,添加上/后解决 -/tf-faster-rcnn/tools/Object_Detector.py --img_dir /home/featurize/Data/exima ...

  6. 目标检测算法Faster R-CNN简介

    在博文https://blog.csdn.net/fengbingchun/article/details/87091740 中对Fast R-CNN进行了简单介绍,这里在Fast R-CNN的基础上 ...

  7. 【目标检测】Faster RCNN算法详解

    转载自:http://blog.csdn.net/shenxiaolu1984/article/details/51152614 Ren, Shaoqing, et al. "Faster ...

  8. 人工智能:物体检测之Faster RCNN模型

    人工智能:物体检测之Faster RCNN模型 物体检测 Faster RCNN模型 简介 卷积层 RPN Roi Pooling Classifier 物体检测 什么是物体检测 物体检测应用场景 物 ...

  9. 深度学习之目标检测:R-CNN、Fast R-CNN、Faster R-CNN

    object detection 就是在给定的图片中精确找到物体所在位置,并标注出物体的类别.object detection 要解决的问题就是物体在哪里,是什么这整个流程的问题.然而,这个问题不是容 ...

最新文章

  1. 上手必备!不可错过的TensorFlow、PyTorch和Keras样例资源
  2. python资料免费-MicroPython最全资料免费获取
  3. 全球及中国生物质能利用产业现状调研及十四五建设布局规划报告2021-2027年
  4. jtag引脚定义_硬件学习之通过树莓派操控 jtag
  5. Java数据类型与各数据库类型对应一览表
  6. 《VMware vSphere设计(原书第2版)》——1.3 设计原则
  7. 使用WireMock进行更好的集成测试
  8. parseInt(string, radix)
  9. STM32H743+CubeMX-QSPI读写外部FLASH(W25Q128JVSQ)
  10. P2057 [SHOI2007]善意的投票 最小割
  11. mysql导入库指令_mysql数据库指令导入导出
  12. 【论文写作】JSP在线考试系统如何写功能描述
  13. diffpatch升级_Tinker资源补丁原理解析
  14. 免费报表软件有哪些?5款热门工具
  15. 学金融离不开计算机,如何看待学计算机的被学金融的鄙视
  16. hdu 6357 Hills And Valleys (DP)
  17. 考研失败最根本的5个原因!
  18. python上海交通大学赵璐_上海交通大学
  19. 如何确定喜神、财神、福神方位
  20. 嵌入式软件开发之------浅析linux根文件系统挂载(九)

热门文章

  1. ESXI 系统密码登录失败登录不上
  2. deep_learning_初学neural network
  3. 计算机桌面图标设置打字图标,电脑的打字图标不见了怎么办
  4. 计算机日常维护小知识,计算机日常维护小常识
  5. 东北师范大学计算机考研参考书,东北师范大学(专业学位)计算机技术考研参考书目...
  6. 手拉手带你开启Vue3世界的鬼斧神工
  7. 解决windows蓝屏 STOP c000021a {Fatal System Error} (by 星空武哥)
  8. 如何在多台电脑间同步
  9. 什么是软件需求,什么是功能需求?——论需求的三个层次和三个方面(2)
  10. excel使用:countif