一、数据集格式

二、解析xml文件,生成data_center.txt

from PIL import Image
import math,os
from xml.etree import ElementTree as ETdef keep_image_size_open(path, size=(256, 256)):img = Image.open(path)temp = max(img.size)mask = Image.new('RGB', (temp, temp), (0, 0, 0))mask.paste(img, (0, 0))mask = mask.resize(size)return maskdef make_data_center_txt(xml_dir):with open('data_center.txt', 'a') as f:f.truncate(0)path=r'data/images'xml_names = os.listdir(xml_dir)for xml in xml_names:xml_path = os.path.join(xml_dir, xml)in_file = open(xml_path)tree = ET.parse(in_file)root = tree.getroot()image_path = root.find('path')polygon = root.find('outputs/object/item/polygon')data = []c_data = []data_str = ''print(xml)for i in polygon:data.append(int(i.text))data_str = data_str + ' ' + str(i.text)for i in range(0, len(data), 2):c_data.append((data[i], data[i + 1]))data_str = os.path.join(path,image_path.text.split('\\')[-1]) +data_strf.write(data_str + '\n')if __name__ == '__main__':make_data_center_txt('data/xml')

三、加载数据集

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Imagefrom heatmap_label import CenterLabelHeatMaptf = transforms.Compose([  #标准化处理transforms.ToTensor()
])class MyDataset(Dataset):def __init__(self,root): #传入路径f=open(root,'r')self.dataset=f.readlines() #读所有行def __len__(self):return len(self.dataset) #返回数据集长度def __getitem__(self, index):data=self.dataset[index] #取当前数据img_path=data.split(' ')[0] #以空格划分,并取出文件名,即data/images\0.pngimg_data=Image.open(img_path).resize((256, 256)) #打开图片# points = data.split(' ')[1:-2]  # 取出后面5个点的x,y坐标,-2是取不到的points=data.split(' ')[1:] #取出后面5个点的x,y坐标# print(img_data, points)#将坐标映射到256*256大小的图片上points = [int(points[0])*256/774, int(points[1])*256/434, int(points[2])*256/774, int(points[3])*256/434, int(points[4])*256/774, int(points[5])*256/434]# points=[int(i)/100 for i in points] #图像宽高为100,int(i)/100进行归一化# print(img_data, points)label = []for i in range(0, len(points), 2):heatmap = CenterLabelHeatMap(256, 256, points[i], points[i+1], 5)label.append(heatmap)#一个关键点会生成一个通道,3个关键点生成3个通道label = np.stack(label) #将列表转成数组的形式return tf(img_data), torch.Tensor(label) #将img_data标准化,将points转化为tensor格式if __name__ == '__main__':data=MyDataset('data_center.txt')for i in data:print(i[0].shape)print(i[1].shape)

四、构建网络

import torch
from torch import nn
from torch.nn import functional as Fclass Conv_Block(nn.Module):def __init__(self,in_channel,out_channel):super(Conv_Block, self).__init__()self.layer=nn.Sequential(nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU(),nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU())def forward(self,x):return self.layer(x)class DownSample(nn.Module):def __init__(self,channel):super(DownSample, self).__init__()self.layer=nn.Sequential(nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(channel),nn.LeakyReLU())def forward(self,x):return self.layer(x)class UpSample(nn.Module):def __init__(self,channel):super(UpSample, self).__init__()self.layer=nn.Conv2d(channel,channel//2,1,1)def forward(self,x,feature_map):up=F.interpolate(x,scale_factor=2,mode='nearest')out=self.layer(up)return torch.cat((out,feature_map),dim=1)class UNet(nn.Module):def __init__(self,num_classes):super(UNet, self).__init__()self.c1=Conv_Block(3,64)self.d1=DownSample(64)self.c2=Conv_Block(64,128)self.d2=DownSample(128)self.c3=Conv_Block(128,256)self.d3=DownSample(256)self.c4=Conv_Block(256,512)self.d4=DownSample(512)self.c5=Conv_Block(512,1024)self.u1=UpSample(1024)self.c6=Conv_Block(1024,512)self.u2 = UpSample(512)self.c7 = Conv_Block(512, 256)self.u3 = UpSample(256)self.c8 = Conv_Block(256, 128)self.u4 = UpSample(128)self.c9 = Conv_Block(128, 64)self.out=nn.Conv2d(64,3, 3, 1, 1)def forward(self,x):R1=self.c1(x)R2=self.c2(self.d1(R1))R3 = self.c3(self.d2(R2))R4 = self.c4(self.d3(R3))R5 = self.c5(self.d4(R4))O1=self.c6(self.u1(R5,R4))O2 = self.c7(self.u2(O1, R3))O3 = self.c8(self.u3(O2, R2))O4 = self.c9(self.u4(O3, R1))return self.out(O4)if __name__ == '__main__':x=torch.randn(2,3,256,256)net=UNet(num_classes=3)print(net(x).shape)

五、开始训练

import osfrom torch import nn,optim
import torch
from dataset import *
from net import *
from torch.utils.data import DataLoaderif __name__ == '__main__':device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')net=UNet(num_classes=3).to(device) #实例化网络并指认到设备上weights='params/unet.pth'if os.path.exists(weights): #如果有初始权值就加载net.load_state_dict(torch.load(weights)) #加载权重print('loading successfully')opt=optim.Adam(net.parameters()) #指定优化器并传入参数# loss_fun=nn.BCELoss() #定义损失函数loss_fun=nn.BCEWithLogitsLoss()dataset=MyDataset('data_center.txt') #实例化数据集data_loader=DataLoader(dataset,batch_size=2,shuffle=True) #加载数据集epoch = 1while True:for i,(image,label) in enumerate(data_loader): #用枚举的方式遍历数据集image,label=image.to(device),label.to(device) #将图片和标签指认到设备上# print(image.shape, label.shape)out=net(image) #将图片输入网络train_loss=loss_fun(out,label) #预测值和真是标签做损失print(f'{epoch}-{i}-train_loss:{train_loss.item()}') #打印当前轮次当前批次的训练损失opt.zero_grad() #梯度清零train_loss.backward() #反向传播opt.step() #更新梯度if epoch % 10 == 0: #每10轮保存一次权重torch.save(net.state_dict(),f'params/unet.pth') #保存参数print('save successfully')epoch += 1

六、利用训练好的权重进行预测

import osimport torch
from PIL import Image,ImageDraw
from dataset import *
from net import *    #import * 代表导入所有path='test_image'
net=UNet(num_classes=3) #实例化网络
net.load_state_dict(torch.load('params/unet.pth')) #加载训练好的权重
net.eval() #测试模式
for j in os.listdir(path):img=Image.open(os.path.join(path,j)).resize((256, 256))draw=ImageDraw.Draw(img) #创建画板img_data=tf(img) #标准化img_data=torch.unsqueeze(img_data,dim=0) #设置批次维度out=net(img_data)out=out.squeeze()d=torch.max_pool2d(out, 256).squeeze()print(d)rst = []for i in range(3): #有3个关键点,故有3个通道h,w=np.where(out[i]==out[i].max()) #当前通道恒等于当前通道的最大值,就取其索引# rst.append((w[0], h[0]))draw.ellipse((w[0]*774/256-2, h[0]*434/256-2, w[0]*774/256+2, h[0]*434/256+2),(255,0,0)) #画半径为2的圆img.show()img.save(f'test_result/{j}')

reference

>>>>>来自B站大佬

【深度学习关键点回归(直接回归法&heatmap热力图法)】 https://www.bilibili.com/video/BV1sS4y197J1/?p=2&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874

关键点检测——heatmap热力图法相关推荐

  1. HRNet人体关键点检测

    Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019 oral) 文章地址:https:/ ...

  2. RTMPose关键点检测实战——笔记3

    文章目录 摘要 安装MMPose 安装虚拟环境 安装pytorch 安装MMCV 安装其他的安装包 下载 MMPose 下载预训练模型权重文件和视频素材 安装MMDetection 安装Pytorch ...

  3. 夸克APP端智能:文档关键点检测实践与应用

    作者:顺达 最近夸克端智能小组在做端上的实时文档检测,即输入一张RGB图像,得到文档的四个角的关键点的坐标.整个pipelines属于关键点检测算法,因此最近对相关领域的论文进行阅读和进行了实验尝试. ...

  4. OpenMMLab AI实战营第二期|人体关键点检测与MMPose学习笔记

    OpenMMLab AI实战营第二期|人体关键点检测与MMPose学习笔记 文章目录 OpenMMLab AI实战营第二期|人体关键点检测与MMPose学习笔记 一.前言 1.1 人体姿态概述 1.2 ...

  5. Python+OpenCV+OpenPose实现人体姿态估计(人体关键点检测)

    目录 1.人体姿态估计简介 2.人体姿态估计数据集 3.OpenPose库 4.实现原理 5.实现神经网络 6.实现代码 1.人体姿态估计简介 人体姿态估计(Human Posture Estimat ...

  6. 计算机视觉方向简介 | 人体骨骼关键点检测综述

    什么是人体骨骼关键点检测 人体骨骼关键点检测,即Pose Estimation,主要检测人体的一些关键点,如关节,五官等,通过关键点描述人体骨骼信息: 应用与挑战 人体骨骼关键点检测是计算机视觉的基础 ...

  7. 猫脸关键点检测大赛:三种方法,轻松实现猫脸识别!

    导语:挑战猫脸,就差你了! 今天这个比赛,得从一个做程序猿的铲屎官开始说起...... 话说,有一天「铲屎猿」早起之后,发现猫主子竟然没了身影:他找啊找啊,找了好久,可仍然到处都没找到猫主子.这时,客 ...

  8. opencv 图像雾检测_OpenCV图像处理-基于OpenPose的关键点检测

    OpenCV基于OpenPose的手部关键点检测 概述 ✔️ 手部关键点检测,旨在找出给定图片中手指上的关节点及指尖关节点, 其中手部关键点检测的应用场景主要包括: 手势识别 手语识别与理解 手部的行 ...

  9. PFLD:简单高效的实用人脸关键点检测算法

    作者丨杜敏 学校丨华中科技大学硕士 研究方向丨模式识别与智能系统 研究背景 人脸关键点检测,在很多人脸相关的任务中,属于基础模块,很关键.比如人脸识别.人脸验证.人脸编辑等等.想做人脸相关的更深层次的 ...

最新文章

  1. /UI5/IF_UI5_REP_PERSISTENCE - why I cannot deploy app to GM6
  2. Windows 11 预览版 Build 22000.120 发布
  3. 我的播客开通的第一天
  4. 初入c++ (八) c++输入和输出
  5. php eureka客户端,Spring Cloud(一)配置Eureka 服务器(示例代码)
  6. 小tip: 使用CSS将图片转换成黑白(灰色、置灰)[转]
  7. PS2251-07 海力士(金士顿U盘量产,成功!)
  8. vc与三菱PLC编程口通信C语言源代码,三菱PLC通讯与编程实例!
  9. 超星阅读器pdz文件转为xps文件或pdf文件说明
  10. windows 网卡驱动安装
  11. pudn下载地址的规律
  12. m1电脑推荐使用Google Chrome浏览器
  13. cat 常用的日志分析架构方案_深度剖析|数据库生产常用架构方案
  14. 西电微机系统课程设计步进电机开环控制系统
  15. 古人教你怎样识人不走眼
  16. ACR2010_现实医疗环境下RA缓解率低是否可以用预测因素解释
  17. C#大作业——学生信息管理系统
  18. XP中服务与后门技术
  19. Android开发面试经典题目
  20. 移动端App广告常见的10种形式

热门文章

  1. 数据字典标准与统一的重要性(码表枚举值)
  2. 06JVM运行时内存分析
  3. CTF-Crypto-(1)
  4. MATLAB 让两个或多个AXES同步旋转
  5. MIT 6.828 (三) Lab 3: User Environments
  6. 蓝桥杯第十一届7月试题(免费下载)
  7. Win10 优化设置
  8. 屏幕开发-屏幕文本的翻译
  9. PHP开源客服工单系统:PESCMS Ticket
  10. 计算机一级office题库哪个好,全国计算机等级考试上机考试与题库解析:一级MSOffice...