介绍

  • 论文原地址: Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks
    在U-Shape网络模型的基础上添加SE模块,具体加在每个卷积模块后,如图所示。

Spatial Squeeze and Channel Excitation Block


from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import tensorflow as tfdef normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False):"""卷积模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param Separable: 是否使用可分离卷积:param coord_conv2d: 是否使用坐标卷积:param activation: 激活函数:return: 输出数组[b, w, h, filters]"""if Separable:x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation(activation)(x)return xdef CAttention(x, channel):x_origin = xx = layers.GlobalAveragePooling2D()(x)x = K.expand_dims(x, axis=1)x = K.expand_dims(x, axis=1)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2, activation='relu')x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)x = tf.multiply(x, x_origin)return x

Channel Squeeze and Spatial Excitation Block

def SAttention(x):x_origin = xx = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')x = tf.multiply(x, x_origin)return x

Spatial and Channel Squeeze & Excitation Block (scSE)

def SCAttention(x, channel):x1 = CAttention(x, channel)x2 = SAttention(x)x = layers.Add()([x1, x2])return x

模型代码示例

"""
SCAttentionNet: with attention
Author: XG_hechao
Begin Date: 20201113
End Date:
"""
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from CoordConv2 import Coord_Conv2d
import tensorflow as tfdef Channel_Split(x):"""通道划分:param x: 输入数组[b, w, h, filters]:return: 两个数组[b, w, h, 0:filters/2]和[b, w, h, filters/2:]"""channel_before = x.shape.as_list()[1:]  # 取通道数split_channel_num = channel_before[2] // 2  # 取一半通道channel_one = x[:, :, :, 0:split_channel_num]  # 取前一半通道channel_two = x[:, :, :, split_channel_num:]  # 取后一半通道return channel_one, channel_twodef Channel_Shuffle(x):"""通道洗牌:param x: 输入数组[b, w, h, filters]:return: 输出数组[b, w, h, filters_new]"""height, width, channels = x.shape.as_list()[1:]  # 取通道数channels_per_split = channels // 2  # 取一半x = K.reshape(x, [-1, height, width, 2, channels_per_split])  # 将n维打乱为[2, n/2]维x = K.permute_dimensions(x, (0, 1, 2, 4, 3))  # 维度重排序x = K.reshape(x, [-1, height, width, channels])  # 通道重组  [2, n/2]--->nreturn xdef branch(x, filters, dilation_rate, unit_num, right=False):"""卷积分支:param x: 输入数组[b, w, h, filter]:param filters: 卷积核数量:param dilation_rate: 扩张卷积率:param unit_num: int:param right: 分支卷积方式, False为使用右侧卷积方式:return: 输出数组[b, w, h, filters]"""if right:x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x)  # 卷积x = layers.Activation('relu')(x)  # 激活x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)x = layers.BatchNormalization()(x)  # BN归一化x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_right_{0}_1_{1}'.format(unit_num, dilation_rate))(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_right_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)else:x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1)(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1)(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(1, 3), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_left_{0}_1_{1}'.format(unit_num, dilation_rate))(x)x = layers.Activation('relu')(x)x = layers.Conv2D(filters=filters, kernel_size=(3, 1), padding='same', strides=1, dilation_rate=dilation_rate,name='dilation_conv2d_left_{0}_2_{1}'.format(unit_num, dilation_rate*2))(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)return xdef split_shuffle_module(x, filters, unit_num, dilation_rate_value):"""卷积通道分离、打乱模块:param x: 输入数组[b, w, h, filter]:param filters: 卷积核数量:param unit_num: 循环单元数:param dilation_rate_value: 扩张卷积率:return: 输出数组[b, w, h, filters]"""for i in range(unit_num):# if len(dilation_rate_value) is not 1:dilation_rate = dilation_rate_value[i]# else:#     dilation_rate = dilation_rate_valueadd = xx_one, x_two = Channel_Split(x)  # 函数调用x_one = branch(x_one, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=True)  # 函数调用x_two = branch(x_two, filters=filters//2, dilation_rate=dilation_rate, unit_num=i, right=False)x = layers.Concatenate()([x_one, x_two])  # 通道叠加x = layers.Add()([add, x])  # 元素相加x = layers.Activation('relu')(x)x = Channel_Shuffle(x)  # 函数调用return xdef normal_conv2d(x, kernel_size, filters, strides, activation='relu', Separable=False, coord_conv2d=False):"""卷积模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param Separable: 是否使用可分离卷积:param coord_conv2d: 是否使用坐标卷积:param activation: 激活函数:return: 输出数组[b, w, h, filters]"""if Separable:x = layers.SeparableConv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:if coord_conv2d:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = Coord_Conv2d(x)x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=1, padding='same')(x)else:x = layers.Conv2D(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation(activation)(x)return xdef upsample(x, kernel_size, filters, strides, coord_conv2d=False):"""上采样模块:param x: 输入数组[b, w, h, filter]:param kernel_size: 卷积尺寸:param filters: 卷积核数量:param strides: 卷积步长:param coord_conv2d: 是否使用坐标卷积:return: 输出数组[b, w, h, filters]"""if coord_conv2d:x = Coord_Conv2d(x)x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)else:x = layers.Conv2DTranspose(kernel_size=kernel_size, filters=filters, strides=strides, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.Activation('relu')(x)return xdef CAttention(x, channel):x_origin = xx = layers.GlobalAveragePooling2D()(x)x = K.expand_dims(x, axis=1)x = K.expand_dims(x, axis=1)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel//2)x = normal_conv2d(x=x, strides=1, kernel_size=1, filters=channel, activation='sigmoid')x = layers.UpSampling2D(size=(x_origin.shape[1], x_origin.shape[2]), interpolation='nearest')(x)x = tf.multiply(x, x_origin)return xdef SAttention(x):x_origin = xx = normal_conv2d(x=x, strides=1, kernel_size=1, filters=1, activation='sigmoid')x = tf.multiply(x, x_origin)return xdef SCAttention(x, channel):x1 = CAttention(x, channel)x2 = SAttention(x)x = layers.Add()([x1, x2])return xdef Encoder(x):"""编码器:param x: 输入数组[b, w, h, filter]:return: 输出数组[b, w, h, filter_new]"""FF_layers = []x = normal_conv2d(x, 3, 32, 2, coord_conv2d=False)  # 函数调用x = split_shuffle_module(x, 32, 1, [2])  # 函数调用x = SCAttention(x, 32)FF_layers.append(x)x = normal_conv2d(x, 3, 64, 2, coord_conv2d=False)x = split_shuffle_module(x, 64, 1, [5])x = SCAttention(x, 64)FF_layers.append(x)x = normal_conv2d(x, 3, 128, 2, coord_conv2d=False)x = split_shuffle_module(x, 128, 1, [8])x = SCAttention(x, 128)FF_layers.append(x)x = normal_conv2d(x, 3, 256, 2, Separable=True)return x, FF_layersdef Decoder(x, num_classes, FF_layers):x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)x = layers.Concatenate()([x, FF_layers[2]])x = upsample(x, 3, num_classes, 2, coord_conv2d=False)x = layers.Concatenate()([x, FF_layers[1]])x = upsample(x, 3, num_classes, 2, coord_conv2d=False)x = layers.Concatenate()([x, FF_layers[0]])return xdef SCAttention_Net(input_size, num_classes):inputs = keras.Input(shape=input_size + (3,))x, FF_layers = Encoder(inputs)x = Decoder(x, num_classes, FF_layers)x = upsample(x=x, strides=2, kernel_size=3, filters=num_classes)outputs = layers.Conv2D(filters=num_classes, kernel_size=3, padding='same', activation='softmax')(x)models = keras.Model(inputs, outputs)return modelsif __name__ == '__main__':model = SCAttention_Net((512, 512), 5)model.summary()#keras.utils.plot_model(model, dpi=96, to_file='./SCAttention_Net.png', show_shapes=True)

import 的coord_conv2d.py

import tensorflow.keras.backend as Kdef Coord_Conv2d(inputs, radius=False):input_shape = K.shape(inputs)input_shape = [input_shape[i] for i in range(4)]batch_shape, dim1, dim2, channels = input_shapexx_ones = K.ones(K.stack([batch_shape, dim2]), dtype='int32')  # 创建[batch_size, dim2]大小的空数组xx_ones = K.expand_dims(xx_ones, axis=-1)  # 扩维至[batch_size, dim2, 1],例:[4,128,1]xx_range = K.tile(K.expand_dims(K.arange(0, dim1), axis=0),K.stack([batch_shape, 1]))  # K.tile 复制数组,K.tile(shape([128,1],shape([4,1)) = shape[4,128]xx_range = K.expand_dims(xx_range, axis=1)  # 从[4,128]扩维至[4,1,128]xx_channels = K.batch_dot(xx_ones, xx_range, axes=[2, 1])  # 矩阵乘[4,128,1]*[4,1,128]=[4,128,128]xx_channels = K.expand_dims(xx_channels, axis=-1)  # [4,128,128,1]xx_channels = K.permute_dimensions(xx_channels, [0, 2, 1, 3])  # 交换维度 [4,128,128,1]yy_ones = K.ones(K.stack([batch_shape, dim1]), dtype='int32')yy_ones = K.expand_dims(yy_ones, axis=1)yy_range = K.tile(K.expand_dims(K.arange(0, dim2), axis=0),K.stack([batch_shape, 1]))yy_range = K.expand_dims(yy_range, axis=-1)yy_channels = K.batch_dot(yy_range, yy_ones, axes=[2, 1])yy_channels = K.expand_dims(yy_channels, axis=-1)yy_channels = K.permute_dimensions(yy_channels, [0, 2, 1, 3])xx_channels = K.cast(xx_channels, K.floatx())  # int--->floatxx_channels = xx_channels / K.cast(dim1 - 1, K.floatx())xx_channels = (xx_channels * 2) - 1.yy_channels = K.cast(yy_channels, K.floatx())yy_channels = yy_channels / K.cast(dim2 - 1, K.floatx())yy_channels = (yy_channels * 2) - 1.outputs = K.concatenate([inputs, xx_channels, yy_channels], axis=-1)if radius:radius_layer = K.sqrt(K.square(xx_channels-0.5) + K.square(yy_channels-0.5))outputs = K.concatenate([outputs, radius_layer], axis=-1)return outputsif __name__ == '__main__':x = K.ones([4, 32, 32, 3])x = Coord_Conv2d(x)print(x.shape)

语义分割注意力机制SE模块tensorflow代码实现相关推荐

  1. 网络中加入注意力机制SE模块

    SENet是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率.SENet 是2017 ILSVR竞赛的冠军. ...

  2. 注意力机制(SE, ECA, CBAM, SKNet, scSE, Non-Local, GCNet, ASFF) Pytorch代码

    注意力机制 1 SENet 2 ECANet 3 CBAM 3.1 通道注意力 3.2 空间注意力 3.3 CBAM 4 展示网络层具体信息 5 SKNet 6 scSE 7 Non-Local Ne ...

  3. yolov5-6.0/6.1加入SE、CBAM、CA注意力机制(理论及代码)

    自从yolov5-5.0加入se.cbam.eca.ca发布后,反响不错,也经常会有同学跑过来私信我能不能出一期6.0版本加入注意力的博客.个人认为是没有必要专门写一篇来讲,因为步骤几乎一样,但是问的 ...

  4. 通道注意力机制keras_在TensorFlow+Keras环境下使用RoI池化一步步实现注意力机制

    项目地址:https://gist.github.com/Jsevillamol/0daac5a6001843942f91f2a3daea27a7 理解 RoI 池化 RoI 池化的概念由 Ross ...

  5. 注意力机制的本质中文版代码

    为什么这样就是注意力机制可以参考 https://dongfangyou.blog.csdn.net/article/details/116389080 import torch from torch ...

  6. MATLAB算法实战应用案例精讲-【人工智能】语义分割(附实战应用案例及代码)

    目录 前言 几个相关概念 算法原理 语义分割存在的难点 语义分割类型

  7. ❀论文篇❀注意力机制SE论文的理解

    Squeeze-and-Excitation Networks(SENet) 论文地址:https://arxiv.org/abs/1709.01507 主要思想: 提出SE block 优点: 增强 ...

  8. 开源|如何利用Tensorflow实现语义分割全卷积网络(附源码)

    导读:本项目是基于论文<语义分割全卷积网络的Tensorflow实现>的基础上实现的,该实现主要是基于论文作者给的参考代码.该模型应用于麻省理工学院(http://sceneparsing ...

  9. 89. 注意力机制以及代码实现Nadaraya-Waston 核回归

    1. 心理学 动物需要在复杂环境下有效关注值得注意的点 心理学框架:人类根据随意线索和不随意线索选择注意点 随意:随着自己的意识,有点强调主观能动性的意味. 2. 注意力机制 2. 非参注意力池化层 ...

最新文章

  1. Python API简单验证
  2. java jtable不可编辑_java – 使JTable单元不可编辑
  3. 《飞鸽传书》把写程序和文学创作相提并论
  4. 内存经销商穷困潦倒 七元午饭都赊账
  5. js 防止重复提交方案
  6. 通用即插即用监视器驱动下载_大楚云控下载-大楚云控电脑客户端1.0.7 官方版...
  7. PHP通过身份证号码获取性别、出生日期、年龄等信息
  8. 一文看懂STM32单片机和51单片机区别
  9. 网络可视化工具netron
  10. 一文了解驱动程序及更新方法
  11. java中成员变量和局部变量的区别
  12. 城市系统应用其一-表征城市交通模式
  13. adc芯片分享,人体脂肪秤芯片CS1256
  14. 金弘同创教育:拼多多店铺分数多久清算一次
  15. 大众点评评论抓取-CSS加密破解
  16. linux 中的rime 输入法 自定义 新世纪五笔输入法
  17. armbian 斐讯n1_斐讯N1-ArmBian系统写入EMMC及优化
  18. 关于硬盘扇区的基本知识
  19. cadence allegro导入dxf文件
  20. 兰州城市学院计算机专业在哪个校区,兰州城市学院 代码

热门文章

  1. connect的使用
  2. 几所大学新增计算机相关专业研究生招生
  3. SAP 学习 1---基础篇
  4. android service常驻内存的一点思考
  5. 雅克比矩阵(Jacobian Matrix)在正运动学中的应用
  6. arcgis 只能查看指定行政区域_ArcGIS之宗地分割与编号
  7. CSS字体连写及外观属性
  8. Unity3D游戏开发之使用AssetBundle和Xml实现场景的动态加载
  9. supervised使用教程
  10. Mixed Content: The page at“https://xxx”was loaded over HTTPS, but requested an insecure