文章目录

    • @[TOC](文章目录)
  • 前言
  • 一、ABCNet的下载与demo
    • 1.下载
    • 2. demo
  • 二、训练自己的数据集
    • 1. 使用标注工具windows_label_tool
    • 2. 转换为json (很重要,json文件错了,会出很多问题)
    • 3. 训练
      • 1. 修改相关配置文件
      • 2. 训练
      • 3. 测试
  • 总结
  • inference

前言

这段事件跑实验,正好用到了ABCNet, 中间遇到了很多的问题,特此记录,以避免大家再遇到这样的问题


一、ABCNet的下载与demo

1.下载

ABCNet是AdelaiDet中对于BAText的一个高效的端到端场景文本定位框架
是基于Detectron2的,所以首先要下载Detectron2
我的 Requirements:
Linux with Python = 3.7.11 ,cuda = 10.1,PyTorch = 1.8.1

pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

其他torch版本:

torch版本

# 下载 detectron2
git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
git checkout -f 9eb4831
cd ..
python -m pip install -e detectron2/
# 下载 AdelaiDet
git clone https://github.com/aim-uofa/AdelaiDet.git
cd AdelaiDet
python setup.py build develop

2. demo

先使用预训练的权重模型测试一下。
下载 CTW1500 数据集,

cd AdelaiDet/datasets
wget https://drive.google.com/file/d/1ntlnlnQHZisDoS_bgDvrcrYFomw9iTZ0/view?usp=sharing -O CTW1500.zip
unzip CTW1500.zip
rm CTW1500.zip

下载model,再 demo

# Download ctw1500_attn_R_50.pth above
wget -O ctw1500_attn_R_50.pth https://universityofadelaide.box.com/shared/static/okeo5pvul5v5rxqh4yg8pcf805tzj2no.pth
python demo/demo.py \--config-file configs/BAText/CTW1500/attn_R_50.yaml \--input datasets/CTW1500/ctwtest_text_image/ \--opts MODEL.WEIGHTS ctw1500_attn_R_50.pth

二、训练自己的数据集

1. 使用标注工具windows_label_tool

链接: windows_label_tool 提取码: exvx

格式如下(示例):

windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标(顺序如上图),后面是文本内容
4
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,“DOUGLASTON”
50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,“E-313”
41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,“L164”
39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,“F.D.N.Y.”

2. 转换为json (很重要,json文件错了,会出很多问题)

我的标签格式为:每个txt文件中只有一行, 所以不需要标注个数

45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84||||“DOUGLASTON”

由于后面json转换代码的问题,由14个点改为了8个点即四对点

45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84||||“DOUGLASTON”

# 四对点的顺序 0 3 4 7 为顶点, 1 2 5 6 为控制点
0--1--2--3
|        |
7--6--5--4

所需classes.txt文件, 我的只有一类,所以只有 text

text

转换代码

# -*- coding: utf-8 -*-
"""@File    : convert_ann_to_json.py@Time    : 2020-8-17 16:13@Author  : yizuotian@Description    : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
"""
import argparse
import json
import os
import sys
import cv2
import numpy as npdef gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):"""根据abcnet的gt标注生成coco格式的json标注:param abc_gt_dir: windows_label_tool标注工具生成标注文件目录:param abc_json_path: ABCNet训练需要json标注路径:param image_dir::param classes_path: 类别文件路径:return:"""# Desktop Latin_embed.# 这是标注列表,可以根据自己的改,但是中文在训练时需要下载 simsun.ttc 字体文件(新宋体)cV2 = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑","苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤","桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁","新", '0',  '1',  '2',  '3',  '4', '5',  '6',  '7',  '8','9',  'A',  'B',  'C',  'D',  'E', 'F',  'G',  'H',  'I', 'J',  'K',  'L',  'M',  'N',  'O', 'P',  'Q',  'R',  'S', 'T',  'U',  'V',  'W',  'X',  'Y', 'Z']dataset = {'licenses': [],'info': {},'categories': [],'images': [],'annotations': []}with open(classes_path) as f:classes = f.read().strip().split()for i, cls in enumerate(classes, 1):dataset['categories'].append({'id': i,'name': cls,'supercategory': 'beverage','keypoints': ['mean','xmin','x2','x3','xmax','ymin','y2','y3','ymax','cross']  # only for BDN})def get_category_id(cls):for category in dataset['categories']:if category['name'] == cls:return category['id']# 遍历abcnet txt 标注indexes = sorted([f.split('.')[0]for f in os.listdir(abc_gt_dir)])print(indexes)j = 1  # 标注边框id号for index in indexes:# if int(index) >3: continue# print('Processing: ' + index)im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))im_height, im_width = im.shape[:2]dataset['images'].append({'coco_url': '','date_captured': '','file_name': index + '.jpg','flickr_url': '','id': int(index.split('_')[-1]),  # img_1'license': 0,'width': im_width,'height': im_height})anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))with open(anno_file) as f:lines = [line for line in f.readlines() if line.strip()]# 没有清晰的标注,跳过if len(lines) <= 1:continuefor i, line in enumerate(lines[1:]):elements = line.strip().split(',')# polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32)  # [14,(x,y)]# control_points = bezier_utils.polygon_to_bezier_pts(polygon, im)  # [8,(x,y)]# 由14个点改为8个点control_points = np.array(elements[:16]).reshape((-1, 2)).astype(np.float32)  # [8,(x,y)]ct = elements[-1].replace('"', '').strip()cls = 'text'# segs = [float(kkpart) for kkpart in parts[:16]]segs = [float(kkpart) for kkpart in control_points.flatten()]xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]# 过滤越界边框if max(xt) > im_width or max(yt) > im_height:print('The annotation bounding box is outside of the image:{}'.format(index))print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))continuexmin = min([xt[0], xt[3], xt[4], xt[7]])ymin = min([yt[0], yt[3], yt[4], yt[7]])xmax = max([xt[0], xt[3], xt[4], xt[7]])ymax = max([yt[0], yt[3], yt[4], yt[7]])width = max(0, xmax - xmin + 1)height = max(0, ymax - ymin + 1)if width == 0 or height == 0:continue# 根据自己标签长度范围而定max_len = 7recs = [len(cV2) + 1 for ir in range(max_len)]ct = str(ct)# print('rec', ct)for ix, ict in enumerate(ct):if ix >= max_len:continueif ict in cV2:recs[ix] = cV2.index(ict)else:recs[ix] = len(cV2)dataset['annotations'].append({'area': width * height,'bbox': [xmin, ymin, width, height],'category_id': get_category_id(cls),'id': j,'image_id': int(index.split('_')[-1]),  # img_1'iscrowd': 0,'bezier_pts': segs,'rec': recs})j += 1# 写入json文件folder = os.path.dirname(abc_json_path)if not os.path.exists(folder):os.makedirs(folder)with open(abc_json_path, 'w') as f:json.dump(dataset, f)def main(args):gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)if __name__ == '__main__':"""Usage: python convert_ann_to_json.py \--ann-dir /path/to/gt \--image-dir /path/to/image \--dst-json-path train.json """parse = argparse.ArgumentParser()parse.add_argument("--ann-dir", type=str, default=None) # 标签路径parse.add_argument("--image-dir", type=str, default=None) # 对应的图片路径parse.add_argument("--dst-json-path", type=str, default=None) # 保存json路径parse.add_argument("--classes-path", type=str, default='./classes.txt') arguments = parse.parse_args() # sys.argv[1:]main(arguments)

部分json文件参考:

{"licenses": [],"info": {},"categories": [{"id": 1,"name": "text","supercategory": "beverage","keypoints": ["mean","xmin","x2","x3","xmax","ymin","y2","y3","ymax","cross"]}],"images": [{"coco_url": "","date_captured": "","file_name": "000001.jpg","flickr_url": "","id": 1,"license": 0,"width": 720,"height": 1160},...],"annotations": [{"area": 6868.0,"bbox": [304.0,343.0,101.0,68.0],"category_id": 1,"id": 1,"image_id": 1,"iscrowd": 0,"bezier_pts": [304.0,357.0,454.0,341.0,458.0,394.0,308.0,410.0,329.0,343.0,354.0,345.0,379.0,347.0,404.0,349.0],"rec": [19,16,20,12,19,21,23]},...]
}

3. 训练

显卡:GeForce RTX 2080 Ti *2 batch_size = 2

1. 修改相关配置文件

  • 将制作好的data数据目录放在"AdelaiDet/datasets"目录
    我的目录结构是:
COCO--annotations--train.json--val.json--train--val
  • 修改"adet/data/builtin.py"中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
_PREDEFINED_SPLITS_TEXT = {
"totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
"totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
...
# 以下为修改 (改为自己的)
"COCO_train": ("COCO/train/", "COCO/annotations/train.json"),
"COCO_val": ("COCO/val/", "COCO/annotations/val.json"),
  • 在需要训练的配置文件中指定数据集即可.以configs/BAText/Pretrain/Base-Pretrain.yaml为例
_BASE_: "../Base-BAText.yaml"
DATASETS:# 以下为修改(改为自己的)TRAIN: ("COCO_train",)TEST: ("COCO_val",)
  • label 中有中文, 需下载这个 simsun.ttc 字体文件 放于 AdelaiDet/simsun.ttc 中
    链接:simsun.ttc
    提取码:7tr2

2. 训练

OMP_NUM_THREADS=1 python tools/train_net.py \--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \--num-gpus 2 \OUTPUT_DIR output/batext/pretrain/coco01  # 保存路径

3. 测试

MP_NUM_THREADS=1 python tools/train_net.py \--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \--eval-only \--num-gpus 1 \OUTPUT_DIR output/batext/pretrain/coco01_result \ # 保存路径MODEL.WEIGHTS output/batext/pretrain/coco01/model_0019999.pth # model 路径

总结

中间遇到了很多问题,自己也参考了很多文章,特此记录,以便后来者参考。

inference

https://blog.csdn.net/weixin_43823854/article/details/108916498
https://www.tqwba.com/x_d/jishu/286353.html

ABCNet的下载与训练--训练自己的数据集相关推荐

  1. PyTorch安装测试训练建自己的数据集

    Pytorch安装测试训练建自己的数据集 前言 一.PyTorch是什么? 二.PyTorch环境搭建 1.设备要求 2.安装Pytorch 3.验证PyTorch 二.CIFAR10测试 1.关于C ...

  2. (详细版Win10+Pycharm)YOLOX——训练自己的VOC2007数据集,以NWPU VHR-10 dataset为例

    目录 一.搭建YOLOX环境 二.训练自己的VOC数据集 1.打开Pycharm配置Anaconda已创建好的yolo_x虚拟环境 2.在Pycharm中设置Git环境 3.修改配置文件 (1)修改Y ...

  3. MMAction2学习笔记 使用C3D训练测试自己的数据集

    新手上路,记录一下自己的学习过程,希望也能对你有所帮助. 1.数据集准备 参考官网给出的数据集准备教程 https://github.com/open-mmlab/mmaction2/blob/mas ...

  4. mmdetection的安装并训练自己的VOC数据集

    mmdetection的安装并训练自己的VOC数据集 mmdetection的安装与VOC数据集的训练 一. mmdetection的安装 1.使用conda创建虚拟环境 2.安装Cython 3.安 ...

  5. 使用CycleGAN训练自己制作的数据集,通俗教程,快速上手

    总结了使用CycleGAN训练自己制作的数据集,这里的教程例子主要就是官网给出的斑马变马,马变斑马,两个不同域之间的相互转换.教程中提供了官网给的源码包和我自己调试优化好的源码包,大家根据自己的情况下 ...

  6. 在服务器上利用mmdetection来训练自己的voc数据集

    在服务器上利用mmdetection来训练自己的voc数据集 服务器上配置mmdetection环境 在服务器上用anaconda配置自己的环境 进入自己的虚拟环境,开始配置mmdetection 跑 ...

  7. tensorflow训练自己的声音数据集进行声音分类

    ** tensorflow训练自己的声音数据集进行声音分类 ** 环境 win10 anaconda3.5 tensorflow 2.0 1.安装anaconda https://pan.baidu. ...

  8. Windows下使用Yolov3(GPU)训练+测试自己的数据集

    Windows下使用Yolov3(GPU)训练+测试自己的数据集 1.配置Yolov3 参考:Windows下使用darknet.exe跑通Yolov3 Window10+VS2017+CUDA10. ...

  9. 使用Yolov5训练自己制作的数据集,快速上手

    总结了快速上手Yolov5训练自己制作的数据集的方法,步骤都很详细,学者耐心看. 文章目录 一.准备好Yolov5框架 二.关于数据集的问题 三.VOC格式数据集转yolo格式数据集 四.训练模型 五 ...

最新文章

  1. PostgreSQL处理xml数据初步
  2. python网页编程测试_李亚涛:python编写友情链接检测工具
  3. leetcode 27. Remove Element
  4. Linux系统文件名字体不同的颜色都代表什么
  5. 点/线/面 等 几何关系运算 的网页 推荐+备忘
  6. Android之自定义带圆角的水纹波效果
  7. TiDB 源码阅读系列文章(十八)tikv-client(上) 1
  8. 机器学习ai选股_自带AI机器学习的MEMS了解一下
  9. python 查找excel内容所在的单元格_python 根据excel单元格内容获取该单元格所在的行号...
  10. JSP连接SQLServer数据库特别要注意一个小问题得到解决
  11. python session过期_设置session过期时间
  12. Android免费地图应用网址
  13. matlab has encountered,matlab运行程序时出现“matlab has encountered an internal problem
  14. SpringBoot 读取 jar 包中 BOOT-INF/lib 下的 jar包
  15. Ubuntu发烧友三部曲
  16. PLC:学习笔记(西门子)3
  17. 计算机excel基础知识教程,EXCEL基本操作技巧 一
  18. Android 语音遥控器的整体分析-主机端语音解码的添加
  19. 做人,该善良时就善良,该勇敢时就要有勇气去对应
  20. vue点击预览图片插件(可放大缩小翻转等)

热门文章

  1. C语言之文件处理(fputc fgetc函数的使用)下篇
  2. Android最新最全面试题及答案分享
  3. Palabos源码:collideAndStream
  4. 在U盘上安装Grub,并引导iso镜像
  5. ubant每30秒运行shell脚本_[mcj]Ubuntu系统定时执行bashshell命令|Ubuntu定时执行指定脚本...
  6. Pr:提高工作性能的设置
  7. vulnhub之tre1
  8. Linux总体大纲总结
  9. 精通Linux内核网络 -(以)罗森
  10. IOS8,IOS8.1等系统出现锁屏状态下WIFI断开问题的解决办法!