目录

引言

一、数据集融合

1.1 直接权重融合

1.2 TIF融合方法

二、结果集融合

2.1 整体代码框架

2.2 标定参数

2.3 选择最优检测框

2.3.1 检测框合并

2.3.2 重复检测框剔除

2.4 结果可视化

2.4.3 单目估计距离

2.4.4 结果参数可视化

三、特征集融合

3.1 双模态数据集制作

3.2 传入双模态数据

3.2.1  dataset.py

3.2.2  train_batch显示

3.3 更改backbone

3.3.1 yaml解析文件

3.3.2 backbone

3.4 Forward更改

3.5 调试运行


引言

目前,多模态数据融合主要有三种融合方式:前端融合(early-fusion)或数据水平融合(data-level fusion)、后端融合(late-fusion)或决策水平融合(decision-level fusion)和中间融合(intermediate-fusion)。

前端融合将多个独立的数据集融合成一个单一的特征向量,然后输入到机器学习分类器中。由于多模态数据的前端融合往往无法充分利用多个模态数据间的互补性,且前端融合的原始数据通常包含大量的冗余信息。因此,多模态前端融合方法常常与特征提取方法相结合以剔除冗余信息,如主成分分析(PCA)、最大相关最小冗余算法(mRMR)、自动解码器(Autoencoders)等。

后端融合则是将不同模态数据分别训练好的分类器输出打分(决策)进行融合。这样做的好处是,融合模型的错误来自不同的分类器,而来自不同分类器的错误往往互不相关、互不影响,不会造成错误的进一步累加。常见的后端融合方式包括最大值融合(max-fusion)、平均值融合(averaged-fusion)、 贝叶斯规则融合(Bayes’rule based)以及集成学习(ensemble learning)等。其中集成学习作为后端融合方式的典型代表,被广泛应用于通信、计算机识别、语音识别等研究领域。

中间融合是指将不同的模态数据先转化为高维特征表达,再于模型的中间层进行融合。以神经网络为例,中间融合首先利用神经网络将原始数据转化成高维 特征表达,然后获取不同模态数据在高维空间上的共性。中间融合方法的一大优势是可以灵活的选择融合的位置

目标检测的融合方法主要分为三种方式,下面将给出三种方法的思路和相关的代码。融合的前提是传感器的标定,不知道如何标定的小伙伴参见下方链接,本文直接使用标定后的图片。

可见光相机与红外相机标定_OrigamiSun的博客-CSDN博客_红外相机标定

一、数据集融合

工程难度:简单
将可见光和红外图像先融合再传入网络进行训练,网络结构不用改动。融合前需要对可见光相机和红外相机进行参数标定,这样才能进行融合。

1.1 直接权重融合

# -*- coding: utf-8 -*-try:import cv2
except:import syssys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')import cv2import numpy as np
import os
import timedef weight_half_algo(p1_path, p2_path, weight=0.5):# 两幅图像的加权平均p1 = cv2.imread(p1_path, cv2.IMREAD_COLOR)  # 按彩色图像读取p2 = cv2.imread(p2_path, cv2.IMREAD_COLOR)  # .astype(np.float)p1 = p1.astype(np.float)  # 转成float类型矩阵p2 = p2.astype(np.float)  # 转成float类型矩阵p = weight * p1 + (1 - weight) * p2return pdef flist():"""对文件夹下的示例图像批量进行计算。结果写入文件夹 rootdir_Res 中:return:"""rootdir_IR = r'D:\Project\Python\ImageFusion\VIFB-master\input\IR'  # 红外图像存放路径rootdir_VI = r'D:\Project\Python\ImageFusion\VIFB-master\input\VI'  # 可见光图像存放路径rootdir_Res = r'D:\Project\Python\ImageFusion\VIFB-master\Res'  # TIF算法处理后的图像存放路径rootdir_Res_weight = r'D:\Project\Python\ImageFusion\VIFB-master\Res_weight'  # 平均加权算法处理后的图像存放路径fflist = os.listdir(rootdir_IR)  # 列出文件夹下所有的目录与文件# print(fflist)for i in range(0, len(fflist)):path1 = os.path.join(rootdir_IR, fflist[i])path2 = os.path.join(rootdir_VI, fflist[i])if os.path.isfile(path1) and os.path.isfile(path2):p = weight_half_algo(path1, path2)  # 采用两者平均加权的方法进行融合cv2.imwrite(os.path.join(rootdir_Res_weight, fflist[i]), p)if __name__ == '__main__':# 程序开始时的时间time_start = time.time()# 1 图表示可见光;2 图表示红外flist()time_end = time.time()#cv2.imwrite('./final_res.jpg', p)  print('程序运行花费时间', time_end - time_start)

直接自己设置权重,权重为定值,融合效果有限

1.2 TIF融合方法

TIF算法是将图像分成基础层和细节层,之后再按加权相加。基础层,就是将图像进行均值滤波(文中用的是35),均值滤波后的图像就是基础层,原图减去基础层就是细节层。基础层的权重是0.5,细节层的权重按公式计算,是动态的。

# -*- coding: utf-8 -*-try:import cv2
except:import syssys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')import cv2import numpy as np
import os
import timedef TIF_algo(p1, p2, median_blur_value=3, mean_blur_value=35):"""通过TIF方法融合图像:param p1_path: 图像1路径:param p2_path: 图像2路径:param median_blur_value: 中值滤波参数:param mean_blur_value:  均值滤波参数:return: 融合图像"""# median_blur_value = 3  # 中值滤波系数# mean_blur_value = 35  # 均值滤波系数# 均值滤波后(35,35)的图层,即基础层p1_b = cv2.blur(p1, (mean_blur_value, mean_blur_value))p1_b = p1_b.astype(np.float)  # 转成float类型矩阵p2_b = cv2.blur(p2, (mean_blur_value, mean_blur_value))p2_b = p2_b.astype(np.float)  # 转成float类型矩阵# cv2.imshow('picture after mean blur p1_b', p1_b)# cv2.imshow('picture after mean blur p2_b', p2_b)# 均值滤波后的细节层# p1_d = abs(p1.astype(np.float) - p1_b)# p2_d = abs(p2.astype(np.float) - p2_b)p1_d = p1.astype(np.float) - p1_bp2_d = p2.astype(np.float) - p2_b# cv2.imshow('detail layer p1', p1_d / 255.0)# cv2.imshow('detail layer p2', p2_d / 255.0)# 原图经过中值滤波后的图层p1_after_medianblur = cv2.medianBlur(p1, median_blur_value)p2_after_medianblur = cv2.medianBlur(p2, median_blur_value)# 矩阵转换,换成float型,参与后面计算p1_after_medianblur = p1_after_medianblur.astype(np.float)p2_after_medianblur = p2_after_medianblur.astype(np.float)# 计算均值和中值滤波后的误差p1_subtract_from_median_mean = p1_after_medianblur - p1_b + 0.01  # 加0.01 保证结果非NANp2_subtract_from_median_mean = p2_after_medianblur - p2_b + 0.01# cv2.imshow('subtract_from_median_mean  p1_subtract_from_median_mean', p1_subtract_from_median_mean/255.0)# cv2.imshow('subtract_from_median_mean  p2_subtract_from_median_mean', p2_subtract_from_median_mean/255.0)m1 = p1_subtract_from_median_mean[:, :, 0]m2 = p1_subtract_from_median_mean[:, :, 1]m3 = p1_subtract_from_median_mean[:, :, 2]res = m1 * m1 + m2 * m2 + m3 * m3# delta1 = np.sqrt(res)delta1 = resm1 = p2_subtract_from_median_mean[:, :, 0]m2 = p2_subtract_from_median_mean[:, :, 1]m3 = p2_subtract_from_median_mean[:, :, 2]res = m1 * m1 + m2 * m2 + m3 * m3# delta2 = np.sqrt(res) #采用平方和开根号做权重计算# delta2 = res #采用平方和做权重计算delta2 = abs(m1)  # 由于图像2 红外图像是灰度图像,直接用像素差做权重计算delta_total = delta1 + delta2  # 分母psi_1 = delta1 / delta_totalpsi_2 = delta2 / delta_totalpsi1 = np.zeros(p1.shape, dtype=np.float)psi2 = np.zeros(p2.shape, dtype=np.float)psi1[:, :, 0] = psi_1psi1[:, :, 1] = psi_1psi1[:, :, 2] = psi_1psi2[:, :, 0] = psi_2psi2[:, :, 1] = psi_2psi2[:, :, 2] = psi_2# 基础层融合p_b = 0.5 * (p1_b + p2_b)# 细节层融合p_d = psi1 * p1_d + psi2 * p2_d# 整体融合p = p_b + p_dreturn pif __name__ == '__main__':# 程序开始时的时间time_start = time.time()# 1 图表示可见光;2 图表示红外#flist()p1_path = './rgb/000000.jpg'  # 可见光图像p1 = cv2.imread(p1_path, cv2.IMREAD_COLOR)#.astype(np.float)p2_path = './t/000000.jpg'  # 红外图像p2 = cv2.imread(p2_path, cv2.IMREAD_COLOR)#.astype(np.float)p = TIF_algo(p1, p2)time_end = time.time()cv2.imwrite('./dy_weight/000000.jpg', p)print('程序运行花费时间', time_end - time_start)

可以看到融合的图片细节更加清楚 。

将融合后的图像传入神经网络进行训练。

二、结果集融合

工程难度:中等

单独训练两个网络,得到两个权重。将detect改成ros节点,对应两个callback分别加载加载两个模型权重进行检测,得到结果进行融合,选择结果最好的。

2.1 整体代码框架

可以看到双模态的主要更改有

1、两个检测器

2、怎么选择最优目标框

3、结果的可视化

# -*- coding: UTF-8 -*-import os
import sys
import rospy
import numpy as np
import time
import math
import matplotlib.pyplot as plt
import threading
import argparsefrom std_msgs.msg import Header
from sensor_msgs.msg import Imagetry:import cv2
except ImportError:import syssys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')import cv2from yolov5_detector import Yolov5Detector, draw_predictions
from mono_estimator import MonoEstimator
from functions import get_stamp, publish_image
from functions import display, print_info
from functions import simplified_nmsparser = argparse.ArgumentParser(description='Demo script for dual modal peception')
parser.add_argument('--print', action='store_true',help='Whether to print and record infos.')
parser.add_argument('--sub_image1', default='/pub_rgb', type=str,help='The image topic to subscribe.')
parser.add_argument('--sub_image2', default='/pub_t', type=str,help='The image topic to subscribe.')
parser.add_argument('--pub_image', default='/result', type=str,help='The image topic to publish.')
parser.add_argument('--calib_file', default='../conf/calibration_image.yaml', type=str,help='The calibration file of the camera.')
parser.add_argument('--modality', default='RGBT', type=str,help='The modality to use. This should be `RGB`, `T` or `RGBT`.')
parser.add_argument('--indoor', action='store_true',help='Whether to use INDOOR detection mode.')
parser.add_argument('--frame_rate', default=10, type=int,help='Working frequency.')
parser.add_argument('--display', action='store_true',help='Whether to display and save all videos.')
args = parser.parse_args()
# 在线程函数执行前,“抢占”该锁,执行完成后,“释放”该锁,则我们确保了每次只有一个线程占有该锁。这也是为什么能让RGB和T匹配的原因
image1_lock = threading.Lock() #创建锁
image2_lock = threading.Lock()#3.1 获取RGB图像的时间戳和格式转化
#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××
def image1_callback(image): global image1_stamp, image1_frame #多线程是共享资源的,使用全局变量 image1_lock.acquire() #锁定锁image1_stamp = get_stamp(image.header) #获得时间戳image1_frame = np.frombuffer(image.data, dtype=np.uint8).reshape(image.height, image.width, -1)#图片格式转化image1_lock.release() #释放
#3.2 获取红外图像的时间戳和格式转化
#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××
def image2_callback(image):global image2_stamp, image2_frameimage2_lock.acquire()image2_stamp = get_stamp(image.header)image2_frame = np.frombuffer(image.data, dtype=np.uint8).reshape(image.height, image.width, -1)image2_lock.release()#5.1 获取红外图像的时间戳和格式转化
#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××
def timer_callback(event):#获得RGB的时间戳和图片global image1_stamp, image1_frameimage1_lock.acquire()cur_stamp1 = image1_stampcur_frame1 = image1_frame.copy()image1_lock.release()#获得T的时间戳和图片global image2_stamp, image2_frameimage2_lock.acquire()cur_stamp2 = image2_stampcur_frame2 = image2_frame.copy()image2_lock.release()global frameframe += 1start = time.time()# 5.2获得预测结果#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××if args.indoor: #使用yolov5本身权重labels, scores, boxes = detector.run(cur_frame1, conf_thres=0.50, classes=[0]) # personelse:if args.modality.lower() == 'rgb': #用RGB图像权重labels, scores, boxes = detector1.run(cur_frame1, conf_thres=0.50, classes=[0, 1, 2, 3, 4]) # pedestrian, cyclist, car, bus, truckelif args.modality.lower() == 't':  #用T图像权重labels, scores, boxes = detector2.run(cur_frame2, conf_thres=0.50, classes=[0, 1, 2, 3, 4]) # pedestrian, cyclist, car, bus, truckelif args.modality.lower() == 'rgbt': #双模态都用#获取RGB的预测结果    类别、置信分数、检测框labels1, scores1, boxes1 = detector1.run(cur_frame1, conf_thres=0.50, classes=[0, 1, 2, 3, 4]) # pedestrian, cyclist, car, bus, truck#print("rgb",labels1, scores1, boxes1)#获取T的预测结果labels2, scores2, boxes2 = detector2.run(cur_frame2, conf_thres=0.50, classes=[0, 1, 2, 3, 4]) # pedestrian, cyclist, car, bus, truck#print("T",labels2, scores2, boxes2)# 确定最终结果labels = labels1 + labels2 #合并类别数组#print("labels",labels)scores = scores1 + scores2 #合并分数数组#print("scores",scores)if boxes1.shape[0] > 0 and boxes2.shape[0] > 0: #如果可见光和红外都检测到目标boxes = np.concatenate([boxes1, boxes2], axis=0) #链接两个检测框#print("boxes",boxes)# 排除重复的目标框indices = simplified_nms(boxes, scores)labels, scores, boxes = np.array(labels)[indices], np.array(scores)[indices], boxes[indices]#print("result",labels, scores, boxes)elif boxes1.shape[0] > 0: #如果只有可见光检测到boxes = boxes1#print("boxes",boxes)elif boxes2.shape[0] > 0: #如果只有红外检测到boxes = boxes2#print("boxes",boxes)else:   #都没检测到boxes = np.array([])#print("boxes",boxes)else:raise ValueError("The modality must be 'RGB', 'T' or 'RGBT'.")labels_temp = labels.copy()labels = []for i in labels_temp:labels.append(i if i not in ['pedestrian', 'cyclist'] else 'person')print("boxes",boxes)# 5.3单目估计距离#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××locations = mono.estimate(boxes)#获得目标框的世界坐标系print("locations",locations)indices = [i for i in range(len(locations)) if locations[i][1] > 0 and locations[i][1] < 200]labels, scores, boxes, locations = \np.array(labels)[indices], np.array(scores)[indices], boxes[indices], np.array(locations)[indices]distances = [(loc[0] ** 2 + loc[1] ** 2) ** 0.5 for loc in locations] #估计距离print("distances",distances)cur_frame1 = cur_frame1[:, :, ::-1].copy() # to BGRcur_frame2 = cur_frame2[:, :, ::-1].copy() # to BGR# 5.4画检测框#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××for i in reversed(np.argsort(distances)):#reversed()函数用于对可迭代对象中的元素进行反向排列cur_frame1 = draw_predictions(cur_frame1, str(labels[i]), float(scores[i]), boxes[i], location=locations[i])cur_frame2 = draw_predictions(cur_frame2, str(labels[i]), float(scores[i]), boxes[i], location=locations[i])# 5.5发布图像#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××result_frame = np.concatenate([cur_frame1, cur_frame2], axis=1) #合并图像if args.display:if not display(result_frame, v_writer, win_name='result'):print("\nReceived the shutdown signal.\n")rospy.signal_shutdown("Everything is over now.")result_frame = result_frame[:, :, ::-1] # to RGBpublish_image(pub, result_frame)delay = round(time.time() - start, 3)if args.print:print_info(frame, cur_stamp1, delay, labels, scores, boxes, locations, file_name)if __name__ == '__main__':# 初始化节点rospy.init_node("dual_modal_perception", anonymous=True, disable_signals=True)frame = 0# 一、加载标定参数#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××if not os.path.exists(args.calib_file):raise ValueError("%s Not Found" % (args.calib_file))mono = MonoEstimator(args.calib_file, print_info=args.print)# 二、初始化Yolov5Detector#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××if args.indoor:detector = Yolov5Detector(weights='weights/coco/yolov5s.pt')else:if args.modality.lower() == 'rgb':detector1 = Yolov5Detector(weights='weights/seumm_visible/yolov5s_100ep_pretrained.pt')elif args.modality.lower() == 't':detector2 = Yolov5Detector(weights='weights/seumm_lwir/yolov5s_100ep_pretrained.pt')elif args.modality.lower() == 'rgbt': #双模态detector1 = Yolov5Detector(weights='weights/yolov5s.pt')detector2 = Yolov5Detector(weights='weights/yolov5s.pt')else:raise ValueError("The modality must be 'RGB', 'T' or 'RGBT'.")# 三、进入回调函数#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××# 准备图像序列image1_stamp, image1_frame = None, Noneimage2_stamp, image2_frame = None, None#rgb回调函数rospy.Subscriber(args.sub_image1, Image, image1_callback, queue_size=1,buff_size=52428800)#红外回调函数rospy.Subscriber(args.sub_image2, Image, image2_callback, queue_size=1,buff_size=52428800)# 等待RGB和t图像都获得再进行下一次循环while image1_frame is None or image2_frame is None:time.sleep(0.1)print('Waiting for topic %s and %s...' % (args.sub_image1, args.sub_image2))print('  Done.\n')# 四、功能选择#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××#如果在命令行中输入 python3 demo_dual_modal.py --print 则会运行下面代码# 是否记录时间戳和检测结果if args.print:file_name = 'result.txt'with open(file_name, 'w') as fob:fob.seek(0)fob.truncate()# 是否保存视频if args.display:assert image1_frame.shape == image2_frame.shape, \'image1_frame.shape must be equal to image2_frame.shape.'win_h, win_w = image1_frame.shape[0], image1_frame.shape[1] * 2v_path = 'result.mp4'v_format = cv2.VideoWriter_fourcc(*"mp4v")v_writer = cv2.VideoWriter(v_path, v_format, args.frame_rate, (win_w, win_h), True)# 五、预测结果与发布#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××# 启动定时检测线程pub = rospy.Publisher(args.pub_image, Image, queue_size=1)rospy.Timer(rospy.Duration(1 / args.frame_rate), timer_callback) #每frame_rate 秒调用一次timer_callback# 与C++的spin不同,rospy.spin()的作用是当节点停止时让python程序退出rospy.spin()

2.2 标定参数

要进行单目距离估计,引入相机标定参数,calibration_image.yaml 文件内容如下。

%YAML:1.0
---
ProjectionMat: !!opencv-matrixrows: 3cols: 4dt: ddata: [859, 0, 339, 0, 0, 864, 212, 0, 0, 0, 1, 0]
Height: 2.0 # the height of the camera (meter)
DepressionAngle: 0.0 # the depression angle of the camera (degree)

2.3 选择最优检测框

2.3.1 检测框合并

想要选择最优检测框,首先要将所有框合并,一共四种情况,两路都检测到,各只有一路检测到,都没检测到

1、可见光和红外都检测到

# 可见光和红外检测出来的东西
rgb ['person'] [0.74] [[        354         245         379         329]]
T ['person', 'person'] [0.77, 0.54] [[        353         241         379         327][        114         249         128         292]]
# 合并后的数组
labels ['person', 'person', 'person']
scores [0.74, 0.77, 0.54]
boxes [[        354         245         379         329][        353         241         379         327][        114         249         128         292]]

2、只有可见光检测到

#各检测结果
rgb ['person'] [0.72] [[        357         245         382         330]]
T [] [] []
# 合并
labels ['person']
scores [0.72]
boxes [[        357         245         382         330]]

3、只有红外检测到

#各检测结果
rgb [] [] []
T ['person'] [0.72] [[        357         245         382         330]]# 合并
labels ['person']
scores [0.72]
boxes [[        357         245         382         330]]

4、都没检测到

#各检测结果
rgb [] [] []
T [] [] []# 合并
labels []
scores []
boxes [[]]

2.3.2 重复检测框剔除

针对第一种情况,会出现同一目标同时被可见光和红外都检测到,那么到底用哪个检测框来表示是一个选择问题。

首先看一下效果,通过检测框坐标知道第1个person和第3个person是同一个目标,算法排除了第2个person目标,那么怎么选择最好的检测框呢?

NMS 把置信度(预测这个网格里是否有目标的置信度)最高的那个网格的边界箱作为极大边界箱,计算极大边界箱和其他几个网格的边界箱的IOU,如果超过一个阈值,例如0.5,就认为这两个网格实际上预测的是同一个物体,就把其中置信度比较小的删除。

#1、检测
rgb ['person', 'person'] [0.82, 0.6] [[        373         244         401         333][        110         259         123         299]]
T ['person'] [0.67] [[        372         248         400         331]]
#2、合并
labels ['person', 'person', 'person']
scores [0.82, 0.6, 0.67]
boxes [[        373         244         401         333][        110         259         123         299][        372         248         400         331]]
#3、选择
result ['person' 'person'] [       0.82         0.6] [[        373         244         401         333]  [        110         259         123         299]]
  • 算法代码如下
def simplified_nms(boxes, scores, iou_thres=0.5):  #排除重复目标'''Args:boxes: (n, 4), xyxy formatscores: list(float)Returns:indices: list(int), indices to keep'''x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] #获得每个检测框的对角坐标area = (x2 - x1) * (y2 - y1) #所有检测框面积#argsort()函数是将x中的元素从小到大排列,提取其对应的index(索引),使用[::-1],可以建立X从大到小的索引。#1、置信度排序#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××idx = np.argsort(scores)[::-1] #置信度由大到小的索引排序print("idx",idx)indices = []#2、循环数组筛选#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××while len(idx) > 0:i = idx[0] #取最大置信度的索引indices.append(i)  #将这个索引放入indices中print('indices',indices)if len(idx) == 1:  #如果就一个检测框breakidx = idx[1:] #意思是去掉列表中第一个元素,对后面的元素进行操作print('idx[1:]',idx)#3、与最大置信度框的重合度来判断是否是重复检测框#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××xx1 = np.clip(x1[idx], a_min=x1[i], a_max=np.inf) #截取数组中小于或者大于某值的部分 np.inf表示无穷大yy1 = np.clip(y1[idx], a_min=y1[i], a_max=np.inf)xx2 = np.clip(x2[idx], a_min=0, a_max=x2[i])yy2 = np.clip(y2[idx], a_min=0, a_max=y2[i])w, h = np.clip((xx2 - xx1), a_min=0, a_max=np.inf), np.clip((yy2 - yy1), a_min=0, a_max=np.inf)inter = w * hunion = area[i] + area[idx] - interiou = inter / union #计算iou 越大越好最大为1print('iou',iou)#4、将重合度小的即不同目标保存到idx再次循环#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××idx = idx[iou < iou_thres] #框重合度小视为不同目标print("idx",idx)return indices #返回合并后的索引
  • 变量输出如下

重复目标

#1、检测
rgb ['person'] [0.81] [[        365         248         392         332]]
T ['person'] [0.82] [[        364         243         392         329]]
#2、合并
labels ['person', 'person']
scores [0.81, 0.82]
boxes [[        365         248         392         332][        364         243         392         329]]
#3、置信度排序选择
idx [1 0]
indices [1]
#4、对剩下的目标与置信度最大的目标进行iou计算,来判断是否是同一目标
idx[1:] [0]
iou [    0.87867] #iou为0.87 高度重合,所以是同一目标
idx [] #没有待选目标了,循环结束

不重复目标

# 1、检测
rgb ['person'] [0.78] [[        363         246         390         329]]
T ['person', 'person'] [0.84, 0.52] [[        362         245         391         332][        113         250         135         295]]
# 2、合并
labels ['person', 'person', 'person']
scores [0.78, 0.84, 0.52]
boxes [[        363         246         390         329][        362         245         391         332][        113         250         135         295]]
# 3、排序
idx [1 0 2]
indices [1]
# 4、计算iou排除重复目标
idx[1:] [0 2]
iou [    0.88823           0] #0表示不重合,为不同的目标,对应索引保存到indices
idx [2]
indices [1, 2]
result ['person' 'person'] [       0.84        0.52] [[        362         245         391         332][        113         250         135         295]]

2.4 结果可视化

2.4.3 单目估计距离

    # 5.3单目估计距离#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××locations = mono.estimate(boxes)#获得目标框的世界坐标系print("locations",locations)indices = [i for i in range(len(locations)) if locations[i][1] > 0 and locations[i][1] < 200]labels, scores, boxes, locations = \np.array(labels)[indices], np.array(scores)[indices], boxes[indices], np.array(locations)[indices]distances = [(loc[0] ** 2 + loc[1] ** 2) ** 0.5 for loc in locations] #估计距离print("distances",distances)cur_frame1 = cur_frame1[:, :, ::-1].copy() # to BGRcur_frame2 = cur_frame2[:, :, ::-1].copy() # to BGR

单目估计距离算法的代码如下

# -*- coding: UTF-8 -*- import numpy as np
import math
import cv2
from math import sin, cosclass MonoEstimator():def __init__(self, file_path, print_info=True):fs = cv2.FileStorage(file_path, cv2.FileStorage_READ)mat = fs.getNode('ProjectionMat').mat()self.fx = int(mat[0, 0])self.fy = int(mat[1, 1])self.u0 = int(mat[0, 2])self.v0 = int(mat[1, 2])self.height = fs.getNode('Height').real()self.depression = fs.getNode('DepressionAngle').real() * math.pi / 180.0if print_info:print('Calibration of camera:')print('  Parameters: fx(%d) fy(%d) u0(%d) v0(%d)' % (self.fx, self.fy, self.u0, self.v0))print('  Height: %.2fm' % self.height)print('  DepressionAngle: %.2frad' % self.depression)print()def uv_to_xyz(self, u, v):# Compute (x, y, z) coordinates in the real world, according to (u, v) coordinates in the image.# X axis - on the right side of the camera# Z axis - in front of the camerau = int(u)v = int(v)fx, fy = self.fx, self.fyu0, v0 = self.u0, self.v0h = self.heightt = self.depressiondenominator = fy * sin(t) + (v - v0) * cos(t)if denominator != 0:z = (h * fy * cos(t) - h * (v - v0) * sin(t)) / denominatorif z > 1000: z = 1000else:z = 1000x = (z * (u - u0) * cos(t) + h * (u - u0) * sin(t)) / fxy = hreturn x, y, zdef estimate(self, boxes):#单目估计距离locations = []if boxes.shape[0] > 0: #判断目标个数是否大于1for box in boxes:print(box[0])u, v = (box[0] + box[2]) / 2, box[3]x, y, z = self.uv_to_xyz(u, v)locations.append((x, z))return locations

locations是世界坐标系的xy坐标,distances是距离

2.4.4 结果参数可视化

# 5.4画检测框#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××for i in reversed(np.argsort(distances)):#reversed()函数用于对可迭代对象中的元素进行反向排列cur_frame1 = draw_predictions(cur_frame1, str(labels[i]), float(scores[i]), boxes[i], location=locations[i] #两个图像上的数据显示一样)cur_frame2 = draw_predictions(cur_frame2, str(labels[i]), float(scores[i]), boxes[i], location=locations[i])# 5.5发布图像#×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××result_frame = np.concatenate([cur_frame1, cur_frame2], axis=1) #合并图像

最后的效果如下

三、特征集融合

工程难度:困难

同时训练可见光和红外图片,需要改动网络的结构,对每层的特征进行融合。

同时需要对图片质量进行评价,给出自适应的融合权重。

论文终于发表喽,欢迎小伙伴们引用:

SLBAF-Net: Super-Lightweight bimodal adaptive fusion network for UAV detection in low recognition environment | SpringerLink

3.1 双模态数据集制作

有两个图片集,可见光images和红外images2,共用一个labels。

yaml文件配置如下

train: data/train.txt
train2: data/train2.txt
val: data/val.txt
val2: data/val2.txt
test: data/test.txt
# Classes
nc: 7  # number of classes
names: ['pedestrian','cyclist','car','bus','truck','traffic_light','traffic_sign']  # class names

3.2 传入双模态数据

3.2.1  dataset.py

这个文件存放的是读取处理数据集的一系列函数,yolov5原来使用的mosaic拼接方法,即一张图片再随机选取三张进行拼接,但这不适合双模态的融合,所以我先把dataset.py文件保留成只有最基本功能的函数,再进行修改。

dataset.py拆解后的运行效果如下,这样对于双模态的融合是非常方便的。随后把各函数的输入输出变成双模态。

主要修改的思路如下,整个修改过程还是比较艰辛的,需要对代码很熟悉,本文只提供大概思路

1、create_dataloader

这个函数主要改一下输入,增加一个path2,对应train.py中引用这个函数也需要修改

def create_dataloader(path,path2,imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):if rect and shuffle:LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')shuffle = Falsewith torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP#return torch.from_numpy(img),torch.from_numpy(img2), labels_out, self.im_files[index],self.im_files2[index], shapesdataset = LoadImagesAndLabels(path,path2, imgsz, batch_size,augment=augment,  # augmentationhyp=hyp,  # hyperparametersrect=rect,  # rectangular batchescache_images=cache,single_cls=single_cls,stride=int(stride),pad=pad,image_weights=image_weights,prefix=prefix)batch_size = min(batch_size, len(dataset))nd = torch.cuda.device_count()  # number of CUDA devicesnw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workerssampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updatesreturn loader(dataset,batch_size=batch_size,shuffle=shuffle and sampler is None,num_workers=nw,sampler=sampler,pin_memory=True,collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset

更改train.py文件中的train_loader和 val_loader

    #创建训练集train_loader, dataset = create_dataloader(train_path,train_path2, imgsz, batch_size // WORLD_SIZE, gs, single_cls,hyp=hyp, augment=True, cache=None if opt.cache == 'val' else opt.cache,rect=opt.rect, rank=LOCAL_RANK, workers=workers,image_weights=opt.image_weights, quad=opt.quad,prefix=colorstr('train: '), shuffle=True)# 获取标签中最大的类别值与类别数做比较#mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max())  # max label classnb = len(train_loader)  # number of batches# 如果小于则出现问题#assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'# Process 0if RANK in [-1, 0]:# 创建测试集val_loader = create_dataloader(val_path,val_path2, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,hyp=hyp, cache=None if noval else opt.cache,rect=True, rank=-1, workers=workers * 2, pad=0.5,prefix=colorstr('val: '))[0]

2、LoadImagesAndLabels

这个函数是改动的重点,我把生成cache的代码单独封装一个函数get_cache,以获得cache2进行后续操作。

def img2label_paths(img_paths):# Define label paths as a function of image pathssa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep  # /images/, /labels/ substringsreturn [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
def get_cache(path):prefix=''cache_version = 0.6  # dataset labels *.cache versiontry:#1、获取图片f = []  # image filesfor p in path if isinstance(path, list) else [path]:p = Path(p)  # os-agnosticif p.is_dir():  # dirf += glob.glob(str(p / '**' / '*.*'), recursive=True)# f = list(p.rglob('*.*'))  # pathlibelif p.is_file():  # filewith open(p) as t:t = t.read().strip().splitlines()parent = str(p.parent) + os.sep #上级目录os.sep是分隔符f += [x.replace('./', parent) if x.startswith('./') else x for x in t]  # local to global path# f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)else:raise Exception(f'{prefix}{p} does not exist')# 2、过滤不支持格式的图片im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlibassert im_files, f'{prefix}No images found'except Exception as e:raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')label_files = img2label_paths(im_files)  # 获取labelscache_path = (p if p.is_file() else Path(label_files[0]).parent).with_suffix('.cache')try:cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dictassert cache['version'] == cache_version  # same versionassert cache['hash'] == get_hash(label_files + im_files)  # same hash 判断hash值是否改变except Exception:cache, exists = cache_labels(cache_path,im_files,label_files, prefix), False  # cache# Display cache  过滤结果打印nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupt, totalif exists:d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT)  # display cache resultsif cache['msgs']:LOGGER.info('\n'.join(cache['msgs']))  # display warningsreturn cache
#↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
#一、数据处理
class LoadImagesAndLabels(Dataset):# YOLOv5 train_loader/val_loader, loads images and labels for training and validationdef __init__(self, path, path2, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):#创建参数self.img_size = img_sizeself.augment = augment #是否数据增强self.hyp = hyp #超参数self.image_weights = image_weights #图片采样权重self.rect = False if image_weights else rect #矩阵训练self.stride = stride #下采样步数self.path = pathself.path2 = path2self.albumentations = Albumentations() if augment else Nonecache = get_cache(self.path)cache2 = get_cache(self.path2)# Read cache[cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items[cache2.pop(k) for k in ('hash', 'version', 'msgs')]  # remove itemslabels, shapes, self.segments = zip(*cache.values())self.labels = list(labels)self.shapes = np.array(shapes, dtype=np.float64)self.im_files = list(cache.keys())  # update 图片列表self.im_files2 = list(cache2.keys())  # update 图片列表self.label_files = img2label_paths(cache.keys())  # update 标签列表n = len(shapes)  # number of images 14329bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index 将每一张图片batch索引nb = bi[-1] + 1  # number of batchesself.batch = bi  # batch index of imageself.n = nself.indices = range(n)# Update labels#过滤类别include_class = []  # filter labels to include only these classes (optional)include_class_array = np.array(include_class).reshape(1, -1)for i, (label, segment) in enumerate(zip(self.labels, self.segments)):if include_class:j = (label[:, 0:1] == include_class_array).any(1)self.labels[i] = label[j]if segment:self.segments[i] = segment[j]if single_cls:  # single-class training, merge all classes into 0 把所有目标归为一类self.labels[i][:, 0] = 0if segment:self.segments[i][:, 0] = 0#是否采用矩形构造if self.rect:# Sort by aspect ratios = self.shapes  # whar = s[:, 1] / s[:, 0]  # aspect ratio #高和宽的比irect = ar.argsort() #根据ar排序self.im_files = [self.im_files[i] for i in irect]self.label_files = [self.label_files[i] for i in irect]self.labels = [self.labels[i] for i in irect]self.shapes = s[irect]  # whar = ar[irect]# Set training image shapes 设置训练图片的shapes# 对同个batch进行尺寸处理shapes = [[1, 1]] * nbfor i in range(nb):ari = ar[bi == i]mini, maxi = ari.min(), ari.max()if maxi < 1:shapes[i] = [maxi, 1]elif mini > 1:shapes[i] = [1, 1 / mini]self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * strideself.ims = [None] * self.nself.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]self.npy_files2 = [Path(f).with_suffix('.npy') for f in self.im_files2]

3、__getitem__

接着更改图像增强部分的输入输出,时刻检查两图像是否是同一张图片。

    #二、图片增强def __getitem__(self, index):#根据每个类别数量获得图片采样权重,获取新的下标i = self.indices[index]  # linear, shuffled, or image_weightshyp = self.hyp#↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓# Load image resize图片im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i], #判断有没有这个图片im2, f2, fn2 = self.ims[i], self.im_files2[i], self.npy_files2[i], #判断有没有这个图片img, (h0, w0), (h, w) = self.load_image(im, f, fn)img2, (h02, w02), (h2, w2) = self.load_image(im2, f2, fn2)# Letterboxshape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size  # final letterboxed shapeimg, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)img2, ratio2, pad2 = letterbox(img2, shape, auto=False, scaleup=self.augment)labels = self.labels[index].copy()shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescalingif labels.size:  # normalized xywh to pixel xyxy formatlabels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])nl = len(labels)  # number of labelsif nl:labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)labels_out = torch.zeros((nl, 6))if nl:labels_out[:, 1:] = torch.from_numpy(labels)# Convertimg = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGBimg = np.ascontiguousarray(img)img2 = img2.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGBimg2 = np.ascontiguousarray(img2)return torch.from_numpy(img),torch.from_numpy(img2), labels_out, self.im_files[index],self.im_files2[index], shapes#↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑@staticmethoddef collate_fn(batch): #如何取样本im,im2, label, path,path2, shapes = zip(*batch)  # transposedfor i, lb in enumerate(label):lb[:, 0] = i  # add target image index for build_targets()return torch.stack(im, 0),torch.stack(im2, 0), torch.cat(label, 0), path,path2, shapes@staticmethoddef collate_fn4(batch):im,im2, label, path,path2, shapes = zip(*batch)  # transposedn = len(shapes) // 4im4,im42, label4, path4,path42, shapes4 = [],[], [], path[:n],path2[:n], shapes[:n]ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]])  # scalefor i in range(n):  # zidane torch.zeros(16,3,720,1280)  # BCHWi *= 4if random.random() < 0.5:im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[0].type(img[i].type())im2 = F.interpolate(img2[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[0].type(img[i].type())lb = label[i]else:im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)im2 = torch.cat((torch.cat((img2[i], img2[i + 1]), 1), torch.cat((img2[i + 2], img2[i + 3]), 1)), 2)lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * sim4.append(im)im42.append(im2)label4.append(lb)for i, lb in enumerate(label4):lb[:, 0] = i  # add target image index for build_targets()return torch.stack(im4, 0),torch.stack(im42, 0), torch.cat(label4, 0), path4,path42,shapes4

4、pbar更改

因为传入的是双模态数据集,所以batch和原来的不一样

由im,targets, paths, shapes 变成了im,im2,targets, paths,paths2, shapes

所以对应的train.py 和 val.py进行修改

train.py

 if RANK in [-1, 0]:# 通过tqdm创建进度条,方便训练信息的展示pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress baroptimizer.zero_grad() #梯度训练for i, (imgs,imgs2, targets, paths,paths2, _) in pbar:  # batch -------------------------------------------------------------# 计算迭代次数ni = i + nb * epoch  # number integrated batches (since train start)imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0

val.py

    pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress barfor batch_i, (im,im2,targets, paths,paths2, shapes) in enumerate(pbar):t1 = time_sync()if cuda:im = im.to(device, non_blocking=True)targets = targets.to(device)im = im.half() if half else im.float()  # uint8 to fp16/32im /= 255  # 0 - 255 to 0.0 - 1.0nb, _, height, width = im.shape  # batch size, channels, height, widtht2 = time_sync()dt[0] += t2 - t1

如果改动的比较顺利,现在运行原本的yolov5命令,代码是可以运行的,因为现在虽然传入了两个模态,但只用了一个模态,并不影响。如果运行不了,根据报错仔细的查看代码进行修改,一般都是输入输出的问题。

3.2.2  train_batch显示

yolov5会在训练的前三个batch生成train_batch.jpg图片,方便检查传入网络的图片和label是否正确。我想检查一下生成的红外与可见光图片是否一致,以及labels是否正确,就需要改相应的函数,下面提供思路。

1、首先要找到生成train_batch图片的位置

采取关键字查找法,在ubuntu使用如下命令进行查找

grep -rn "train_batch" *

搜寻结果如下,可以看到在utils/loggers/__init__.py文件中

2、更改函数

直接拷贝这个函数,改一下生成图片的名字以做区分

    def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn):# Callback runs on train batch endif plots:if ni == 0:if not sync_bn:  # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754with warnings.catch_warnings():warnings.simplefilter('ignore')  # suppress jit trace warningself.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])if ni < 3:f = self.save_dir / f'train_batch{ni}.jpg'  # filenameThread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()if self.wandb and ni == 10:files = sorted(self.save_dir.glob('train*.jpg'))self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})def on_train_batch_end2(self, ni, model, imgs, targets, paths, plots, sync_bn):# Callback runs on train batch endif plots:if ni == 0:if not sync_bn:  # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754with warnings.catch_warnings():warnings.simplefilter('ignore')  # suppress jit trace warningself.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])if ni < 3:f = self.save_dir / f'train_batch2{ni}.jpg'  # filenameThread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()if self.wandb and ni == 10:files = sorted(self.save_dir.glob('train*.jpg'))self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})

3、train.py引用

在train.py添加这个函数,让img2引用

我当时改到这一步认为已经改完了,但代码还是跑不起来于是,我搜寻了一下原来on_train_batch_end所寻在的文件,结果如下

看到在callbacks.py还存在这个函数,于是将更改的函数同样加到这里,如下所示

之后再次运行代码,跑通!!!!可以看到红外和可见光的图片相对应,labels也正确。

到此数据输入的部分算是搞定了,下面该是更改网络部分。

3.3 更改backbone

想要更改backbone需要对yolo.py十分熟悉,整个改代码调试的过程太过艰辛,这里只分享主要思路和代码。改更改代码之前首先要理清楚几个概念。

3.3.1 yaml解析文件

yolo.py是十分重要的文件,包括了网络的读取和构建,yolov5是使用yaml文件读取网络的,新手刚开始看十分迷惑,网上的注释很多可以借鉴,这里不再阐述。

parse_model函数是yaml的解析文件,这里需要改成双模态的,首先双模态需要两个分开来的输入和输出,避免在调用时候混淆。因此我加了ch2、c12、c22。随后根据层的序列来判断是走哪个backbone。需要注意的是在最后ch2需要包括ch的所有层,如果不包含的话,索引是不匹配的。

#解析网络配置文件构建模型
def parse_model(d, ch,ch2):  # model_dict, input_channels(3)LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")#记录日志#1、以下是读取配置dict里的参数#————————————————————————————————————————————————————————————————anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5) 输出的通道数# 2、开始搭建网络#————————————————————————————————————————————————————————————————'''layers: 保存每一层的层结构# save: 记录下所有层结构中from中不是-1的层结构序号# c2: 保存当前层的输出channel'''layers, save, c2,c22 = [], [], ch[-1] ,ch2[-1] # layers, savelist, ch out初始化# from(当前层输入来自哪些层), number(当前层次数 初定模型深度), module(当前层类别), args(当前层类参数 初定)for i, (f, n, m, args) in enumerate(d['backbone1']+d['backbone2']+d['head']):  # 遍历backbone和head的每一层m = eval(m) if isinstance(m, str) else m  #得到当前层的真实类名 for j, a in enumerate(args): #循环模块参数argstry:args[j] = eval(a) if isinstance(a, str) else a  # eval stringsexcept NameError:passn = n_ = max(round(n * gd), 1) if n > 1 else nif i<10:if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, conv_bn_relu_maxpool, Shuffle_Block]:c1, c2 = ch[f], args[0]if c2 != no:  c2 = make_divisible(c2 * gw, 8)  #保证通道是8的倍数# 在初始arg的基础上更新 加入当前层的输入channel并更新当前层# [in_channel, out_channel, *args[1:]]args = [c1, c2, *args[1:]] # depth gain 控制深度  如v5s: n*0.33   n: 当前模块的次数(间接控制深度)else : if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, conv_bn_relu_maxpool, Shuffle_Block]:c12, c22 = ch2[f], args[0] #保存输出if c22 != no:  c22 = make_divisible(c22 * gw, 8)  #保证通道是8的倍数args = [c12, c22, *args[1:]] if m in [BottleneckCSP, C3, C3TR, C3Ghost]:args.insert(2, n)   #在第二个位置插入bottleneck个数nn = 1  #重置elif m is nn.BatchNorm2d: # BN层只需要返回上一层的输出channelargs = [ch2[f]]elif m is Concat2: #Concat:f是所有需要拼接层的索引,则输出通道c2是所有层的和c22 = 0for x in f:if x==-1:c2p = ch2[x]else:c2p = ch[x]c22 = c22+c2pelif m is Concat: #Concat:f是所有需要拼接层的索引,则输出通道c2是所有层的和c22 = sum(ch2[x] for x in f)elif m is Detect:#args先填入每个预测层的输入通道数,然后填入生成所有预测层对应的预测框的初始高宽的列表。args.append([ch2[x] for x in f]) #在args中加入三个Detect层的输出channelif isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)elif m is Contract:c22 = ch2[f] * args[0] ** 2elif m is Expand:c22 = ch2[f] // args[0] ** 2else:c22 = ch2[f]#拿args里的参数去构建了module m,然后模块的循环次数用参数n控制。# m_: 得到当前层module  如果n>1就创建多个m(当前层结构), 如果n=1就创建一个mm_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # 打印当前层结构的一些基本信息np = sum(x.numel() for x in m_.parameters())  # 计算这一层的参数量m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # 打印日志文件信息(每一层module构建的编号、参数量等)# append to savelist  把所有层结构中from不是-1的值记下 save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_) # 将当前层结构module加入layers中if i == 0:ch = []ch.append(c2)if i == 10:ch2 = [32, 64, 64, 128, 128, 256, 256, 512, 512, 512] #层编号是包含backbone1的ch2.append(c22)return nn.Sequential(*layers), sorted(save)      #当循环结束后再构建成模型 

3.3.2 backbone

yaml文件改好以后就可以改backbone了如下,这里我为了不让和正常的concat混淆,加了一个concat2,一模一样只是函数名字变了

# 双模型态度
nc: 7
depth_multiple: 0.33
width_multiple: 0.50#边界框的设置
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32backbone1:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2 #64代表通道数,3表示3*3的卷积核,2代表步长为2,2表示分两组卷积#320×320×32[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4#160×160×64[-1, 3, C3, [128]],#160×160×64[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8#80×80×128[-1, 6, C3, [256]],#80×80×128[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16#40×40×256[-1, 9, C3, [512]],#40×40×256[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32#20×20×512[-1, 3, C3, [1024]],#20×20×512[-1, 1, SPPF, [1024, 5]],  # 9#20×20×512]backbone2:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  #10#320×320×32[-1, 1, Conv, [128, 3, 2]],  # #160×160×64[-1, 3, C3, [128]], #160×160×64 [-1, 1, Conv, [256, 3, 2]],#80×80×128  [-1, 6, C3, [256]],  # 14#80×80×128[[-1, 4], 1, Concat2, [1]],  ##80×80×256[-1, 6, C3, [256]],  #16#80×80×128[14, 1, Conv, [512, 3, 2]],  # #40×40×256[-1, 9, C3, [512]],#40×40×256[[-1, 6], 1, Concat2, [1]],  ##40×40×512[-1, 6, C3, [512]],  #20#40×40×256[17, 1, Conv, [1024, 3, 2]],  # #20×20×512[[-1, 9], 1, Concat2, [1]],  # #20×20×1024[-1, 3, C3, [1024]],#20×20×512[-1, 1, SPPF, [1024, 5]],  #24#20×20×512]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],#20×20×256[-1, 1, nn.Upsample, [None, 2, 'nearest']],#40×40×256[[-1, 20], 1, Concat, [1]],  # cat backbone P4#40×40×512[-1, 3, C3, [512, False]],  # 28#40×40×256[-1, 1, Conv, [256, 1, 1]],#40×40×128[-1, 1, nn.Upsample, [None, 2, 'nearest']],#80×80×128[[-1, 16], 1, Concat, [1]],  # cat backbone P3#80×80×256[-1, 3, C3, [256, False]],  # 32 (P3/8-small)#80×80×128[-1, 1, Conv, [512, 3, 2]],#40×40×256[[-1, 28], 1, Concat, [1]],  # cat head P4#40×40×512[-1, 3, C3, [512, False]],  # 35 (P4/16-medium)#40×40×256[-1, 1, Conv, [1024, 3, 2]],#20×20×512[[-1, 24], 1, Concat, [1]],  # cat head P5#20×20×1024[-1, 3, C3, [1024, False]],  # 38 (P5/32-large)#20×20×512[[32, 35, 38], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.4 Forward更改

Forward是网络向前传播的逻辑所在,要解决的问题是x1和x2怎么判断进入到不同的backbone。

    def _forward_once(self, x1,x2, profile=False, visualize=False):y, dt = [], []  # outputsfor m in self.model:if m.i<10:x1 = m(x1)x2=x2y.append(x1 if m.i in self.save else None)  # save outputelse:if m.f != -1:  # if not from previous layerx2 = y[m.f] if isinstance(m.f, int) else [x2 if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:self._profile_one_layer(m, x2, dt)x2 = m(x2)  # runx1=x1y.append(x2 if m.i in self.save else None)  # save outputif visualize:feature_visualization(x2, m.type, m.i, save_dir=visualize)feature_vis = Falseif m.type == 'models.common.C3' and feature_vis and m.i==17:print(m.type, m.i)feature_visualization2(x2, m.type, m.i,128,8,16)return x2

3.5 调试运行

python train.py --data data/dual.yaml --cfg models/dualyolov5s.yaml --weights weights/yolov5s.pt --batch-size 12 --epochs 50

网络结构如下

              from  n    params  module                                  arguments                     0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]              1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                2                -1  1     18816  models.common.C3                        [64, 64]                      3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               4                -1  2    148992  models.common.C3                        [128, 128]                    5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              6                -1  3    889344  models.common.C3                        [256, 256]                    7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              8                -1  1   1182720  models.common.C3                        [512, 512]                    9                -1  1    656896  models.common.SPPF                      [512, 512, 5]                 10                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]              11                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                12                -1  1     18816  models.common.C3                        [64, 64, 1]                   13                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               14                -1  2    115712  models.common.C3                        [128, 128, 2]                 15           [-1, 4]  1         0  models.common.Concat2                   [1]                           16                -1  2    132096  models.common.C3                        [256, 128, 2]                 17                14  1    295424  models.common.Conv                      [128, 256, 3, 2]              18                -1  3    625152  models.common.C3                        [256, 256, 3]                 19           [-1, 6]  1         0  models.common.Concat2                   [1]                           20                -1  2    526336  models.common.C3                        [512, 256, 2]                 21                17  1   1180672  models.common.Conv                      [256, 512, 3, 2]              22           [-1, 9]  1         0  models.common.Concat2                   [1]                           23                -1  1   1444864  models.common.C3                        [1024, 512, 1]                24                -1  1    656896  models.common.SPPF                      [512, 512, 5]                 25                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]              26                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          27          [-1, 20]  1         0  models.common.Concat                    [1]                           28                -1  1    361984  models.common.C3                        [512, 256, 1, False]          29                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              30                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          31          [-1, 16]  1         0  models.common.Concat                    [1]                           32                -1  1     90880  models.common.C3                        [256, 128, 1, False]          33                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]              34          [-1, 28]  1         0  models.common.Concat                    [1]                           35                -1  1    361984  models.common.C3                        [512, 256, 1, False]          36                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]              37          [-1, 24]  1         0  models.common.Concat                    [1]                           38                -1  1   1444864  models.common.C3                        [1024, 512, 1, False]         39      [32, 35, 38]  1     32364  models.yolo.Detect                      [7, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
Model Summary: 530 layers, 13493740 parameters, 13493740 gradients

可以到看效果还是不错的,训练13个epoch就map0.5就达到了0.85左右

经过漫长的调试,也总算是能够运行了,整个更改过程是我自己摸索出来的,网上没有相关的参考,如果有更简洁的改法可以讨论。后期还有大量的优化工作,有时间在更新。

工程(八)——yolov5可见光+红外双模态融合(代码)相关推荐

  1. 【图像融合】自适应参考图像的可见光与热红外彩色图像融合算法

    彩色图像融合技术   可见光与红外热图像的彩色图像融合技术,主要用于夜视应用.基于人眼彩色视觉特性的彩色夜视技术充分利用微光和红外波段的夜视图像信息,可使观察者的目标识别速度和准确度提高30%~60% ...

  2. 多模态深度学习综述:网络结构设计和模态融合方法汇总

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨小奚每天都要学习@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p ...

  3. 自动驾驶 Apollo 源码分析系列,感知篇(八):感知融合代码的基本流程

    说起自动驾驶感知系统,大家都会谈论到感知融合,这涉及到不同传感器数据在时间.空间的对齐和融合,最终的结果将提升自动驾驶系统的感知能力,因为我们都知道单一的传感器都是有缺陷的.本篇文章梳理 Apollo ...

  4. RGBD模态融合问题

    当图像特征中RGB图像和depth图像特征有鸿沟时,怎么进行融合呢?这里从RGBD SOD显著性检测的领域找了一些最新论文的方法,并附上代码,以供参考. 1.2022 Multi-modal inte ...

  5. 可见近红外双发射荧光-钙钛矿异质结量子点

    量子点是一类半导体纳米粒子,其激子在三维空间受限在与其激子玻尔直径相当或更小的尺度范围内,从而表现出与对应的宏观物质不同的物理性质.近年来,新兴的钙钛矿量子点以其优异而独特的光电特性,如可见光波段能级 ...

  6. 八十八枚红手印背后的故事

    八十八枚红手印背后的故事 杨桂林 一 这是2019年的一天."听说咱们的第一书记要回单位了!"这个消息一阵风似的传遍了乌兰敖包嘎查方圆几十公里的山坳里的牧户.草场.牧民们一些一时没 ...

  7. 【C4AI-2022】基于可见光与激光雷达数据融合的航天器三维精细结构智能重建

    基于可见光与激光雷达数据融合的航天器三维精细结构智能重建 选题背景 随着航天技术的快速发展,空间活动任务类型呈现出多元化的发展趋势.其中,太空垃圾快速清除.故障卫星在轨维修.空间目标监视寄生.空间卫星 ...

  8. Yolov5总结文档(理论、代码、实验结果)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨Mr.Hang@知乎 来源丨https://zhuanlan.zhihu.com/p/44925 ...

  9. iOS 11开发教程(十八)iOS11应用视图之使用代码添加按钮

    iOS 11开发教程(十八)iOS11应用视图之使用代码添加按钮 由于使用编辑界面添加视图的方式比较简单,所以不在介绍.这里,直接讲解代码中如何添加.使用代码为主视图添加一个按钮的方式和在1.3.3节 ...

最新文章

  1. MOSS字段编辑权限控制方案--发布源码
  2. Python3 列表的基本操作
  3. 搭建rabbitmq的docker集群
  4. LUA string的状态
  5. 关于jquery-Validate
  6. 【大牛疯狂教学】mysqlinnodb和myisam
  7. ExtJs2.0学习系列(15)--extjs换肤
  8. 开课吧:数据分析师常用的分析方法有哪些?
  9. RHEL5_x64上安装oracle 11.2
  10. python urllib2详解及实例
  11. extjs4 视频教程
  12. 计数器代码php,php的计数器程序_php
  13. 双注入法/开路短路法
  14. 插入数据报错: Incorrect string value: ‘\xE8\xB5\xB5\xE9\x9B\xB7‘ for column ‘Sname‘ at row 1
  15. 论文笔记(二十二):Soft Tracking Using Contacts for Cluttered Objects to Perform Blind Object Retrieval
  16. 安全防护与信息加密:一个新的挑战
  17. MySQL主从同步数据
  18. [开源精品] C#.NET im 聊天通讯架构设计 -- FreeIM 支持集群、职责分明、高性能
  19. 常见 Web 攻击介绍
  20. 淮北职业技术学院大一计算机考试,淮北职业技术学院2020年录取分数线(附2018-2020年分数线)...

热门文章

  1. pymysql的安装及使用
  2. MQTT客户端软件mqtt-spy使用教程
  3. 流量限制(rate-limiting)
  4. PostgreSQL字符串连接
  5. 网站301怎么开启并设置
  6. IIS 7的配置问题
  7. opencv imshow函数报cv::exception错误,以及sift算法的使用问题
  8. CSS —— BFC机制
  9. 微信小程序如何拍照?
  10. AQS 从后往前遍历寻找继任者