GHMC_Loss pytorch实现

  • 前言
  • 存在的问题
  • 代码

前言

做图像分割,想试试GHMC_Loss,但没找到一个完整,复杂度较好的代码,所以写了一个自我感觉良好的代码。贴出来方便小白参考和接受大佬指正错误.
学习原理可以看5分钟理解Focal Loss与GHM——解决样本不平衡利器

存在的问题

训练到后面会出现验证loss飙升但数据指标正常的情况,还是小白不清楚是不是哪里错了.

代码

import torch
import numpy as np
from torch import nnclass GHM_Loss(nn.Module):def __init__(self, bins, alpha, device, is_split_batch=True):super(GHM_Loss, self).__init__()self._bins = binsself._alpha = alphaself._last_bin_count = Noneself._device = deviceself.is_split_batch = is_split_batchself.is_evaluation = Falsedef set_evaluation(self, is_evaluation):"""评估时可以(也可以不管)将is_evaluation设为True,训练时设为False,这样就类似于直接计算CEL:param is_evaluation: bool:return:"""self.is_evaluation = is_evaluationdef _g2bin(self, g, bin):return torch.floor(g * (bin - 0.0001)).long()def _custom_loss(self, x, target, weight):raise NotImplementedErrordef _custom_loss_grad(self, x, target):raise NotImplementedErrordef use_alpha(self, bin_count):if (self._alpha != 0):if (self.is_evaluation):if (self._last_bin_count == None):self._last_bin_count = bin_countelse:bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_countself._last_bin_count = bin_countreturn bin_countdef forward(self, x, target):""":param x: torch.Tensor,[B,C,*]:param target: torch.Tensor,[B,*]:return: loss"""g = torch.abs(self._custom_loss_grad(x, target)).detach()weight = torch.zeros((x.size(0), x.size(2), x.size(3)))if self.is_split_batch:#是否对每个batch分开统计梯度,我实验时发现分开统计loss会更容易收敛,可能因为模型中用了batch normalization?N = x.size(2) * x.size(3)bin = (int)(N // self._bins)bin_idx = self._g2bin(g, bin)bin_idx = torch.clamp(bin_idx, max=bin - 1)bin_count = torch.zeros((x.size(0), bin))for i in range(x.size(0)):bin_count[i] = torch.from_numpy(np.bincount(torch.flatten(bin_idx[i].cpu()), minlength=bin))bin_count[i] *= (bin_count[i] > 0).sum()bin_count = self.use_alpha(bin_count)gd = torch.clamp(bin_count, min=1)beta = N * 1.0 / gdfor i in range(x.size(0)):weight[i] = beta[i][bin_idx[i]]else:N = x.size(0) * x.size(2) * x.size(3)bin = (int)(N // self._bins)bin_idx = self._g2bin(g, bin)bin_idx = torch.clamp(bin_idx, max=bin - 1)bin_count = torch.from_numpy(np.bincount(torch.flatten(bin_idx.cpu()), minlength=bin))bin_count *= (bin_count > 0).sum()bin_count = self.use_alpha(bin_count)gd = torch.clamp(bin_count, min=1)beta = N * 1.0 / gdweight = beta[bin_idx]return self._custom_loss(x, target, weight)class GHMC_Loss(GHM_Loss):def __init__(self, bins, alpha, device, num_classes, ignore_classes=None, class_weights=None, is_split_batch=True):""":param bins: int。不是bin,这里将取数据[B,C,X,Y]的size计算bin=[B*]X*Y,B不一定乘:param alpha: float。:param device::param num_classes: int。分类数量。:param ignore_classes: [int]。不计算的:param class_weights: torch.Tensor,每个类型的权重:param is_split_batch: bool,是否分离batch统计"""super(GHMC_Loss, self).__init__(bins, alpha, device, is_split_batch)self.num_classes = num_classesself.ignore_classes = ignore_classesself.class_weights = class_weightsdef _custom_loss(self, x, target, weight):"""计算loss:param x: torch.Tensor,[B,C,*]:param target: torch.Tensor,[B,*]:param weight: torch.Tensor,[B,C,*]:return: loss"""if (self.is_evaluation):return torch.mean((torch.nn.NLLLoss(weight=self.class_weights, reduction='none')(torch.log_softmax(x, 1), target)))else:return torch.mean((torch.nn.NLLLoss(weight=self.class_weights, reduction='none')(torch.log_softmax(x, 1), target)).mul(weight.to(self._device).detach()))def _custom_loss_grad(self, x, target):"""统计梯度:param x: torch.Tensor,[B,C,*]:param target: torch.Tensor,[B,*]:return: 梯度信息"""g = (torch.softmax(x, 1).detach() - make_one_hot(target.unsqueeze(1), self.num_classes).to(self._device)). \gather(1, target.unsqueeze(1)).squeeze(1)if self.ignore_classes != None:a = torch.tensor(0.0, dtype=torch.float32).to(self._device)for class_id in self.ignore_classes:g = torch.where(target != class_id, g, a)return gdef make_one_hot(input, num_classes):"""Convert class index tensor to one hot encoding tensor.Args:input: A tensor of shape [N, 1, *]num_classes: An int of number of classReturns:A tensor of shape [N, num_classes, *]"""# input=torch.squeeze(input,dim=-1)shape = np.array(input.shape)shape[1] = num_classesshape = tuple(shape)result = torch.zeros(shape)result = result.scatter_(1, input.cpu(), 1)return result

GHMC_Loss pytorch实现相关推荐

  1. OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)

    https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA 综述:解决目标检测中的样本不均衡问题 该综述主要介绍了OHEM,Focal loss,GHM los ...

  2. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  3. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  4. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  5. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

  6. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  7. API pytorch tensorflow

    pytorch与tensorflow API速查表 方法名称 pytroch tensorflow numpy 裁剪 torch.clamp(x, min, max) tf.clip_by_value ...

  8. tensor转换 pytorch tensorflow

    一.tensorflow的numpy与tensor互转 1.数组(numpy)转tensor 利用tf.convert_to_tensor(numpy),将numpy转成tensor >> ...

  9. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

最新文章

  1. php模态窗口,php – 如何在yii2中的模态窗口中使用pjax更新小部件
  2. CentOS7中MariaDB重置密码
  3. HealthKit开发快速入门教程之HealthKit框架体系创建健康AppID
  4. invalid dts/pts combination
  5. Python面试题-朋友昨天去面试,这5个Python面试题都被考到了,太神奇了!
  6. CF962E Byteland, Berland and Disputed Cities
  7. 比Redis快5倍的中间件,究竟为什么这么快?
  8. 关于GitHub如何转为中文问题——Google举例
  9. Chrome浏览器打不开网页,连设置都打不开的解决办法
  10. 搭建自己的以图搜图系统 (一):10 行代码以图搜图
  11. Excel 2010 SQL应用052 将英文字母转换为小写字母
  12. 习题8-4 报数 (20分)
  13. 计算机二级编程题题库
  14. 华硕笔记本系统重装之后需要输入用户名和计算机名称是怎么回事,华硕笔记本电脑重装系统【方法详解】...
  15. [redis] 10 种数据结构详解
  16. 分享应用于桌面闹钟的超低成本MG127蓝牙射频前端芯片
  17. linux 蓝牙发送文件,如何在Ubuntu上使用蓝牙进行文件传输
  18. 英国大不列颠百科全书_大不列颠计划通过社区编辑接受维基百科
  19. 华附计算机学神,【学习】时隔13年,华附两牛娃杀进奥数国家队,父母亲述学霸成长史!...
  20. oracle输出实心三角型,C语言帕斯卡三角形打印示例

热门文章

  1. Linux tcp sack_reneging分析
  2. Pytorch实现Top1准确率和Top5准确率
  3. OpenCV UMat类 使用GPU运算
  4. C语言不带头结点的单链表
  5. 修改AD域ladp连接数
  6. npm ERR! Invalid tag name “@vue-cli“: Tags may not have any characters that encodeURIComponent encod
  7. 关于Knuth Shuffle算法
  8. 对于拼多多商家店铺什么最重要?|一度智信
  9. Http客户端请求工具-RestTemplate
  10. 将H5网站转换成原生体验的App