一.简介

SegNet是Cambridge提出旨在解决自动驾驶或者智能机器人的图像语义分割深度网络,网络由编码器Encoder和解码器Decoder两大部分组成。SegNet基于FCN,编码器部分修改VGG-16网络得到,解码器部分进行多次上采样得到。

编码器用于提取图像的特征,解码器放大特征图,然后对每个像素的进行分类得到输出层。

二.编码器Encoder

Encoder部分用于特征提取,一般对特征进行四次压缩,每次压缩后特征图的大小都减小一半。Encoder的主干网络可以使用vgg16,不过vgg16模型有些繁重,也可以换成轻量级的MobileNet

1.基于vgg16的Encoder

from tensorflow.keras.layers import *# 基于 vgg16 的 segnet 编码器
def segnet_encoder_vgg16(height=416, width=416):img_input = Input(shape=(height, width, 3))# block1# 416,416,3 -- 208,208,64x = Conv2D(64, 3, padding='same', activation='relu', name='b1_c1')(img_input)x = Conv2D(64, 3, padding='same', activation='relu', name='b1_c2')(x)x = MaxPooling2D((2, 2), strides=2, name='b1_pool')(x)out1 = x# block2# 208,208,64 -- 104,104,128x = Conv2D(128, 3, padding='same', activation='relu', name='b2_c1')(x)x = Conv2D(128, 3, padding='same', activation='relu', name='b2_c2')(x)x = MaxPooling2D((2, 2), strides=2, name='b2_pool')(x)out2 = x# block3# 104,104,128 -- 52,52,256x = Conv2D(256, 3, padding='same', activation='relu', name='b3_c1')(x)x = Conv2D(256, 3, padding='same', activation='relu', name='b3_c2')(x)x = Conv2D(256, 3, padding='same', activation='relu', name='b3_c3')(x)x = MaxPooling2D((2, 2), strides=2, name='b3_pool')(x)out3 = x# block4# 52,52,256 -- 26,26,512x = Conv2D(512, 3, padding='same', activation='relu', name='b4_c1')(x)x = Conv2D(512, 3, padding='same', activation='relu', name='b4_c2')(x)x = Conv2D(512, 3, padding='same', activation='relu', name='b4_c3')(x)x = MaxPooling2D((2, 2), strides=2, name='b4_pool')(x)out4 = x# block5# 26,26,512 -- 13,13,512x = Conv2D(512, 3, padding='same', activation='relu', name='b5_c1')(x)x = Conv2D(512, 3, padding='same', activation='relu', name='b5_c2')(x)x = Conv2D(512, 3, padding='same', activation='relu', name='b5_c3')(x)x = MaxPooling2D((2, 2), strides=2, name='b5_pool')(x)out5 = xreturn img_input, out4

2.基于MobilenetV1的Encoder

深度可分离卷积在tensorflow2中有两种实现方法,(DepthwiseConv2D + Conv1x1 ) 实现与(SeparableConv2D)实现

(1).DepthwiseConv2D + Conv1x1

from tensorflow.keras.layers import *def conv_block(inputs, filters, kernel, strides):x = ZeroPadding2D(1)(inputs)x = Conv2D(filters, kernel, strides, padding='valid', use_bias=False)(x)x = BatchNormalization()(x)x = ReLU(max_value=6)(x)return xdef dw_pw_block(inputs, dw_strides, pw_filters, name):x = ZeroPadding2D(1)(inputs)# dwx = DepthwiseConv2D((3, 3), dw_strides, padding='valid', use_bias=False, name=name)(x)x = BatchNormalization()(x)x = ReLU(max_value=6)(x)# pwx = Conv2D(pw_filters, (1, 1), 1, padding='valid', use_bias=False)(x)x = BatchNormalization()(x)x = ReLU(max_value=6)(x)return x# 基于 Mobilenet 的 segnet 编码器(DepthwiseConv2D + Conv1x1 实现)
def segnet_encoder_MobilenetV1_1(height=416, width=416):img_input = Input(shape=(height, width, 3))# block1:con1 + dw_pw_1# 416,416,3 -- 208,208,32 -- 208,208,64x = conv_block(img_input, 32, (3, 3), (2, 2))x = dw_pw_block(x, 1, 64, 'dw_pw_1')# block2:dw_pw_2# 208,208,64 -- 104,104,128x = dw_pw_block(x, 2, 128, 'dw_pw_2_1')x = dw_pw_block(x, 1, 128, 'dw_pw_2_2')# block3:dw_pw_3# 104,104,128 -- 52,52,256x = dw_pw_block(x, 2, 256, 'dw_pw_3_1')x = dw_pw_block(x, 1, 256, 'dw_pw_3_2')# block4:dw_pw_4# 52,52,256 -- 26,26,512x = dw_pw_block(x, 2, 512, 'dw_pw_4_1')for i in range(5):x = dw_pw_block(x, 1, 512, 'dw_pw_4_' + str(i + 2))out4 = x# block5:dw_pw_5# 26,26,512 -- 13,13,1024x = dw_pw_block(x, 2, 1024, 'dw_pw_5_1')x = dw_pw_block(x, 1, 1024, 'dw_pw_5_2')return img_input, out4

(2).SeparableConv2D

from tensorflow.keras.layers import *def sp_block(x, dw_strides, pw_filters, name):x = ZeroPadding2D(1)(x)x = SeparableConv2D(pw_filters, (3, 3), dw_strides, use_bias=False, name=name)(x)x = BatchNormalization()(x)x = ReLU(max_value=6)(x)return x# 基于 Mobilenet 的 segnet 编码器(SeparableConv2D实现)
def segnet_encoder_MobilenetV1_2(height=416, width=416):img_input = Input(shape=(height, width, 3))# block1:con1 + dw_pw_1# 416,416,3 -- 208,208,32 -- 208,208,64x = conv_block(img_input, 32, (3, 3), (2, 2))x = sp_block(x, 1, 64, 'dw_pw_1')# block2:dw_pw_2# 208,208,64 -- 104,104,128x = sp_block(x, 2, 128, 'dw_pw_2_1')x = sp_block(x, 1, 128, 'dw_pw_2_2')# block3:dw_pw_3# 104,104,128 -- 52,52,256x = sp_block(x, 2, 256, 'dw_pw_3_1')x = sp_block(x, 1, 256, 'dw_pw_3_2')# block4:dw_pw_4# 52,52,256 -- 26,26,512x = sp_block(x, 2, 512, 'dw_pw_4_1')for i in range(5):x = sp_block(x, 1, 512, 'dw_pw_4_' + str(i + 2))out4 = x# block5:dw_pw_5# 26,26,512 -- 13,13,1024x = sp_block(x, 2, 1024, 'dw_pw_5_1')x = sp_block(x, 1, 1024, 'dw_pw_5_2')return img_input, out4

三.解码器Decoder

from tensorflow.keras.layers import *# segnet的解码器
def segnet_decoder(feature,n_classes):# 直接进行上采样时会出现一些问题,这里先Zeropadding# 26,26,512 -- 26,26,512x = ZeroPadding2D(1)(feature) # 26,26,512 -- 28,28,512x = Conv2D(512,3,padding='valid')(x)    # 28,28,512 -- 26,26,512x = BatchNormalization()(x)# 上采样 3 次(编码器总共编码5次,每次图像缩小一半,但是只用第4次的结果)# 1/16 -- 1/8 ; 26,26,512 -- 52,52,256# 1/8 -- 1/4  ; 52,52,256 -- 104,104,128# 1/4 -- 1/2  ; 104,104,128 -- 208,208,64filters = [256,128,64]for i,filter in enumerate(filters):x = UpSampling2D(2,name='Up_'+str(i+1))(x)x = ZeroPadding2D(1)(x)x = Conv2D(filter,3,padding='valid')(x)x = BatchNormalization()(x)# 208,208,64 -- 208,208,n_classesout = Conv2D(n_classes,3,padding='same')(x)return out

四.模型创建

from encoders import *
from tensorflow.keras.models import Model# 创建 segnet 模型
def build_segnet(n_classes,encoder_type='vgg16',input_height=416,input_width=416):# 1.获取encoder的输出 (416,416,3--26,26,512)if encoder_type == 'vgg16':img_input,feature = segnet_encoder_vgg16(input_height,input_width)elif encoder_type == 'MobilenetV1_1':img_input, feature = segnet_encoder_MobilenetV1_1(input_height, input_width)elif encoder_type == 'MobilenetV1_2':img_input, feature = segnet_encoder_MobilenetV1_2(input_height, input_width)else:raise RuntimeError('segnet encoder name is error')# 2.获取decoder的输出 (26,26,512--208,208,n_classes)out = segnet_decoder(feature,n_classes)# 3.结果Reshape (208*208,n_classes)out = Reshape((int(input_height/2)*int(input_height/2),-1))(out)out = Softmax()(out)# 4.创建模型model = Model(img_input,out)return model

五.Segnet训练斑马线语义分割

from segnet import build_segnet
from tensorflow.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
from PIL import Image
import os
import argparsedef parse_opt():parse = argparse.ArgumentParser()parse.add_argument('--datasets_path',type=str,default='../../datasets/banmaxian',help='数据集路径')parse.add_argument('--n_classes',type=int,default=2,help='标签种类(含背景)')parse.add_argument('--height',type=int,default=416,help='图片高度')parse.add_argument('--width',type=int,default=416,help='图片宽度')parse.add_argument('--batch_size',type=int,default=2)parse.add_argument('--lr',type=float,default=0.0001)parse.add_argument('--epochs',type=int,default=50)parse.add_argument('--encoder_type',type=str,default='MobilenetV1_2',help='segnet模型编码器的类型[vgg16,MobilenetV1_1,MobilenetV1_2]')opt = parse.parse_args()return optdef get_data_from_file(opt):datasets_path,height,width,n_classes = opt.datasets_path,opt.height,opt.width,opt.n_classeswith open(os.path.join(datasets_path,'train.txt')) as f:lines = f.readlines()lines = [line.replace('\n','') for line in lines]X = []Y = []for i in range(len(lines)):names = lines[i].split(';')real_name = names[0]    # xx.jpglabel_name = names[1]   # xx.png# 读取真实图像real_img = Image.open(os.path.join(datasets_path,'jpg',real_name))real_img = real_img.resize((height,width))real_img = np.array(real_img)/255   # (416,416,3) [0,1]X.append(real_img)# 读取标签图像,3通道,每个通道的数据都一样,每个像素点就是对应的类别,0表示背景label_img = Image.open(os.path.join(datasets_path, 'png', label_name))label_img = label_img.resize((int(height/2), int(width/2)))label_img = np.array(label_img) # (208,208,3) [0,1]# 根据标签图像来创建训练标签数据,n类对应的 seg_labels 就有n个通道# 此时 seg_labels 每个通道的都值为 0seg_labels = np.zeros((int(height/2), int(width/2),n_classes))  # (208,208,2)# 第0通道表示第0类# 第1通道表示第1类# .....# 第n_classes通道表示第n_classes类for c in range(n_classes):seg_labels[:,:,c] = (label_img[:,:,0]==c).astype(int)# 此时 seg_labels 每个通道的值为0或1, 1 表示该像素点是该类,0 则不是seg_labels = np.reshape(seg_labels,(-1,n_classes))  # (208*208,2)Y.append(seg_labels)return np.array(X),np.array(Y)if __name__ == '__main__':# 1.参数初始化opt = parse_opt()# 2.获取数据集X,Y = get_data_from_file(opt)# 3.创建模型# 每5个epoch保存一次weight_path = 'segnet_' + opt.encoder_type+'_weight/'model = build_segnet(opt.n_classes,opt.encoder_type,opt.height,opt.width)os.makedirs(weight_path,exist_ok=True)checkpoint = ModelCheckpoint(filepath=weight_path+'acc{accuracy:.4f}-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',verbose=1,save_best_only=True,save_weights_only=True,period=5)lr_sh = ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=5,verbose=1)es = EarlyStopping(monitor='val_loss',patience=10,verbose=1)model.compile(loss=BinaryCrossentropy(),optimizer=Adam(opt.lr),metrics='accuracy')# 4.模型训练model.fit(x=X,y=Y,batch_size=opt.batch_size,epochs=opt.epochs,callbacks=[checkpoint,lr_sh],verbose=1,validation_split=0.1,shuffle=True,)# 5.模型保存model.save_weights(weight_path+'/last.h5')

六.测试

from segnet import build_segnet
from PIL import Image
import numpy as np
import copy
import os
import argparsedef parse_opt():parse = argparse.ArgumentParser()parse.add_argument('--test_imgs', type=str, default='test_imgs', help='测试数据集')parse.add_argument('--test_res', type=str, default='test_res', help='测试数据集')parse.add_argument('--n_classes', type=int, default=2, help='标签种类(含背景)')parse.add_argument('--height', type=int, default=416, help='输入模型的图片高度')parse.add_argument('--width', type=int, default=416, help='输入模型的图片宽度')parse.add_argument('--encoder_type', type=str, default='MobilenetV1_2', help='segnet模型编码器的类型[vgg16,MobilenetV1_1,MobilenetV1_2]')opt = parse.parse_args()return optdef resize_img(path,real_width,real_height):img_names = os.listdir(path)for img_name in img_names:img = Image.open(os.path.join(path, img_name))img = img.resize((real_width,real_height))img.save(os.path.join(path, img_name))if __name__ == '__main__':# 1.参数初始化opt = parse_opt()# class_colors 要根据图像的语义标签来设定;n_classes 行 3 列;# 3列为RGB的值class_colors = [[0, 0, 0],[0, 255, 0]]imgs_path = os.listdir(opt.test_imgs)imgs_test = []imgs_init = []jpg_names = []real_width,real_height = 1280,720resize_img(opt.test_imgs, real_width,real_height)# 2.获取测试图片for i,jpg_name in enumerate(imgs_path):img_init = Image.open(os.path.join(opt.test_imgs, jpg_name))img = copy.deepcopy(img_init)img = img.resize((opt.width,opt.height))img = np.array(img) / 255  # (416,416,3) [0,1]imgs_test.append(img)imgs_init.append(img_init)jpg_names.append(jpg_name)imgs_test = np.array(imgs_test)  # (-1,416,416,3)# 3.模型创建weight_path = 'segnet_' + opt.encoder_type + '_weight/'model = build_segnet(opt.n_classes, opt.encoder_type,opt.height,opt.width)model.load_weights(os.path.join(weight_path, 'last.h5'))# 4.模型预测语义分类结果prs = model.predict(imgs_test)  # (-1, 43264, 2)# 结果 reshapeprs = prs.reshape(-1, int(opt.height / 2), int(opt.width / 2), opt.n_classes)  # (-1, 208, 208, 2)# 找到结果每个像素点所属类别的索引 两类就是 0 或 1prs = prs.argmax(axis=-1)   # (-1, 208, 208)# 此时 prs 就是预测出来的类别,argmax 求得是最大值所在的索引,这个索引和类别值相同# 所以 prs 每个像素点就是对应的类别# 5.创建语义图像# 和训练集中的语义标签图像不同,这里要显示图像,所以固定3通道imgs_seg = np.zeros((len(prs), int(opt.height / 2), int(opt.width / 2), 3)) # (-1,208,208,3)for c in range(opt.n_classes):# 每个通道都要判断是否属于第0,1,2... n-1 类,是的话就乘以对应的颜色,每个类别都要判断一次# 因为是RGB三个通道,所以3个通道分别乘以class_colors的每个通道颜色值imgs_seg[:,:,:,0] += ((prs[:,:,:]==c)*(class_colors[c][0])).astype(int)imgs_seg[:,:,:,1] += ((prs[:,:,:]==c)*(class_colors[c][1])).astype(int)imgs_seg[:,:,:,2] += ((prs[:,:,:]==c)*(class_colors[c][2])).astype(int)# 6.保存结果save_path = opt.test_out+'_'+opt.encoder_typeos.makedirs(save_path,exist_ok=True)for img_init,img_seg,img_name in zip(imgs_init,imgs_seg,jpg_names):img_seg = Image.fromarray(np.uint8(img_seg)).resize((real_width,real_height))images = Image.blend(img_init,img_seg,0.3)images.save(os.path.join(opt.test_out+'_'+opt.encoder_type,img_name))

SegNet代码实战相关推荐

  1. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  2. R语言使用tryCatch函数调试R代码实战:tryCatch函数运行正常R代码、tryCatch函数运行有错误(error)的R代码示例/tryCatch函数运行有警告(warning)的R代码示例

    R语言使用tryCatch函数调试R代码实战:tryCatch函数运行正常R代码.tryCatch函数运行有错误(error)的R代码示例/tryCatch函数运行有警告(warning)的R代码示例 ...

  3. sklearn基于make_scorer函数为Logistic模型构建自定义损失函数并可视化误差图(lambda selection)和系数图(trace plot)+代码实战

    sklearn基于make_scorer函数为Logistic模型构建自定义损失函数并可视化误差图(lambda selection)和系数图(trace plot)+代码实战 # 自定义损失函数 i ...

  4. sklearn基于make_scorer函数为Logistic模型构建自定义损失函数+代码实战(二元交叉熵损失 binary cross-entropy loss)

    sklearn基于make_scorer函数为Logistic模型构建自定义损失函数+代码实战(二元交叉熵损失 binary cross-entropy loss) # 广义线性模型中的各种连接函数: ...

  5. 深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测大气压( air pressure)+代码实战

    深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测大气压( air pressure)+代码实战 长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主 ...

  6. 深度学习时间序列预测:卷积神经网络(CNN)算法构建单变量时间序列预测模型预测空气质量(PM2.5)+代码实战

    深度学习时间序列预测:卷积神经网络(CNN)算法构建单变量时间序列预测模型预测空气质量(PM2.5)+代码实战 神经网络(neual networks)是人工智能研究领域的一部分,当前最流行的神经网络 ...

  7. 深度学习时间序列预测:GRU算法构建单变量时间序列预测模型+代码实战

    深度学习时间序列预测:GRU算法构建单变量时间序列预测模型+代码实战 GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种 ...

  8. 深度学习时间序列预测:GRU算法构建多变量时间序列预测模型+代码实战

    深度学习时间序列预测:GRU算法构建多变量时间序列预测模型+代码实战 注意参考:深度学习多变量时间序列预测:GRU算法构建单变量时间序列预测模型+代码实战 GRU(Gate Recurrent Uni ...

  9. 深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测空气质量(PM2.5)+代码实战

    深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测空气质量(PM2.5)+代码实战 # 导入需要的包和函数: from __future__ import print_function im ...

最新文章

  1. asp.net 2.0小TIPS两则
  2. Css Sprites 多张图片整合在一张图片上
  3. 【LeetCode笔记】剑指 Offer 93. 复原 IP 地址(Java、DFS、字符串)
  4. 【Linux笔记(000) 】-- 系统启动过程
  5. Angular - ng-repeat高级用法
  6. v4l2接口,结构图
  7. 最近写mapreduce程序从hbase中抽取程序遇到的一些问题
  8. topjui/easyui 表格分页简单实例
  9. sis最新ip地址2020入口一_最新天猫双十一2020淘宝双十一红包活动加码揭秘 京东双11玩法攻略_互联网...
  10. 董明珠成为带货女王,并非格力值得高兴的事情
  11. node.js + busboy 多文件上传
  12. 基于Zigbee的智能家居系统
  13. MVX Android设计架构浅析-MVVM
  14. 华硕 小布 类似机器人_“嗨 小布跟着我” | 华硕首款智能机器人“小布”正式发布...
  15. ChucK初步(6)
  16. 安卓虚拟摄像头_华为Mate 40 Pro爆料,后置摄像头类似苹果iPod
  17. TMO (time-triggered message-triggered object)
  18. 快速傅里叶变换(FFT)的原理及公式
  19. 【超融合】超融合“火不火”?
  20. 莫凡Python学习笔记 一

热门文章

  1. Titanic - (XGBoost,RF随机森林,Fastai-tabular_learner)总结
  2. LeetCode(89):格雷编码 Gray Code(Java)
  3. EasyCVR接入Ehome协议的设备,无法观看设备录像是什么原因?
  4. ! undefined control sequence \begin{the bibliography}{0}的解决方法
  5. 运动控制第二篇之闭环控制直流电机调速系统仿真
  6. PHP实现随机发牌功能
  7. FFmpeg 推流不同视频格式参数
  8. python自动化[poco篇]
  9. 网站的SEO优化(提高搜索引擎收录,类似百度)
  10. 百度关键词ad竞价的优劣势分析,信息流优化师必看