语义分割注意力机制SE模块tensorflow代码实现
介绍
- 论文原地址: Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks
在U-Shape网络模型的基础上添加SE模块,具体加在每个卷积模块后,如图所示。
Spatial Squeeze and Channel Excitation Block
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import tensorflow as tfdef normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False):"""卷积模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param Separable: 是否使用可分离卷积:param coord_conv2d: 是否使用坐标卷积:param activation: 激活函数:return: 输出数组[b, w, h, filters]"""if Separable:x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation(activation)(x)return xdef CAttention(x, channel):x_origin = xx = layers.GlobalAveragePooling2D()(x)x = K.expand_dims(x, axis=1)x = K.expand_dims(x, axis=1)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2, activation='relu')x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)x = tf.multiply(x, x_origin)return x
Channel Squeeze and Spatial Excitation Block
def SAttention(x):x_origin = xx = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')x = tf.multiply(x, x_origin)return x
Spatial and Channel Squeeze & Excitation Block (scSE)
def SCAttention(x, channel):x1 = CAttention(x, channel)x2 = SAttention(x)x = layers.Add()([x1, x2])return x
模型代码示例
"""
SCAttentionNet: with attention
Author: XG_hechao
Begin Date: 20201113
End Date:
"""
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from CoordConv2 import Coord_Conv2d
import tensorflow as tfdef Channel_Split(x):"""通道划分:param x: 输入数组[b, w, h, filters]:return: 两个数组[b, w, h, 0:filters/2]和[b, w, h, filters/2:]"""channel_before = x.shape.as_list()[1:] # 取通道数split_channel_num = channel_before[2] // 2 # 取一半通道channel_one = x[:, :, :, 0:split_channel_num] # 取前一半通道channel_two = x[:, :, :, split_channel_num:] # 取后一半通道return channel_one, channel_twodef Channel_Shuffle(x):"""通道洗牌:param x: 输入数组[b, w, h, filters]:return: 输出数组[b, w, h, filters_new]"""height, width, channels = x.shape.as_list()[1:] # 取通道数channels_per_split = channels // 2 # 取一半x = K.reshape(x, [-1, height, width, 2, channels_per_split]) # 将n维打乱为[2, n/2]维x = K.permute_dimensions(x, (0, 1, 2, 4, 3)) # 维度重排序x = K.reshape(x, [-1, height, width, channels]) # 通道重组 [2, n/2]--->nreturn xdef branch(x, filters, dilation_rate, unit_num, right=False):"""卷积分支:param x: 输入数组[b, w, h, filter]:param filters: 卷积核数量:param dilation_rate: 扩张卷积率:param unit_num: int:param right: 分支卷积方式, False为使用右侧卷积方式:return: 输出数组[b, w, h, filters]"""if right:x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x) # 卷积x = layers.Activation('relu')(x) # 激活x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)x = layers.BatchNormalization()(x) # BN归一化x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_right_{0}_1_{1}'.format(unit_num, dilation_rate))(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_right_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)else:x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_left_{0}_1_{1}'.format(unit_num, dilation_rate))(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_left_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)return xdef split_shuffle_module(x, filters, unit_num, dilation_rate_value):"""卷积通道分离、打乱模块:param x: 输入数组[b, w, h, filter]:param filters: 卷积核数量:param unit_num: 循环单元数:param dilation_rate_value: 扩张卷积率:return: 输出数组[b, w, h, filters]"""for i in range(unit_num):# if len(dilation_rate_value) is not 1:dilation_rate = dilation_rate_value[i]# else:# dilation_rate = dilation_rate_valueadd = xx_one, x_two = Channel_Split(x) # 函数调用x_one = branch(x_one, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=True) # 函数调用x_two = branch(x_two, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=False)x = layers.Concatenate()([x_one, x_two]) # 通道叠加x = layers.Add()([add, x]) # 元素相加x = layers.Activation('relu')(x)x = Channel_Shuffle(x) # 函数调用return xdef normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False, coord_conv2d=False):"""卷积模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param Separable: 是否使用可分离卷积:param coord_conv2d: 是否使用坐标卷积:param activation: 激活函数:return: 输出数组[b, w, h, filters]"""if Separable:x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:if coord_conv2d:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = Coord_Conv2d(x)x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=1, padding='same')(x)else:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation(activation)(x)return xdef upsample(x, kernel_size, filters, strides, coord_conv2d=False):"""上采样模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param coord_conv2d: 是否使用坐标卷积:return: 输出数组[b, w, h, filters]"""if coord_conv2d:x = Coord_Conv2d(x)x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)return xdef CAttention(x, channel):x_origin = xx = layers.GlobalAveragePooling2D()(x)x = K.expand_dims(x, axis=1)x = K.expand_dims(x, axis=1)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)x = tf.multiply(x, x_origin)return xdef SAttention(x):x_origin = xx = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')x = tf.multiply(x, x_origin)return xdef SCAttention(x, channel):x1 = CAttention(x, channel)x2 = SAttention(x)x = layers.Add()([x1, x2])return xdef Encoder(x):"""编码器:param x: 输入数组[b, w, h, filter]:return: 输出数组[b, w, h, filter_new]"""FF_layers = []x = normal_conv2d(x, 3, 32, 2, coord_conv2d=False) # 函数调用x = split_shuffle_module(x, 32, 1, [2]) # 函数调用x = SCAttention(x, 32)FF_layers.append(x)x = normal_conv2d(x, 3, 64, 2, coord_conv2d=False)x = split_shuffle_module(x, 64, 1, [5])x = SCAttention(x, 64)FF_layers.append(x)x = normal_conv2d(x, 3, 128, 2, coord_conv2d=False)x = split_shuffle_module(x, 128, 1, [8])x = SCAttention(x, 128)FF_layers.append(x)x = normal_conv2d(x, 3, 256, 2, Separable=True)return x, FF_layersdef Decoder(x, num_classes, FF_layers):x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)x = layers.Concatenate()([x, FF_layers[2]])x = upsample(x, 3, num_classes, 2, coord_conv2d=False)x = layers.Concatenate()([x, FF_layers[1]])x = upsample(x, 3, num_classes, 2, coord_conv2d=False)x = layers.Concatenate()([x, FF_layers[0]])return xdef SCAttention_Net(input_size, num_classes):inputs = keras.Input(shape=input_size + (3,))x, FF_layers = Encoder(inputs)x = Decoder(x, num_classes, FF_layers)x = upsample(x=x, strides=2, kernel_size=3, filters=num_classes)outputs = layers.Conv2D(filters=num_classes, kernel_size=3, padding='same', activation='softmax')(x)models = keras.Model(inputs, outputs)return modelsif __name__ == '__main__':model = SCAttention_Net((512, 512), 5)model.summary()#keras.utils.plot_model(model, dpi=96, to_file='./SCAttention_Net.png', show_shapes=True)
import 的coord_conv2d.py
import tensorflow.keras.backend as Kdef Coord_Conv2d(inputs, radius=False):input_shape = K.shape(inputs)input_shape = [input_shape[i] for i in range(4)]batch_shape, dim1, dim2, channels = input_shapexx_ones = K.ones(K.stack([batch_shape, dim2]), dtype='int32') # 创建[batch_size, dim2]大小的空数组xx_ones = K.expand_dims(xx_ones, axis=-1) # 扩维至[batch_size, dim2, 1],例:[4,128,1]xx_range = K.tile(K.expand_dims(K.arange(0, dim1), axis=0),K.stack([batch_shape, 1])) # K.tile 复制数组,K.tile(shape([128,1],shape([4,1)) = shape[4,128]xx_range = K.expand_dims(xx_range, axis=1) # 从[4,128]扩维至[4,1,128]xx_channels = K.batch_dot(xx_ones, xx_range, axes=[2, 1]) # 矩阵乘[4,128,1]*[4,1,128]=[4,128,128]xx_channels = K.expand_dims(xx_channels, axis=-1) # [4,128,128,1]xx_channels = K.permute_dimensions(xx_channels, [0, 2, 1, 3]) # 交换维度 [4,128,128,1]yy_ones = K.ones(K.stack([batch_shape, dim1]), dtype='int32')yy_ones = K.expand_dims(yy_ones, axis=1)yy_range = K.tile(K.expand_dims(K.arange(0, dim2), axis=0),K.stack([batch_shape, 1]))yy_range = K.expand_dims(yy_range, axis=-1)yy_channels = K.batch_dot(yy_range, yy_ones, axes=[2, 1])yy_channels = K.expand_dims(yy_channels, axis=-1)yy_channels = K.permute_dimensions(yy_channels, [0, 2, 1, 3])xx_channels = K.cast(xx_channels, K.floatx()) # int--->floatxx_channels = xx_channels / K.cast(dim1 - 1, K.floatx())xx_channels = (xx_channels * 2) - 1.yy_channels = K.cast(yy_channels, K.floatx())yy_channels = yy_channels / K.cast(dim2 - 1, K.floatx())yy_channels = (yy_channels * 2) - 1.outputs = K.concatenate([inputs, xx_channels, yy_channels], axis=-1)if radius:radius_layer = K.sqrt(K.square(xx_channels-0.5) + K.square(yy_channels-0.5))outputs = K.concatenate([outputs, radius_layer], axis=-1)return outputsif __name__ == '__main__':x = K.ones([4, 32, 32, 3])x = Coord_Conv2d(x)print(x.shape)
语义分割注意力机制SE模块tensorflow代码实现相关推荐
- 网络中加入注意力机制SE模块
SENet是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率.SENet 是2017 ILSVR竞赛的冠军. ...
- 注意力机制(SE, ECA, CBAM, SKNet, scSE, Non-Local, GCNet, ASFF) Pytorch代码
注意力机制 1 SENet 2 ECANet 3 CBAM 3.1 通道注意力 3.2 空间注意力 3.3 CBAM 4 展示网络层具体信息 5 SKNet 6 scSE 7 Non-Local Ne ...
- yolov5-6.0/6.1加入SE、CBAM、CA注意力机制(理论及代码)
自从yolov5-5.0加入se.cbam.eca.ca发布后,反响不错,也经常会有同学跑过来私信我能不能出一期6.0版本加入注意力的博客.个人认为是没有必要专门写一篇来讲,因为步骤几乎一样,但是问的 ...
- 通道注意力机制keras_在TensorFlow+Keras环境下使用RoI池化一步步实现注意力机制
项目地址:https://gist.github.com/Jsevillamol/0daac5a6001843942f91f2a3daea27a7 理解 RoI 池化 RoI 池化的概念由 Ross ...
- 注意力机制的本质中文版代码
为什么这样就是注意力机制可以参考 https://dongfangyou.blog.csdn.net/article/details/116389080 import torch from torch ...
- MATLAB算法实战应用案例精讲-【人工智能】语义分割(附实战应用案例及代码)
目录 前言 几个相关概念 算法原理 语义分割存在的难点 语义分割类型
- ❀论文篇❀注意力机制SE论文的理解
Squeeze-and-Excitation Networks(SENet) 论文地址:https://arxiv.org/abs/1709.01507 主要思想: 提出SE block 优点: 增强 ...
- 开源|如何利用Tensorflow实现语义分割全卷积网络(附源码)
导读:本项目是基于论文<语义分割全卷积网络的Tensorflow实现>的基础上实现的,该实现主要是基于论文作者给的参考代码.该模型应用于麻省理工学院(http://sceneparsing ...
- 89. 注意力机制以及代码实现Nadaraya-Waston 核回归
1. 心理学 动物需要在复杂环境下有效关注值得注意的点 心理学框架:人类根据随意线索和不随意线索选择注意点 随意:随着自己的意识,有点强调主观能动性的意味. 2. 注意力机制 2. 非参注意力池化层 ...
最新文章
- Python API简单验证
- java jtable不可编辑_java – 使JTable单元不可编辑
- 《飞鸽传书》把写程序和文学创作相提并论
- 内存经销商穷困潦倒 七元午饭都赊账
- js 防止重复提交方案
- 通用即插即用监视器驱动下载_大楚云控下载-大楚云控电脑客户端1.0.7 官方版...
- PHP通过身份证号码获取性别、出生日期、年龄等信息
- 一文看懂STM32单片机和51单片机区别
- 网络可视化工具netron
- 一文了解驱动程序及更新方法
- java中成员变量和局部变量的区别
- 城市系统应用其一-表征城市交通模式
- adc芯片分享,人体脂肪秤芯片CS1256
- 金弘同创教育:拼多多店铺分数多久清算一次
- 大众点评评论抓取-CSS加密破解
- linux 中的rime 输入法 自定义 新世纪五笔输入法
- armbian 斐讯n1_斐讯N1-ArmBian系统写入EMMC及优化
- 关于硬盘扇区的基本知识
- cadence allegro导入dxf文件
- 兰州城市学院计算机专业在哪个校区,兰州城市学院 代码
热门文章
- connect的使用
- 几所大学新增计算机相关专业研究生招生
- SAP 学习 1---基础篇
- android service常驻内存的一点思考
- 雅克比矩阵(Jacobian Matrix)在正运动学中的应用
- arcgis 只能查看指定行政区域_ArcGIS之宗地分割与编号
- CSS字体连写及外观属性
- Unity3D游戏开发之使用AssetBundle和Xml实现场景的动态加载
- supervised使用教程
- Mixed Content: The page at“https://xxx”was loaded over HTTPS, but requested an insecure