目录

  • 任务介绍
  • 官方baseline地址
  • 网络结构
  • 网络输入
  • Code
    • import
    • DataLoader
      • dataset类
      • 进行的transforms
    • NetWork
    • 评估指标
      • Train & Val
    • 训练模型超参数
  • 其他的一些说明
    • 测试集和验证集的划分
    • 结果文件生成
  • 优化方向

任务介绍

任务一:基于多模态眼底影像的青光眼分级

在本任务中,我们的目的是分析2D眼底图像和3D OCT扫描体数据两种模态的临床数据,根据视觉特征将样本分级为无青光眼、早期青光眼、中或晚期青光眼三个类别。

官方baseline地址

基于paddle实现的baseline

网络结构

网络的backbone由两个ResNet组成,一个负责提取2D模态图像特征,一个负责提取3D模态图像特征,经过卷积层后,将提取到的多维特征压成一维,通过concat合并为一个一维数组,最后输出分类结果。

网络输入

数据格式[batch, channel, height, width],3D图像使用opencv读入是3通道的,需要转换成单通道灰度图。

Code

这里只解释一些非常规的代码,常规代码看paddle官方文档即可

import


import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
# 评估函数
from sklearn.metrics import cohen_kappa_scoreimport paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import resnet34
# 这个transforms是一个自行实现的文件,不是第三方库
import transforms as transimport warnings

DataLoader

dataset类

class GAMMA_sub1_dataset(paddle.io.Dataset):"""getitem() output:fundus_img: RGB uint8 image with shape (3, image_size, image_size)oct_img:    Uint8 image with shape (256, oct_img_size[0], oct_img_size[1])"""def __init__(self,img_transforms,oct_transforms,dataset_root,label_file='',filelists=None,num_classes=3,mode='train'):self.dataset_root = dataset_rootself.img_transforms = img_transformsself.oct_transforms = oct_transformsself.mode = mode.lower()self.num_classes = num_classesif self.mode == 'train':label = {row['data']: row[1:].values for _, row in pd.read_excel(label_file).iterrows()}self.file_list = [[f, label[int(f)]] for f in os.listdir(dataset_root)]elif self.mode == "test":self.file_list = [[f, None] for f in os.listdir(dataset_root)]if filelists is not None:self.file_list = [item for item in self.file_list if item[0] in filelists]def __getitem__(self, idx):real_index, label = self.file_list[idx]fundus_img_path = os.path.join(self.dataset_root, real_index, real_index + ".jpg")# 这里有个问题就是图片读入是乱序的不是按照序号0到255读入oct_series_list = sorted(os.listdir(os.path.join(self.dataset_root, real_index, real_index)), key=lambda x: int(x.strip("_")[0]))fundus_img = cv2.imread(fundus_img_path)[:, :, ::-1] # BGR -> RGBoct_series_0 = cv2.imread(os.path.join(self.dataset_root, real_index, real_index, oct_series_list[0]), cv2.IMREAD_GRAYSCALE)oct_img = np.zeros((len(oct_series_list), oct_series_0.shape[0], oct_series_0.shape[1], 1), dtype="uint8")for k, p in enumerate(oct_series_list):oct_img[k] = cv2.imread(os.path.join(self.dataset_root, real_index, real_index, p), cv2.IMREAD_GRAYSCALE)[..., np.newaxis]# 如果对应图像数据transforms不为空,则对数据进行transformsif self.img_transforms is not None:fundus_img = self.img_transforms(fundus_img)if self.oct_transforms is not None:oct_img = self.oct_transforms(oct_img)# normlize on GPU to save CPU Memory and IO consuming.# fundus_img = (fundus_img / 255.).astype("float32")# oct_img = (oct_img / 255.).astype("float32")# 眼底图片改变通道顺序fundus_img = fundus_img.transpose(2, 0, 1) # H, W, C -> C, H, W# oct图像去掉最后一维oct_img = oct_img.squeeze(-1) # D, H, W, 1 -> D, H, Wif self.mode == 'test':return fundus_img, oct_img, real_indexif self.mode == "train":label = label.argmax()return fundus_img, oct_img, labeldef __len__(self):return len(self.file_list)

进行的transforms

# 训练集眼底图片进行的transforms
img_train_transforms = trans.Compose([# 随机resize裁剪trans.RandomResizedCrop(image_size, scale=(0.90, 1.1), ratio=(0.90, 1.1)),# 随机水平翻转trans.RandomHorizontalFlip(),# 随机垂直翻转trans.RandomVerticalFlip(),# 随机角度翻转0~30度trans.RandomRotation(30)
])
# 训练集oct图片进行的transforms
oct_train_transforms = trans.Compose([# 对图片进行中央剪裁trans.CenterCrop([256] + oct_img_size),trans.RandomHorizontalFlip(),trans.RandomVerticalFlip()
])
# 验证集眼底图片进行的transforms
img_val_transforms = trans.Compose([# 裁剪中央正方形trans.CropCenterSquare(),trans.Resize((image_size, image_size))
])
# 验证集oct图片进行的transforms
oct_val_transforms = trans.Compose([trans.CenterCrop([256] + oct_img_size)
])

NetWork

网络模型从paddle.vision.models import了resnet34,只要init的时候num_classes设置为0,网络结构就不含有末端全连接层

class Model(nn.Layer):"""simply create a 2-branch network, and concat global pooled feature vector.each branch = single resnet34"""def __init__(self):super(Model, self).__init__()# 带pretrained代表使用paddle的预训练模型,做Transfer Learningself.fundus_branch = resnet34(pretrained=True, num_classes=0) # remove final fcself.oct_branch = resnet34(pretrained=True, num_classes=0) # remove final fcself.decision_branch = nn.Linear(512 * 1 * 2, 3) # ResNet34 use basic block, expansion = 1# replace first conv layer in oct_branch# 对oct提取特征的resnet34的第一层卷积层进行修改,修改通道跟oct图像的输入通道一致,都是256self.oct_branch.conv1 = nn.Conv2D(256, 64,kernel_size=7,stride=2,padding=3,bias_attr=False)# 网络组网def forward(self, fundus_img, oct_img):b1 = self.fundus_branch(fundus_img)b2 = self.oct_branch(oct_img)# 将图像压成一维b1 = paddle.flatten(b1, 1)b2 = paddle.flatten(b2, 1)# 将两个一维tensor concat在一起,再通过一个全连接层,最后做softmax处理获得分类结果logit = self.decision_branch(paddle.concat([b1, b2], 1))return logit

评估指标

以下为该任务的评估指标,评估函数使用sklearn库的cohen_kappa_score函数计算

Train & Val

def train(model, iters, train_dataloader, val_dataloader, optimizer, criterion, log_interval, eval_interval):iter = 0model.train()# 储存平均loss和平均kappa的listavg_loss_list = []avg_kappa_list = []best_kappa = 0.while iter < iters:for data in train_dataloader:iter += 1if iter > iters:break# 输入图片int32转float32fundus_imgs = (data[0] / 255.).astype("float32")oct_imgs = (data[1] / 255.).astype("float32")labels = data[2].astype('int64')# 模型输入眼底和oct图像logits = model(fundus_imgs, oct_imgs)# 计算lossloss = criterion(logits, labels)# acc = paddle.metric.accuracy(input=logits, label=labels.reshape((-1, 1)), k=1)for p, l in zip(logits.numpy().argmax(1), labels.numpy()):avg_kappa_list.append([p, l])loss.backward()optimizer.step()model.clear_gradients()avg_loss_list.append(loss.numpy()[0])if iter % log_interval == 0:# 计算平均loss, list转ndarray再求均值avg_loss = np.array(avg_loss_list).mean()# list转ndarrayavg_kappa_list = np.array(avg_kappa_list)# 计算平均kappaavg_kappa = cohen_kappa_score(avg_kappa_list[:, 0], avg_kappa_list[:, 1], weights='quadratic')# 对两个list进行清空avg_loss_list = []avg_kappa_list = []print("[TRAIN] iter={}/{} avg_loss={:.4f} avg_kappa={:.4f}".format(iter, iters, avg_loss, avg_kappa))if iter % eval_interval == 0:# 进行验证操作, 获得验证集的avg_loss和avg_kappa avg_loss, avg_kappa = val(model, val_dataloader, criterion)print("[EVAL] iter={}/{} avg_loss={:.4f} kappa={:.4f}".format(iter, iters, avg_loss, avg_kappa))# 储存指标最优模型if avg_kappa >= best_kappa:best_kappa = avg_kappapaddle.save(model.state_dict(),os.path.join("best_model_{:.4f}".format(best_kappa), 'model.pdparams'))model.train()def val(model, val_dataloader, criterion):model.eval()avg_loss_list = []cache = []with paddle.no_grad():for data in val_dataloader:fundus_imgs = (data[0] / 255.).astype("float32")oct_imgs = (data[1] / 255.).astype("float32")labels = data[2].astype('int64')logits = model(fundus_imgs, oct_imgs)for p, l in zip(logits.numpy().argmax(1), labels.numpy()):cache.append([p, l])loss = criterion(logits, labels)# acc = paddle.metric.accuracy(input=logits, label=labels.reshape((-1, 1)), k=1)avg_loss_list.append(loss.numpy()[0])cache = np.array(cache)kappa = cohen_kappa_score(cache[:, 0], cache[:, 1], weights='quadratic')avg_loss = np.array(avg_loss_list).mean()return avg_loss, kappa

训练模型超参数

学习率在baseline是个固定值,batch_size和iteration根据自己实际情况而定,在这里就不写了

Key Value
优化器 Adam
loss函数 CrossEntropyLoss

其他的一些说明

测试集和验证集的划分

使用的是sklearn的train_test_split函数,进行测试集合验证集的划分

from sklearn.model_selection import train_test_split
val_ratio = 0.2 # 80 / 20
# 省略filelists生成
train_filelists, val_filelists = train_test_split(filelists, test_size=val_ratio, random_state=42)

结果文件生成

# cache为val操作生成的变量
submission_result = pd.DataFrame(cache, columns=['data', 'dense_pred'])
submission_result['non'] = submission_result['dense_pred'].apply(lambda x: int(x[0] == 0))
submission_result['early'] = submission_result['dense_pred'].apply(lambda x: int(x[0] == 1))
submission_result['mid_advanced'] = submission_result['dense_pred'].apply(lambda x: int(x[0] == 2))
# 最后生成提交结果文件
submission_result[['data', 'non', 'early', 'mid_advanced']].to_csv("./submission_sub1.csv", index=False)

优化方向

  1. 使用更优的预训练backbone替代resnet34
  2. 两种模态数据直接使用concat合并太直接,设计一个self-attention模块,让网络自行学习两者的比重
  3. 对oct图像进行降噪等处理
  4. 先把眼底图片的视盘分割出来再进行分类
  5. 官方提供的训练集为小样本,训练一个GAN去生成更多的样本

MICCAI2021 Contest : GAMMA任务一:<基于多模态眼底影像的青光眼分级>官方Baseline代码解释相关推荐

  1. 基于深度学习的眼底影像分析最新综述

    医学影像是深度学习取得极大成功的一个领域,而眼底图像是其中一个重要的分支.眼底图像是由单目相机捕获到的眼底的2D图像. 使用眼底图像可以用于对眼科疾病诊断分级.对病变点和重要的生物标记进行分割等等,对 ...

  2. 基于多模态数据挖掘算法matlab,多模态生物数据分析与挖掘研究

    多模态生物数据分析与挖掘研究 [摘要]:近年来,随着生物测量技术的飞速发展,在生命科学研究的不同领域都积累了大量的生物数据.这些数据中蕴藏着丰富信息,使得我们从不同角度全方位地了解与疾病或是特定表型相 ...

  3. 快速鲁棒的多模态遥感影像配准系统(可下载,支持大尺寸遥感影像),性能超越国际著名遥感商业软件ERDAS和ENVI

    文章目录 一.引言 二.多模态配准系统的技术流程 1.特征点检测 2.同名点匹配 3.误差剔除 4.图像校正 三.实验结果 四.总结 五.知识产权 一.引言 近些年来,随着国产卫星的不断发射,导致各种 ...

  4. 【笔记】基于低空无人机影像和 YOLOv3 实现棉田杂草检测

    <基于低空无人机影像和 YOLOv3 实现棉田杂草检测> 单位:石河子大学信息科学与技术学院 作者:薛金利 数据获取 设备:大疆 DJI 四旋翼无人机悟 Inspire l PRO 相机: ...

  5. 基于ArcSDE的影像数据管理-解决篇(转载)

    本文为转载http://www.cnblogs.com/rib06/category/56544.html 疑惑篇中简单介绍了基于ArcSDE的影像数据管理的基本方法.策略及其缺陷.那么要想基于Arc ...

  6. input不可编辑属性_谁不喜欢图文并茂呢:基于多模态信息的属性抽取

    0. 前言 最近做属性抽取,并且基于多模态信息(文本+图片)那种,然后发现了一个比较经典的论文"Multimodal Attribute Extraction".正好就顺着这个论文 ...

  7. 论文阅读:基于多模态词向量的语句距离计算方法

    论文信息 华阳. 基于多模态词向量的语句距离计算方法[D].哈尔滨工业大学,2018. 1.主要工作 简述语句间的距离问题:自然语言处理任务是度量文本间的距离:不同阶段语言学习的难度可以抽象为距离,本 ...

  8. 眼底影像血管分割(一):选择通道

    一:通道选择 一张眼底影像是RGB三色的,我们在做血管分割时,需要选择比较适合的图像来作为原始图像进行分割.那么选择哪个通道呢? 绿色通道?红色通道?蓝色通道? 好了,上图: 上图中四张图均来自同一张 ...

  9. 文献阅读_基于多模态数据语义融合的旅游在线评论有用性识别研究

    文献来源:马超,李纲,陈思菁,毛进,张霁.基于多模态数据语义融合的旅游在线评论有用性识别研究[J].情报学报,2020,39(02):199-207. 基于多模态数据语义融合的旅游在线评论有用性识别研 ...

最新文章

  1. Bootsrap基本应用
  2. Nagios监控lvs服务
  3. m.2接口和nvme区别_透明款散热不好,那么ORICO 全铝NVMe固态硬盘盒了解一下?
  4. [半翻] 设计面向DDD的微服务
  5. java2048设计说明,Html5中的本地存储设计理念
  6. Github readme语法-- markdown
  7. 我的ROS学习之路——服务通信
  8. Python 爬取zw年鉴
  9. html查重报告转换,知网查重报告网页版如何转换成PDF和WORD?
  10. 洛谷 U80415 懒懒的Seaway
  11. 华三交换机snmp配置
  12. html 预选单选按钮,关于html:单选按钮的预选
  13. 三阶魔方CFOP还原方法图解
  14. 【2023计算机考研】双非院校录取分数线汇总
  15. 派大星如期反馈小程序的生命周期
  16. 酸菜鱼用什么鱼最好吃
  17. 关于灰色软件(Grayware)及其危害你了解多少?
  18. 安装ubuntu18.04之后遇到的问题,及运行ROS-Academy-for-Beginners遇到的问题
  19. HTML5游戏《被淹没的王国》截图
  20. UsbDeviceManager.java

热门文章

  1. 百度引领AI大生态,产业联盟谋破局
  2. 期货CTP接口与程序化(量化交易)的对接(1)
  3. 基于物联网的远程温湿度监测系统 --- ESP8266 + 机智云
  4. VLC rtsp服务分割/打包HEVC(h265)
  5. Tableau10——人口金字塔,漏斗图,箱型图
  6. html table-cell用法,CSS中的table-cell属性使用实例教程
  7. 亚泰盛世邀您参观第66届中国教育装备展示会
  8. elasticsearch安装ik分词器
  9. Flink CDC + OceanBase 全增量一体化数据集成方案
  10. 三相电路线电压(电流)与相电压(电流)的关系