StatScores的原理与使用

  • Confusion matrix (混淆矩阵)
    • 四分类定义
      • 关系图
      • Precision(准确率) 与 Recall (召回率)
  • StatScores类
    • 继承关系
    • 四类任务
    • Update与Compute方法
      • 1. update
        • _input_format_classification的四个参数
      • **_stat_scores**
        • 举个例子
      • 2. compute
        • **_stat_scores_compute**

Confusion matrix (混淆矩阵)

在介绍StatScores之前,我们先复习以下Confusion matrix。

我们有两组数据,分别为真实分布预测分布
预测为真定义为Possitive,预测为假定义为Negetive

四分类定义

  1. 如果预测Possitive与真实一致,则为True Possitive,简写为TP
  2. 如果预测Negetive与真实一致,则为True Negetive,简写为TN
  3. 如果预测Possitive与真实不一致,则为False Possitive,简写为FP
  4. 如果预测Negetive与真实不一致,则为False Negetive,简写为FN

关系图


StatScores类实际上就是统计一组预测数据的这四个分类。

额外提一下Precision与Recall

Precision(准确率) 与 Recall (召回率)

Precision=TPTP+FPPrecision = \cfrac {TP} {TP+FP} Precision=TP+FPTP​
Recall=TPTP+FNRecall = \cfrac {TP} {TP+FN} Recall=TP+FNTP​


StatScores类

继承关系

直接继承与Metrics

class StatScores(Metric)

四类任务

它将处理的case分为了四类

  1. Binary 二分类
  2. MultiClass 多分类
  3. MultiLabel 多标签
  4. MultiClass&MultiLabel

没有入参指定所属的任务case,代码中是根据pred张量来判断的。逻辑如下,

因为笔者暂时只使用第1和2中,所以其他暂不介绍了。

Update与Compute方法

所有继承Metrics的子类都需要实现Update和Compute方法。

1. update


update方法中调用内部方法 _stat_scores_update

在该方法内部,首先将根据输入的数据做分类 _input_format_classification

该方法主要作用是将preds和target做one hot化,所属分类任务的case也在该方法中识别的。

_input_format_classification的四个参数

这里有三个参数注意以下:

  • threshold
    它仅仅作用与Binary的任务,作用是preds张量中,如果元素大于threshold,则规整为1,否则规整为0
  • num_classes
    指明分类种类,如果不指明的话,代码中根据元素值的最大值来判断。这个值同时也会影响one_hot后的数据长度。
  • multiclass
    如果multiclass=False,则强制认为所属任务为Binary。True或者不设置(None)则根据入参自行判断
  • topk
    在多分类任务中,在做one_hot转换时,需要返回的最大前k个位置。
    比如[0.1,0.5,0.4], 在topk=1(默认时),返回的是 [0,1,0],
    如果topk=2,则返回的是[0,1,1]

_stat_scores

_stat_scores是真实计算tp, fp, tn, fn四个值的地方。

举个例子

假设我们有如下

preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

首先,在 _input_format_classification方法处理后,这两个张量会转换为one_hot形式如下,

preds = [[1,0], [0,1], [1,0]]
target= [[0,1], [0,1], [1,0]]

然后, 进入**_stat_scores**
第64,65行的计算结果如下:

# 预测true是正确的预测值和预测是false是正确的预测值
true_pred, false_pred = [[False,False], [True, True], [True, True]] , [ [True True], [False, False] [False, False]
# 预测是Ture的预测值与预测是False的预测值
pos_pred, neg_pred = [[False, True] [False, True] [True, False]] , [[True False] [True False] [True False]]

这两者再两两相乘,得到tp fp tn fn

    tp = (true_pred * pos_pred).sum(dim=dim)fp = (false_pred * pos_pred).sum(dim=dim)tn = (true_pred * neg_pred).sum(dim=dim)fn = (false_pred * neg_pred).sum(dim=dim)

2. compute

compute调用内部方法 _stat_scores_compute

_stat_scores_compute

该方法返回一个数组, [tp, fp, tn, fn, tp_fn]

这个就是StatScores的返回结果。

【pytorch】StatScores的原理与使用相关推荐

  1. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  2. 【分布式】Pytorch分布式训练原理和实战

    [分布式]基于Horovod的Pytorch分布式训练原理和实战 并行方法: 1. 模型并行 2. 数据并行 3. 两者之间的联系 更新方法: 1. 同步更新 2. 异步更新 分布式算法: 1. Pa ...

  3. Pytorch分布式训练原理简介

    1. 引言 分布式训练就是指将模型放置在很多台机器并且在每台机器上的多个GPU上进行训练,之所以使用分布式训练的原因一般来说有两种:其一是模型在一块GPU上放不下,其二使用多块GPU进行并行计算能够加 ...

  4. Pytorch - 弹性训练原理

    Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch.torchrun在torch.distributed.launch 功能 ...

  5. pytorch深度学习原理实战-rightness函数

    在pytorch深度学习实战(集智俱乐部)书中卷积神经网络一章没有定义rightness函数.下面是自己实现代码 def rightness(output, target):preds = outpu ...

  6. pytorch反向传播原理

    假设我们有一组数据如下表所示,,表示学习时长和分数的关系,我们要推断出当x=4时,y为多少 x(hours) y(points) 1 2 2 4 3 6 4 ? 通常情况下我们假设上述模型是一个线性模 ...

  7. pytorch图片分割原理

    自从transformer应用到cv领域以后,对图片的分割需求便越加重了,但是图像分割说起来容易,实际操作起来还是有很多地方不懂(主要还是code能力太弱). 我们知道,对张量的处理一般又两种,一种是 ...

  8. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  9. multi task训练torch_手把手教你使用PyTorch(2)-requires_gradamp;computation graph

    import torch 1. Requires_grad 但是,模型毕竟不是人,它的智力水平还不足够去自主辨识那些量的梯度需要计算,既然如此,就需要手动对其进行标记. 在PyTorch中,通用的数据 ...

最新文章

  1. Jenkins 无法捕获构建脚本错误问题
  2. Matlab实现线性回归和逻辑回归: Linear Regression Logistic Regression
  3. Oracle的启动机制
  4. Linux添加授信根证书,linux系统添加根证书 linux证书信任列表
  5. 找出数组中只出现1次的两个元素
  6. IE8/9的几个前端bug解决方案
  7. 浅谈assert()函数的用法
  8. 3.修改和编译XposedBridge.jar 和 api.jar
  9. 数字信号处理课程设计---带通滤波器的设计及其matlab实,数字信号处理课程设计---带通滤波器的设计及其MATLAB实现...
  10. Linux进程间通信
  11. Android中矢量图形的相关知识
  12. 安卓和win环境下扫描局域网下设备IP的工具
  13. 银行卡查询银行卡类型查询及归属地查询
  14. 计算机图形学画简单图形,计算机图形学 基本图形绘制 Koch雪花绘制
  15. 三点组成的三角形的面积计算公式(海伦公式)
  16. C++/MFC修行之路(5)Ribbon(功能区)的使用
  17. 【rzxt】巧用电池小工具 电量问题全掌握
  18. 超级实习生计划学习笔记——Redis字符串
  19. 必备技能:图解用电烙铁焊接电路
  20. Service Mesh(服务网格)——后 Kubernetes 时代的微服务

热门文章

  1. 深入浅出CMake(一):基础篇
  2. Drug Discovery Today | 频繁命中化合物机制探究:PAINS规则的局限性
  3. GRADE:联合学习演化节点和社区表示的概率生成模型
  4. 靶向新冠状病毒(COVID-19)的药物靶点
  5. mysql必备技能_Mysql常用技能(1)
  6. 市场有变,中小型基因测序机构机会来了
  7. NBT:牛瘤胃微生物组的4941个宏基因组组装基因组(MAG)
  8. MPB:亚热带生态所谭支良、焦金真等-​反刍动物瘤胃样品采集与保存
  9. 北科院分子互作实战专题培训班(10月底/11月底班)(生物医药与营养健康协同创新中心)...
  10. Science-2018-微生物群落的构建过程具有趋简性