ABCNet的下载与训练--训练自己的数据集
文章目录
- @[TOC](文章目录)
- 前言
- 一、ABCNet的下载与demo
- 1.下载
- 2. demo
- 二、训练自己的数据集
- 1. 使用标注工具windows_label_tool
- 2. 转换为json (很重要,json文件错了,会出很多问题)
- 3. 训练
- 1. 修改相关配置文件
- 2. 训练
- 3. 测试
- 总结
- inference
前言
这段事件跑实验,正好用到了ABCNet, 中间遇到了很多的问题,特此记录,以避免大家再遇到这样的问题
一、ABCNet的下载与demo
1.下载
ABCNet是AdelaiDet中对于BAText的一个高效的端到端场景文本定位框架
是基于Detectron2的,所以首先要下载Detectron2
我的 Requirements:
Linux with Python = 3.7.11 ,cuda = 10.1,PyTorch = 1.8.1
pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
其他torch版本:
torch版本
# 下载 detectron2
git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
git checkout -f 9eb4831
cd ..
python -m pip install -e detectron2/
# 下载 AdelaiDet
git clone https://github.com/aim-uofa/AdelaiDet.git
cd AdelaiDet
python setup.py build develop
2. demo
先使用预训练的权重模型测试一下。
下载 CTW1500 数据集,
cd AdelaiDet/datasets
wget https://drive.google.com/file/d/1ntlnlnQHZisDoS_bgDvrcrYFomw9iTZ0/view?usp=sharing -O CTW1500.zip
unzip CTW1500.zip
rm CTW1500.zip
下载model,再 demo
# Download ctw1500_attn_R_50.pth above
wget -O ctw1500_attn_R_50.pth https://universityofadelaide.box.com/shared/static/okeo5pvul5v5rxqh4yg8pcf805tzj2no.pth
python demo/demo.py \--config-file configs/BAText/CTW1500/attn_R_50.yaml \--input datasets/CTW1500/ctwtest_text_image/ \--opts MODEL.WEIGHTS ctw1500_attn_R_50.pth
二、训练自己的数据集
1. 使用标注工具windows_label_tool
链接: windows_label_tool 提取码: exvx
格式如下(示例):
windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标(顺序如上图),后面是文本内容
4
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,“DOUGLASTON”
50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,“E-313”
41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,“L164”
39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,“F.D.N.Y.”
2. 转换为json (很重要,json文件错了,会出很多问题)
我的标签格式为:每个txt文件中只有一行, 所以不需要标注个数
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84||||“DOUGLASTON”
由于后面json转换代码的问题,由14个点改为了8个点即四对点
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84||||“DOUGLASTON”
# 四对点的顺序 0 3 4 7 为顶点, 1 2 5 6 为控制点
0--1--2--3
| |
7--6--5--4
所需classes.txt文件, 我的只有一类,所以只有 text
text
转换代码
# -*- coding: utf-8 -*-
"""@File : convert_ann_to_json.py@Time : 2020-8-17 16:13@Author : yizuotian@Description : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
"""
import argparse
import json
import os
import sys
import cv2
import numpy as npdef gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):"""根据abcnet的gt标注生成coco格式的json标注:param abc_gt_dir: windows_label_tool标注工具生成标注文件目录:param abc_json_path: ABCNet训练需要json标注路径:param image_dir::param classes_path: 类别文件路径:return:"""# Desktop Latin_embed.# 这是标注列表,可以根据自己的改,但是中文在训练时需要下载 simsun.ttc 字体文件(新宋体)cV2 = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑","苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤","桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁","新", '0', '1', '2', '3', '4', '5', '6', '7', '8','9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']dataset = {'licenses': [],'info': {},'categories': [],'images': [],'annotations': []}with open(classes_path) as f:classes = f.read().strip().split()for i, cls in enumerate(classes, 1):dataset['categories'].append({'id': i,'name': cls,'supercategory': 'beverage','keypoints': ['mean','xmin','x2','x3','xmax','ymin','y2','y3','ymax','cross'] # only for BDN})def get_category_id(cls):for category in dataset['categories']:if category['name'] == cls:return category['id']# 遍历abcnet txt 标注indexes = sorted([f.split('.')[0]for f in os.listdir(abc_gt_dir)])print(indexes)j = 1 # 标注边框id号for index in indexes:# if int(index) >3: continue# print('Processing: ' + index)im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))im_height, im_width = im.shape[:2]dataset['images'].append({'coco_url': '','date_captured': '','file_name': index + '.jpg','flickr_url': '','id': int(index.split('_')[-1]), # img_1'license': 0,'width': im_width,'height': im_height})anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))with open(anno_file) as f:lines = [line for line in f.readlines() if line.strip()]# 没有清晰的标注,跳过if len(lines) <= 1:continuefor i, line in enumerate(lines[1:]):elements = line.strip().split(',')# polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32) # [14,(x,y)]# control_points = bezier_utils.polygon_to_bezier_pts(polygon, im) # [8,(x,y)]# 由14个点改为8个点control_points = np.array(elements[:16]).reshape((-1, 2)).astype(np.float32) # [8,(x,y)]ct = elements[-1].replace('"', '').strip()cls = 'text'# segs = [float(kkpart) for kkpart in parts[:16]]segs = [float(kkpart) for kkpart in control_points.flatten()]xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]# 过滤越界边框if max(xt) > im_width or max(yt) > im_height:print('The annotation bounding box is outside of the image:{}'.format(index))print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))continuexmin = min([xt[0], xt[3], xt[4], xt[7]])ymin = min([yt[0], yt[3], yt[4], yt[7]])xmax = max([xt[0], xt[3], xt[4], xt[7]])ymax = max([yt[0], yt[3], yt[4], yt[7]])width = max(0, xmax - xmin + 1)height = max(0, ymax - ymin + 1)if width == 0 or height == 0:continue# 根据自己标签长度范围而定max_len = 7recs = [len(cV2) + 1 for ir in range(max_len)]ct = str(ct)# print('rec', ct)for ix, ict in enumerate(ct):if ix >= max_len:continueif ict in cV2:recs[ix] = cV2.index(ict)else:recs[ix] = len(cV2)dataset['annotations'].append({'area': width * height,'bbox': [xmin, ymin, width, height],'category_id': get_category_id(cls),'id': j,'image_id': int(index.split('_')[-1]), # img_1'iscrowd': 0,'bezier_pts': segs,'rec': recs})j += 1# 写入json文件folder = os.path.dirname(abc_json_path)if not os.path.exists(folder):os.makedirs(folder)with open(abc_json_path, 'w') as f:json.dump(dataset, f)def main(args):gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)if __name__ == '__main__':"""Usage: python convert_ann_to_json.py \--ann-dir /path/to/gt \--image-dir /path/to/image \--dst-json-path train.json """parse = argparse.ArgumentParser()parse.add_argument("--ann-dir", type=str, default=None) # 标签路径parse.add_argument("--image-dir", type=str, default=None) # 对应的图片路径parse.add_argument("--dst-json-path", type=str, default=None) # 保存json路径parse.add_argument("--classes-path", type=str, default='./classes.txt') arguments = parse.parse_args() # sys.argv[1:]main(arguments)
部分json文件参考:
{"licenses": [],"info": {},"categories": [{"id": 1,"name": "text","supercategory": "beverage","keypoints": ["mean","xmin","x2","x3","xmax","ymin","y2","y3","ymax","cross"]}],"images": [{"coco_url": "","date_captured": "","file_name": "000001.jpg","flickr_url": "","id": 1,"license": 0,"width": 720,"height": 1160},...],"annotations": [{"area": 6868.0,"bbox": [304.0,343.0,101.0,68.0],"category_id": 1,"id": 1,"image_id": 1,"iscrowd": 0,"bezier_pts": [304.0,357.0,454.0,341.0,458.0,394.0,308.0,410.0,329.0,343.0,354.0,345.0,379.0,347.0,404.0,349.0],"rec": [19,16,20,12,19,21,23]},...]
}
3. 训练
显卡:GeForce RTX 2080 Ti *2 batch_size = 2
1. 修改相关配置文件
- 将制作好的data数据目录放在"AdelaiDet/datasets"目录
我的目录结构是:
COCO--annotations--train.json--val.json--train--val
- 修改"adet/data/builtin.py"中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
_PREDEFINED_SPLITS_TEXT = {
"totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
"totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
...
# 以下为修改 (改为自己的)
"COCO_train": ("COCO/train/", "COCO/annotations/train.json"),
"COCO_val": ("COCO/val/", "COCO/annotations/val.json"),
- 在需要训练的配置文件中指定数据集即可.以configs/BAText/Pretrain/Base-Pretrain.yaml为例
_BASE_: "../Base-BAText.yaml"
DATASETS:# 以下为修改(改为自己的)TRAIN: ("COCO_train",)TEST: ("COCO_val",)
- label 中有中文, 需下载这个 simsun.ttc 字体文件 放于 AdelaiDet/simsun.ttc 中
链接:simsun.ttc
提取码:7tr2
2. 训练
OMP_NUM_THREADS=1 python tools/train_net.py \--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \--num-gpus 2 \OUTPUT_DIR output/batext/pretrain/coco01 # 保存路径
3. 测试
MP_NUM_THREADS=1 python tools/train_net.py \--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \--eval-only \--num-gpus 1 \OUTPUT_DIR output/batext/pretrain/coco01_result \ # 保存路径MODEL.WEIGHTS output/batext/pretrain/coco01/model_0019999.pth # model 路径
总结
中间遇到了很多问题,自己也参考了很多文章,特此记录,以便后来者参考。
inference
https://blog.csdn.net/weixin_43823854/article/details/108916498
https://www.tqwba.com/x_d/jishu/286353.html
ABCNet的下载与训练--训练自己的数据集相关推荐
- PyTorch安装测试训练建自己的数据集
Pytorch安装测试训练建自己的数据集 前言 一.PyTorch是什么? 二.PyTorch环境搭建 1.设备要求 2.安装Pytorch 3.验证PyTorch 二.CIFAR10测试 1.关于C ...
- (详细版Win10+Pycharm)YOLOX——训练自己的VOC2007数据集,以NWPU VHR-10 dataset为例
目录 一.搭建YOLOX环境 二.训练自己的VOC数据集 1.打开Pycharm配置Anaconda已创建好的yolo_x虚拟环境 2.在Pycharm中设置Git环境 3.修改配置文件 (1)修改Y ...
- MMAction2学习笔记 使用C3D训练测试自己的数据集
新手上路,记录一下自己的学习过程,希望也能对你有所帮助. 1.数据集准备 参考官网给出的数据集准备教程 https://github.com/open-mmlab/mmaction2/blob/mas ...
- mmdetection的安装并训练自己的VOC数据集
mmdetection的安装并训练自己的VOC数据集 mmdetection的安装与VOC数据集的训练 一. mmdetection的安装 1.使用conda创建虚拟环境 2.安装Cython 3.安 ...
- 使用CycleGAN训练自己制作的数据集,通俗教程,快速上手
总结了使用CycleGAN训练自己制作的数据集,这里的教程例子主要就是官网给出的斑马变马,马变斑马,两个不同域之间的相互转换.教程中提供了官网给的源码包和我自己调试优化好的源码包,大家根据自己的情况下 ...
- 在服务器上利用mmdetection来训练自己的voc数据集
在服务器上利用mmdetection来训练自己的voc数据集 服务器上配置mmdetection环境 在服务器上用anaconda配置自己的环境 进入自己的虚拟环境,开始配置mmdetection 跑 ...
- tensorflow训练自己的声音数据集进行声音分类
** tensorflow训练自己的声音数据集进行声音分类 ** 环境 win10 anaconda3.5 tensorflow 2.0 1.安装anaconda https://pan.baidu. ...
- Windows下使用Yolov3(GPU)训练+测试自己的数据集
Windows下使用Yolov3(GPU)训练+测试自己的数据集 1.配置Yolov3 参考:Windows下使用darknet.exe跑通Yolov3 Window10+VS2017+CUDA10. ...
- 使用Yolov5训练自己制作的数据集,快速上手
总结了快速上手Yolov5训练自己制作的数据集的方法,步骤都很详细,学者耐心看. 文章目录 一.准备好Yolov5框架 二.关于数据集的问题 三.VOC格式数据集转yolo格式数据集 四.训练模型 五 ...
最新文章
- PostgreSQL处理xml数据初步
- python网页编程测试_李亚涛:python编写友情链接检测工具
- leetcode 27. Remove Element
- Linux系统文件名字体不同的颜色都代表什么
- 点/线/面 等 几何关系运算 的网页 推荐+备忘
- Android之自定义带圆角的水纹波效果
- TiDB 源码阅读系列文章(十八)tikv-client(上) 1
- 机器学习ai选股_自带AI机器学习的MEMS了解一下
- python 查找excel内容所在的单元格_python 根据excel单元格内容获取该单元格所在的行号...
- JSP连接SQLServer数据库特别要注意一个小问题得到解决
- python session过期_设置session过期时间
- Android免费地图应用网址
- matlab has encountered,matlab运行程序时出现“matlab has encountered an internal problem
- SpringBoot 读取 jar 包中 BOOT-INF/lib 下的 jar包
- Ubuntu发烧友三部曲
- PLC:学习笔记(西门子)3
- 计算机excel基础知识教程,EXCEL基本操作技巧 一
- Android 语音遥控器的整体分析-主机端语音解码的添加
- 做人,该善良时就善良,该勇敢时就要有勇气去对应
- vue点击预览图片插件(可放大缩小翻转等)
热门文章
- C语言之文件处理(fputc fgetc函数的使用)下篇
- Android最新最全面试题及答案分享
- Palabos源码:collideAndStream
- 在U盘上安装Grub,并引导iso镜像
- ubant每30秒运行shell脚本_[mcj]Ubuntu系统定时执行bashshell命令|Ubuntu定时执行指定脚本...
- Pr:提高工作性能的设置
- vulnhub之tre1
- Linux总体大纲总结
- 精通Linux内核网络 -(以)罗森
- IOS8,IOS8.1等系统出现锁屏状态下WIFI断开问题的解决办法!