前言:

2018年阿里的论文《Semantatic Human Matting》给出了抠图领域的一个新方法,可惜阿里并没有公布源码,而牛人在Github上对这个论文进行了复现,我也是依赖Github上的工程进行钻研,而在调试的过程中,发现有一些地方原作者并没有检验通过就上传,导致训练过程出错,这篇博客就是讲解如何调试通过Github上的Semantic_Human_Matting工程的训练以及测试的代码。

-------------------------------------------------------------------------------------------------------------------------

申明:

  • 写博客的初衷:一是为了记录,二是为了给后来人填坑——测试效果的好坏受算法结构、受数据集、受训练次数等因素的影响,留言板处不要因为你的结果表现不优良而无视博主无偿付出、甚至恶评相向,这样的白嫖党我劝你善良。

-------------------------------------------------------------------------------------------------------------------------

一、SHM网络简单讲解

通过下面Semantic_Human_Matting网络图开始讲解SHM的网络设计:

SHM的网络大致分为三个部分:

  1. T-Net网络部分:这部分的作用主要是预测生成trimap图。网络的输入是原图 + mask图;
  2. M-Net网络部分:这部分的作用主要是预测生成alpha图。网络的输入来源于三部分:第一个是原图(上图最左边的那张),第二个是原图对应的mask图(真正输入到网络中的mask图会被拆分成前景图 + 背景图两部分,也就是上图中的FsF_sFs​和BsB_sBs​),第三个是trimap图(真正输入到网络中的只要trimap图的不确定区域,也就是上图中的UsU_sUs​),预测得到上图中的 αrα_rαr​
  3. Fusion Module这部分的作用主要是融合得到精准的alpha图。最后精准αpα_pαp​遮罩图的概率估计是: αp=Fs+Usαrα_p = F_s + U_sα_rαp​=Fs​+Us​αr​

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

二、SHM数据集调整说明

2.1、工程下载,以及环境配置

Github上的Semantic_Human_Matting工程链接在此处此处此处,先下载解压;

根据工程主页上的说明,需要的是python3.5/3.6,torch>=0.4.0,以及opencv-python,我配置的机器环境是ubuntu16.04 + cuda10.0 + python3.6.12 + torch0.4.1 + opencv3.4.3。Windows机器我好像配置过,好像没通过(记不大清楚了,有兴趣的去试一试)

-------------------------------------------------------------------------------------------------------------------------

2.2、下载数据集

2.2.1、最头痛的就是数据集的建立,因为建立大型数据集耗时耗力。所幸工程主页里作者给出了他找到的数据集,在这里对作者及爱分割公司表示感谢,数据集的链接在此处此处此处:https://pan.baidu.com/share/init?surl=R9PJJRT-KjSxh-2-3wCGxQ,密码是:dzsn,下载解压。

2.2.2、解压后可以看到其下主要包含两个文件夹:

  • clip_img文件夹:其下都是原图;
  • matting文件夹:其下都是原图对应的mask图,但是需要处理一下;
  • 其中matting/1803201916/._matting_00000000是错误文件,需要手动删除!
  • 其中clip_img/1803201916/clip_00000000/1803201916-00000117.png文件,需要改成jpg后缀!

注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。
注意:整个数据集包含3W+张图片,预处理全部文件的话很耗时,所以在调试阶段博主强烈建议用其中某一个文件夹就行了。

2.2.3、在工程data目录下新建mattingclip_img文件夹,再将数据集mattingclip_img文件夹下的挑选任意一个相同文件夹对应放入工程目录中,隶属关系如下:

-------------------------------------------------------------------------------------------------------------------------

2.3、matting图生成对应的mask图:

先在data文件夹下新建zcm_matting_get_mask.py文件,代码如下,然后执行这个py文件,完成后可以在data目录下看到生成了一个新的mask文件夹,其下存储着黑白底的mask图。

import os
import cv2matting_path = "matting/"
mask_path = "mask/"# test
# for mask_name in os.listdir(matting_path):
#     in_image = cv2.imread(matting_path + mask_name, cv2.IMREAD_UNCHANGED)
#     alpha = in_image[:,:,3]
#     cv2.imwrite(mask_path + mask_name, alpha)for name_0 in os.listdir(matting_path):if not os.path.exists(mask_path + "/" + name_0):os.makedirs(mask_path + "/" + name_0)for name_1 in os.listdir(matting_path + "/" + name_0):if not os.path.exists(mask_path + name_0 + "/" + name_1):os.mkdir(mask_path + name_0 + "/" + name_1)for name_2 in os.listdir(matting_path + "/" + name_0 + "/" + name_1):pic_input_path = matting_path + "/" + name_0 + "/" + name_1 + "/" + name_2pic_output_path = mask_path + "/" + name_0 + "/" + name_1 + "/" + name_2print("pic_input_path=", pic_input_path)in_image = cv2.imread(pic_input_path, cv2.IMREAD_UNCHANGED)alpha = in_image[:, :, 3]cv2.imwrite(pic_output_path, alpha)

-------------------------------------------------------------------------------------------------------------------------

2.4、生成训练数据的TXT目录:

先在data文件夹下新建zcm_get_train_txt.py文件,代码如下,然后执行这个py文件,完成后可以在data目录下看到生成了一个新的train.txt文件,打开里面存储图片的路径。

import ospic_path = "matting/"with open("train.txt", "w", encoding="UTF-8") as ff:for name_0 in os.listdir(pic_path):for name_1 in os.listdir(pic_path + "/" + name_0):for name_2 in os.listdir(pic_path + "/" + name_0 + "/" + name_1):pic_input_path = name_0 + "/" + name_1 + "/" + name_2ff.write(pic_input_path + "\n")ff.close()
print("well done____________!")

-------------------------------------------------------------------------------------------------------------------------

2.5、由mask图生成trimap图:

2.5.1:像下面一样注释掉gen_trimap.py第36/42/48行的断言语句;

# assert(cnt1 == cnt2 + cnt3)

2.5.2:在gen_trimap.py第四行添加语句,引入os库;

import os

2.5.3:在gen_trimap.py第64行后,添加如下代码;

trimap_name_1 = trimap_name.split("/")[:-1]
trimap_path = "/".join(trimap_name_1)
if not os.path.exists(trimap_path):os.makedirs(trimap_path)

2.5.4:执行sh gen_trimap.sh脚本,就可以生成得到trimap文件夹,及其其下的trimap图片;

-------------------------------------------------------------------------------------------------------------------------

2.6、生成alpha图:

说明:这里给出两种生成alpha图的方法:

  1. 用工程自带的knn_matting.sh脚本生成alpha图;
  2. 直接拷贝mask文件夹,将mask图作为精确的alpha图注入训练;

第一种方法我在简单测试中使用过,该方法非常非常非常的耗时间,而且用该方法处理爱分割公司提供的数据集得到了alpha图,将其注入训练后,对最后的预测的准确率的影响并不大;有兴趣的朋友可以对knn_matting继续改进,将时间效率提高;

我也阐述使用第二种方法的依据:因为爱分割公司的数据集的mask图是精确的,是直接通过matting文件夹生成的。爱分割公司在提供数据集的时候,mask图就是他们人工扣出来的。而knn_matting.sh脚本存在的意义,是对于正常情况下,我们如使用faster-RCNN,DeepLab这样的分割算法得到的mask图是不精准的,才需要使用knn_matting算法处理边界,得到精准的alpha图。

所以这一步,在data文件夹下新建alpha文件夹后,再执行下面复制语句,将mask文件夹下所有文件复制到alpha文件夹;

cp -r mask/* alpha/

至此,数据集准备工作全部做完。

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

三、训练细节调整说明

3.1、写入训练code:

先在Semantic_Human_Matting工程目录下,新建train_code.txt文件,写入如下指令:

# # T-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'# # M-Net训练指令
python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'

第一段是T-Net训练代码,第二段是M-Net训练代码

-------------------------------------------------------------------------------------------------------------------------

3.2、修改train.py文件:

train.py文件第29行后添加一条语句,用来指示GPU的使用情况

parser.add_argument('--gpus', default='0,1,2,3', help='gpus number')

-------------------------------------------------------------------------------------------------------------------------

3.3、修改dataset.py文件:

3.3.1:用如下语句替换dataset.py文件第17/18/19行

image_name = os.path.join(data_dir, 'clip_img', file_name['image'].replace("matting", "clip").replace("png", "jpg"))
trimap_name = os.path.join(data_dir, 'trimap', file_name['trimap'].replace("clip", "matting"))
alpha_name = os.path.join(data_dir, 'alpha', file_name['alpha'].replace("clip", "matting"))

3.3.2:用如下语句替换dataset.py文件第101/102/103行:

trimap[trimap == 0] = 0
trimap[trimap >= 250] = 2
trimap[np.where(~((trimap == 0) | (trimap == 2)))] = 1


这里是整个代码中错误最隐蔽的一个,当初也是花了我很长时间才搞定。我解释一下为什么这样做:我们知道trimap图是三色图,但是它的“三色”并不像上图中0/128/255只有这三色,它是在[0, 255]这个区间范围内。所以新改的代码,将这“三色”用区间区分,作为三种不同的label传入训练。

-------------------------------------------------------------------------------------------------------------------------

3.4、开启T-Net训练:

运行train_code.txt第一行代码,开启T-Net训练,如果你报内存不足的错误,就适当调小patch_size,nThreads,train_batch的数值;

python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=1e-5 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='pre_train_t_net'

下图是我T-Net训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;

-------------------------------------------------------------------------------------------------------------------------

3.5、开启M-Net训练:

运行train_code.txt第二行代码,开启M-Net微调训练

python3 train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train.txt' --lrdecayType='keep' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=320 --lr=5e-6 --gpus='0,1,2,3' --nThreads=24 --train_batch=48 --train_phase='end_to_end'

下图是我M-Net训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

四、测试细节调整说明

4.1:新建test_camera_used.py文件

写入如下代码,代码与test_camera.py文件很相似,只是改了一部分需求,让过程更简洁;

'''test camera Author: Zhengwei Li
Date  : 2018/12/28
'''
import time
import cv2
import torch
import argparse
import numpy as np
import os
import torch.nn.functional as F
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'parser = argparse.ArgumentParser(description='human matting')
parser.add_argument('--model', default='./ckpt/human_matting/model/model_obj.pth', help='preTrained model')
parser.add_argument('--size', type=int, default=320, help='input size')
parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')args = parser.parse_args()torch.set_grad_enabled(False)#################################
#----------------
if args.without_gpu:print("use CPU !")device = torch.device('cpu')
else:if torch.cuda.is_available():n_gpu = torch.cuda.device_count()print("----------------------------------------------------------")print("|       use GPU !      ||   Available GPU number is {} !  |".format(n_gpu))print("----------------------------------------------------------")device = torch.device('cuda: 0, 1, 2, 3')#################################
#---------------
def load_model(args):print('Loading model from {}...'.format(args.model))if args.without_gpu:myModel = torch.load(args.model, map_location=lambda storage, loc: storage)else:myModel = torch.load(args.model)myModel.eval()myModel.to(device)# myModel.cuda()return myModeldef seg_process(args, image, net):# opencvorigin_h, origin_w, c = image.shapeimage_resize = cv2.resize(image, (args.size,args.size), interpolation=cv2.INTER_CUBIC)image_resize = (image_resize - (104., 112., 121.,)) / 255.0tensor_4D = torch.FloatTensor(1, 3, args.size, args.size)tensor_4D[0,:,:,:] = torch.FloatTensor(image_resize.transpose(2,0,1))inputs = tensor_4D.to(device)trimap, alpha = net(inputs)trimap_np = trimap[0, 0, :, :].cpu().data.numpy()trimap_np = cv2.resize(trimap_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)mask_result = np.multiply(trimap_np[..., np.newaxis], image)trimap_1 = mask_result.copy()mask_result[trimap_1 < 10] = 255mask_result[trimap_1 >= 10] = 0cv2.imwrite("mask_result.png", mask_result)if args.without_gpu:alpha_np = alpha[0,0,:,:].data.numpy()else:alpha_np = alpha[0,0,:,:].cpu().data.numpy()alpha_np = cv2.resize(alpha_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)fg = np.multiply(alpha_np[..., np.newaxis], image)# cv2.imwrite("fg.png", fg)# bg = image# bg_gray = np.multiply(1 - alpha_np[..., np.newaxis], image)# bg_gray = cv2.cvtColor(bg_gray, cv2.COLOR_BGR2GRAY)# # print("bg_gray=", bg_gray)# bg[:,:,0] = bg_gray# bg[:,:,1] = bg_gray# bg[:,:,2] = bg_gray## # fg[fg<=0] = 0# # fg[fg>255] = 255# # fg = fg.astype(np.uint8)# # out = cv2.addWeighted(fg, 0.7, bg, 0.3, 0)## # out = fg + bg# # out[out<0] = 0# # out[out>255] = 255# # out = out.astype(np.uint8)## out = fg.copy()# out[out<10] = 0# out[out>=10] = 255# out = out.astype(np.uint8)return fg, mask_resultdef camera_seg(args, net):# videoCapture = cv2.VideoCapture(0)## while(1):#     # get a frame#     ret, frame = videoCapture.read()#     frame = cv2.flip(frame,1)#     frame_seg = seg_process(args, frame, net)###     # show a frame#     cv2.imshow("capture", frame_seg)##     if cv2.waitKey(1) & 0xFF == ord('q'):#         break# videoCapture.release()test_pic_path = "test_pic/"output_path = "result/"if not os.path.exists(output_path):os.mkdir(output_path)time_0 = time.time()for name_ in os.listdir(test_pic_path):frame = cv2.imread(test_pic_path + name_)fg, mask_result = seg_process(args, frame, net)print("SUCCESS_____!", test_pic_path + name_)cv2.imwrite(output_path + name_.split(".")[0] + "_fg.jpg", fg)cv2.imwrite(output_path + name_, mask_result)print("time_all = ", time.time() - time_0)def main(args):time_1 = time.time()myModel = load_model(args)print("lodding_model_time = ", time.time() - time_1)camera_seg(args, myModel)if __name__ == "__main__":main(args)

4.2:测试过程

在主目录下新建test_pic文件夹,将测试所用的pic图片存入其中后,运行test_camera_used.py文件,就能在result文件夹下得到预测的结果图。

-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------

五、最后的说明:

  1. 爱分割公司提供的数据集中,某一个目录中有一个没用的隐藏文件,如果不删除的话,数据准备过程、训练过程会报错——文件地址是:matting/1803201916/._matting_00000000
  2. 我训练了一个较好的model,所用的设备是具有4个Tesla V100的显卡服务器,用上了爱分割公司全部数据集 + 自建的一些数据集,因为公司的保密协议,我不能公布这个model,只展示我测试的结果。左边是预测生成图,右边是原图;
  3. 有问题欢迎留言垂询;



【SHM】Semantic Human Matting抠图算法调试相关推荐

  1. 【Image Matting】Semantic Human Matting

    [MM 18] Semantic Human Matting Paper : https://arxiv.org/pdf/1809.01354.pdf 摘要 首次实现无需Trimap方式生成alpha ...

  2. SHM(Semantic Human Mating)算法详解

    论文题目:Semantic Human Matin 论文链接:论文链接 论文代码:None 目录 1.人像抠图简介 2.本文的贡献 3.本文整体框架详解 4.论文实现细节详解 4.1 T-Net详解 ...

  3. 论文阅读:Semantic Human Matting

    论文地址:https://arxiv.org/pdf/1809.01354.pdf 内容简介 这个网络是用来做人像抠图的(Matting),只能抠人不能抠别的 制作了一个很大的高质量人像抠图数据集(5 ...

  4. 基于阿里Semantatic Human Matting算法,实现精细化人物抠图

    人像抠图 基于深度学习技术研发的人像抠图技术.可识别视频图像中的人像区域,包括头部.半身.全身位置,抠出人像部分后,配以不同背景图片.效果,实现娱乐化需求,支持用户玩转更多个性化操作,常用于直播.视频 ...

  5. c#实现SharedMatting抠图算法

    内容简介 将Alpha Matting抠图算法由c++ 版本移植至c#环境. 主要采用OpenCV的C#版本Emgu取代c++支撑的OpenCV. 参考资料 http://www.inf.ufrgs. ...

  6. 图像抠图算法学习 - Shared Sampling for Real-Time Alpha Matting

    一.序言   陆陆续续的如果累计起来,我估计至少有二十来位左右的朋友加我QQ,向我咨询有关抠图方面的算法,可惜的是,我对这方面之前一直是没有研究过的.除了利用和Photoshop中的魔棒一样的技术或者 ...

  7. opencv 图像 抠图 算法_图像抠图算法学习 - Shared Sampling for Real-Time Alpha Matting

    一.序言 陆陆续续的如果累计起来,我估计至少有二十来位左右的朋友加我QQ,向我咨询有关抠图方面的算法,可惜的是,我对这方面之前一直是没有研究过的.除了利用和Photoshop中的魔棒一样的技术或者Ph ...

  8. 图像抠图算法学习 - Shared Sampling for Real-Time Alpha Matting

    本篇博文来自博主Imageshop,打赏或想要查阅更多内容可以移步至Imageshop. 转载自:https://www.cnblogs.com/Imageshop/p/3550185.html    ...

  9. 转载:图像抠图算法学习 -Shared Sampling for Real-Time Alpha Matting

    原文地址:https://www.cnblogs.com/Imageshop/p/3550185.html 一.序言   陆陆续续的如果累计起来,我估计至少有二十来位左右的朋友加我QQ,向我咨询有关抠 ...

最新文章

  1. 图灵奖得主Yann LeCun万字访谈:DNN“史前文明”、炼金术及新的寒冬
  2. 关于DSP开发的步骤
  3. 《Java和Android开发学习指南(第2版)》—— 1.5 本章小结
  4. PDE9 wave equation: general solution
  5. oracle md,Oracle笔记.md
  6. 在Orderby子句中使用CASE 语句
  7. IDEA中安装TeaVM插件
  8. el-select的写法
  9. C#数据库编程实战经典
  10. 教你屏蔽CSDN广告
  11. 守望空城,一位摄影师镜头下的武汉
  12. [unreal] 切换关卡
  13. 单选按钮、字体的设置、沿着y轴旋转、面向用户的这一面不可见、三维效果、背景线性渐变、将背景剪切至文本
  14. Spring 实战-第六章-渲染Web视图-6.2创建JSP视图
  15. 随興8作者雨落下無痕
  16. 输出100以内不能被7整除的数
  17. Linux下使用uinput创建虚拟设备(Ubuntu20.04.2)
  18. ITU标准介绍及下载索引
  19. 用Python自带的tkinter制作一款简易音乐播放器(附工程文件)
  20. 中小型服装店如何选择管理软件?

热门文章

  1. 国际象棋棋盘有64格,若在第1格放1粒谷;第2格放2粒谷;第3格放4粒谷;第4格放8粒谷……如此一直放到第n格(n小于等于64)。假设2000000粒谷有一吨重,问需要多少吨谷才能存满n格?
  2. 溆浦职业中等学校计算机,溆浦县职业中等专业学校2021年学费
  3. 基于小米NOTE的安卓手机刷nethunter通用包的教程
  4. Redis key过期事件监听实现 - 30分钟自动取消未支付订单
  5. 迷你php搭建,laravel 30分站搭建迷你博客
  6. 【文献调研】半监督菌群优化因果特征选择是否可行?
  7. python 3.9版本安装教程(超详细)
  8. Oracle查看用户密码过期,修改永不过期
  9. 领峰:现货黄金行情的解读和分析要重视哪几点
  10. oracle数据库instr用法,postgresql instr函数功能实现(实现oracle plsql instr相同功能)