最近看了一篇有关多示例学习的paper,题目为Data Efficient and Weakly Supervised Computational Pathologyon Whole Slide Images,对里面提出的模型比较感兴趣,特此做一下笔记。
github地址:https://github.com/mahmoodlab/CLAM
paper地址:https://arxiv.org/abs/2004.09666

笔记

这篇paper提出了一个Clustering-constrained Attention Multiple Instance Learning的模型,简称为CLAM模型,整个模型的框架图如下:

由于整篇paper对其模型框架介绍的很模糊,而且光看上面的流程图也看不明白。通过对CLAM仓库代码的研究,总算弄明白了整体模型的结构。
还是采用基于embedding的两阶段的训练方式(如果不明白什么是基于embedding,可以看我上一篇文章背景知识部分——Dual-stream multiple instance learning networks for tumor detection in Whole Slide Image——论文笔记)。

(一)特征提取部分

特征提取部分很简单,并没有采用复杂的方式,而是直接采用pytorch提供的预训练权重。仓库代码默认是采用resnet50的ImageNet预训练权重来提取特征的,这部分没什么好说的。

(二)示例分类器

示例分类器的模型结构可以简单概括为门注意力层+全连接层,整体的模型并不复杂,
训练部分可以分为以下步骤,以resnet50提取的特征为例,特征维度为1024:

  1. 使用gate_attention将M×1024的特征向量转为M×1的注意力分数和M×512的特征映射。(M为一个slide中所有的patch数量)
  2. 对注意力分数进行排序,取出最大最小的topk个分数对应的特征映射(topk默认设置为8)。
  3. 将最大topk的标签设为1,最小topk的标签设为0,作为instance标签。
  4. 对2×topk的特征映射输入N个二分类全连接层,得到N个二分类输出。(N为预测类别)
  5. 计算N个二分类输出与instance标签的SmoothTop1SVM Loss(这个loss就是instance loss)。
  6. 将注意力分数乘于对应的特征映射并将所有特征映射相加相加,得到512×1的特征向量。
  7. 将特征向量输进去全连接层,得到bag分类结果。
  8. 计算bag分类的CE Loss。(这个loss就是bag loss)
  9. 总Loss定义为:bag_weight * bag_loss+(1-bag_weight)*instance_loss。(bag_weight默认是0.8)

上面介绍的是整体流程的思路,具体的模型是怎么定义的, 怎么组合的,仓库代码都有,我就不一一赘述了,都是一些很简单模型构件,并不难理解。

(三)一些细节部分

这些细节部分都是从仓库提供的源代码推出来的,并不保证一定是正确的。

  1. dataloader的batch size是固定设为1。
  2. 无论是二分类还是多分类问题,关于计算instance loss时候的标签都是最大topk为1,最小topk为0。
  3. 仓库里面提供了两个模型,一个单分支的clam_sb,另外一个是多分支的clam_mb。 两者的区别在于:
    (1) mb计算注意力分数是计算了每个类别的分数,即N×C个分数,最后分类层也是有多少类别就有多少个全连接层,每个全连接层输出每个类别bag分数。
    (2) sb无论多少类别,计算注意力分数时都是N×1个分数,最后分类层就一个全连接层,输出为类别数。

CLAM——论文笔记相关推荐

  1. ORB-SLAM3 论文笔记

    ORB-SLAM3 论文笔记 这篇博客 ORB-SLAM3系统 相机模型的抽象(Camera Model) 重定位的问题 图片矫正的问题 视觉惯性SLAM的工作原理 相关公式 IMU初始化 跟踪和建图 ...

  2. 【论文笔记】 LSTM-BASED DEEP LEARNING MODELS FOR NONFACTOID ANSWER SELECTION

    一.简介 这篇论文由IBM Watson发表在2016 ICLR,目前引用量92.这篇论文的研究主题是answer selection,作者在这篇论文基础上[Applying Deep Learnin ...

  3. 最新图神经网络论文笔记汇总(附pdf下载)

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 [导读]近年来,图神经网络变得非常火热,每年顶会在该领域内都会出现大量的研究论文,本文为大家提 ...

  4. [论文笔记] Fast Quality Driven Selection of Composite Web Services (ECOWS, 2006)

    Time: 4.0 hours Jae-Ho Jang, Dong-Hoon Shin, Kyong-Ho Lee, "Fast Quality Driven Selection of Co ...

  5. 论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning

    论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning  2017-06-06  21: ...

  6. 光流 速度_[论文笔记] FlowNet 光流估计

    [论文笔记] FlowNet: Learning Optical Flow with Convolutional Networks 说在前面 个人心得: 1. CNN的光流估计主要是速度上快,之后的v ...

  7. 论文笔记 《Maxout Networks》 《Network In Network》

    原文出处:http://zhangliliang.com/2014/09/22/paper-note-maxout-and-nin/ 论文笔记 <Maxout Networks> & ...

  8. 论文笔记:HKMF-T: Recover From Blackouts in TaggedTime Series With Hankel Matrix Factorization

    论文笔记:Hankel Matrix Factorization for Tagged Time Series to Recover Missing Values during Blackouts_U ...

  9. 论文笔记 A Spatial-Temporal Decomposition Based Deep Neural Network for TimeSeries Forecasting

    0 abstract 空间时间序列预测问题出现在广泛的应用中,如环境和交通问题.由于存在特定的空间.短期和长期模式,以及维度的诅咒,这些问题具有挑战性. 在本文中,我们提出了一个用于大规模空间时间序列 ...

最新文章

  1. 收藏!美国博士明确给出Python的高效学习技巧
  2. 从主数据的角度看一个零售ERP系统
  3. B-树的插入、查找、删除
  4. python提取word参考文献_写作相关 | word中参考文献转化为.bib格式全流程
  5. NSLog的常用格式说明小释
  6. laravel 邮件配置
  7. 复杂背景下计算机视觉模型害虫识别的比较研究(像素语义分割网络SegNet)
  8. mysql数据自定义随机_MySQL 利用事务自定义插入随机数据
  9. 深度学习入门者选择开源框架丨硬创公开课群友问答
  10. 深入理解java虚拟机系列文章:类的加载、连接与初始化
  11. java远程获取linux文件_Java远程连接操作linux服务器,scp获取文件
  12. 【SEO优化,网络营销】刘克亚《利润腾挪》,一分钟销售51000元的书
  13. [经验]iOS开发-记录下在开发过程中遇到的问题的解决方案及经验总结-1
  14. c语言最大乘积问题,利用C语言来求最大连续子序列乘积的方法
  15. Vue移动端系列 => [06] 文章搜索
  16. 联通(上海)产互一面
  17. private和protected的区别_学习笔记
  18. 浅析敏捷项目管理中的5大阶段
  19. php安全新闻早八点-Microdoor-第四季
  20. linux 卸载license,卸载 Network License Manager 的步骤

热门文章

  1. 开发新产品离不开CRM需求分析
  2. 肘方法确定聚类数k_肘方法确定KMeans聚类的最佳K值
  3. hexo+yilia添加live2d看板娘
  4. Android相关资源下载
  5. 基于python的电影数据可视化分析与推荐系统
  6. 电子邮件头跟踪_如何以正确的方式发送电子邮件:跟踪,跟进并获得回复。
  7. [Unity]EasyTouch手指滑动返回距离值
  8. mobiscroll 破解
  9. Oracle DB 性能管理
  10. 使用Python进行微信公众号开发(三)回复消息