一、如果不出错的话

参考链接:https://github.com/jacobgil/pytorch-grad-cam

1、 先将此github源码clone到本地2、 参考pytorch-grad-cam/tutorials/Class Activation Maps for Semantic Segmentation.ipynb3、 把包都导好。4、注意推理做的归一化与标准化跟自己训练的时候弄成一样的
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from torchvision.models.segmentation import deeplabv3_resnet50
import torch
import torch.functional as F
import numpy as np
import requests
import torchvision
from PIL import Image
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from models.model_stages_double import BiSeNet
from pytorch_grad_cam.grad_cam import GradCAM# 读入自己的图像
image = np.array(Image.open('/media/wlj/soft_D/WLJ/WJJ/STDC-Seg/camera_4_crop/leftImg8bit/test/nok/NoK_4_leftImg8bit.png'))
rgb_img = np.float32(image) / 255
input_tensor = preprocess_image(rgb_img,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])# 读入自己的模型并且加载训练好的权重
model = BiSeNet(backbone='STDCNet813',n_classes=6)
model.cuda()
model = model.eval()
save_pth = '/media/wlj/soft_D/WLJ/WJJ/STDC-Seg/checkpoints/camera_4_crop/batch8_11.2_15000it_dublebaseline_left1xSGE2345_right0.5x_DFConv2_SGE3_RGB/model_maxmIOU100.pth'
model.load_state_dict(torch.load(save_pth))if torch.cuda.is_available():model = model.cuda()input_tensor = input_tensor.cuda()# 推理
output = model(input_tensor)[0]
normalized_masks = torch.softmax(output, dim=1).cpu()# 自己的数据集的类别
sem_classes = ['__background__', 'round', 'nok', 'headbroken', 'headdeep', 'shoulderbroken'
]sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
round_category = sem_class_to_idx["nok"]
round_mask = torch.argmax(normalized_masks[0], dim=0).detach().cpu().numpy()
round_mask_uint8 = 255 * np.uint8(round_mask == round_category)
round_mask_float = np.float32(round_mask == round_category)# 推理结果图与原图拼接
# both_images = np.hstack((image, np.repeat(round_mask_uint8[:, :, None], 3, axis=-1)))
# img = Image.fromarray(both_images)
# img.save("./hhhh.png")class SemanticSegmentationTarget:def __init__(self, category, mask):self.category = categoryself.mask = torch.from_numpy(mask)if torch.cuda.is_available():self.mask = self.mask.cuda()def __call__(self, model_output):return (model_output[self.category, :, :] * self.mask).sum()# 自己要放CAM的位置
target_layers = [model.conv_out]
targets = [SemanticSegmentationTarget(round_category, round_mask_float)]with GradCAM(model=model, target_layers=target_layers,use_cuda=torch.cuda.is_available()) as cam:grayscale_cam = cam(input_tensor=input_tensor,targets=targets)[0, :]cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)# 保存CAM的结果
img = Image.fromarray(cam_image)
img.show()
img.save('./result.png')

二、可能出错

我遇到了 如下错误

解决方法:

将base_cam.py的第81行修改为:

就不报错了!

拿下!

grad-CAM用于自己的语义分割网络【亲测】相关推荐

  1. SegNet 语义分割网络以及其变体 基于贝叶斯后验推断的 SegNet

    HomePage: http://mi.eng.cam.ac.uk/projects/segnet/ SegNet Paper: https://www.computer.org/csdl/trans ...

  2. 【24】搭建FCN语义分割网络完成自己数据库图像分割(1)

    [1]batchimageprocess.py #批量图片处理.改名字.改类型 #!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @File : ...

  3. CVPR2020:4D点云语义分割网络(SpSequenceNet)

    CVPR2020:4D点云语义分割网络(SpSequenceNet) SpSequenceNet: Semantic Segmentation Network on 4D Point Clouds 论 ...

  4. 一块GPU就能训练语义分割网络,百度PaddlePaddle是如何优化的?

    [引言]显存不足是训练语义分割网络常常遇见的问题,而显存是GPU计算中的稀缺资源.百度深度学习框架PaddlePaddle中的显存优化,不仅可以让研究人员在相同成本的计算设备上训练更大的模型,还可以在 ...

  5. 北航、旷视联合,打造最强实时语义分割网络

    来源:AI科技评论 编辑:Camel 导语:MSFNet在Cityscapes测试集上达到77.1%mIoU/41FPS(注意是1024*2048),在Camvid测试集上达到75.4 mIoU/97 ...

  6. 深度学习:语义分割网络

    1.考虑采用实例分割或语义分割用于烟支打孔的内孔和外孔的边缘提取上 2.https://www.cnblogs.com/zxj9487/p/11154316.html 直接可以用的Python和Ope ...

  7. Fast-SCNN:多分支结构共享低级特征的语义分割网络

    介绍一篇 BMVC 2019 语义分割论文 Fast-SCNN:Fast Semantic Segmentation Network,谷歌学术显示该文已有62次引用. 论文:https://arxiv ...

  8. 复杂背景下计算机视觉模型害虫识别的比较研究(像素语义分割网络SegNet)

    Abstract 农业被认为是世界各国的经济基础,新技术的发展有助于提高收获效率.自动驾驶汽车在农场用于播种.收获和施用农药等任务.然而,任何一个种植园的主要问题之一是害虫和疾病的鉴定,这对害虫控制和 ...

  9. 计算机视觉算法——语义分割网络总结

    计算机视觉算法--语义分割网络总结 计算机视觉算法--语义分割网络总结 1. FCN 1.1 关键知识点--网络结构及特点 1.2 关键知识点--转置卷积 1.3 关键知识点--语义分割评价指标 2. ...

最新文章

  1. iOS 9应用开发教程之使用开关滑块控件以及滚动部署视图
  2. JVM编译时和运行时状态
  3. 开发项目之考研计划_软件测试之项目测试计划模板
  4. java与数据库的数据交互,Java与数据库初步交互(后续需要进行优化)
  5. py2exe使用方法 (含一些调试技巧,如压缩email 类)(转)
  6. 有关Kubernetes监控的4大常见陷阱,注意避免!
  7. tar命令打包并删除原文件
  8. JSON怎么转成Excel
  9. thinkpad T500开机大于10分钟,黑屏
  10. 怎么修改电脑的ip地址
  11. A graph auto-encoder model for miRNA-disease associations prediction 论文解析
  12. Holder类的作用
  13. java 6u45 no sni 2_sjscxz.taobao.com
  14. windchill安全标签客制化
  15. Java 8: 元空间(Metaspace)
  16. QT UI控件和事件
  17. 用python画美国国旗
  18. 你都这么拼了,面试官TM怎么还是无动于衷?
  19. 一站易购邀请好友第一届PK赛打响,拿千元现金奖励
  20. 规范精细化管理是企业死亡的开始?

热门文章

  1. OllyDBG 入门系列5 消息断点及 RUN 跟踪
  2. 从jieba分词到BERT-wwm——中文自然语言处理(NLP)基础分享系列(7)
  3. Discuz! 6.1 - 自动禁止非公开版面向Home推送事件
  4. 实现QTableWidget表头带Checkbox
  5. 【慕伏白教程】在Vmware中安装Ubuntu流程
  6. Lesson 1 Nehe
  7. 【halcon专题】离散序列的1d函数
  8. 安装卸载EMBY,jellyfin
  9. 初学C语言:计算身体质量指数 BMI,从键盘输入身高(m)和体重(kg),计算身体质量指数 BMI,其公式为: BMI = 体重 / 身高的平方。
  10. 【病虫害识别】基于支持向量机SVM的病虫害识别系统附GUI界面