pytorch实现图像RGB和HSV色彩空间的相互转换,可直接加入神经网络中,支持反向传播,支持cuda。

  今天在设计一个pytorch神经网络结构时需要把RGB图像转换到HSV空间,因为HSV空间更适合做一些色彩平滑过渡的图像渐变处理。因此我需要一个pytorch版的RGB和HSV相互转换函数,而且要求它可微,即可通过反向传播计算梯度。在github上找了一个"Differentiable-RGB-to-HSV-convertion-pytorch",然而这个代码中的HSV-to-RGB部分是不能用的,所以我补充了后部分,作为一个完整功能分享出来并备忘,以后再用到的时候可方便的找到。

一、代码

"""
Pytorch implementation of RGB convert to HSV, and HSV convert to RGB,
RGB or HSV's shape: (B * C * H * W)
RGB or HSV's range: [0, 1)
"""
import torch
from torch import nnclass RGB_HSV(nn.Module):def __init__(self, eps=1e-8):super(RGB_HSV, self).__init__()self.eps = epsdef rgb_to_hsv(self, img):hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(img.device)hue[ img[:,2]==img.max(1)[0] ] = 4.0 + ( (img[:,0]-img[:,1]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,2]==img.max(1)[0] ]hue[ img[:,1]==img.max(1)[0] ] = 2.0 + ( (img[:,2]-img[:,0]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,1]==img.max(1)[0] ]hue[ img[:,0]==img.max(1)[0] ] = (0.0 + ( (img[:,1]-img[:,2]) / ( img.max(1)[0] - img.min(1)[0] + self.eps) ) [ img[:,0]==img.max(1)[0] ]) % 6hue[img.min(1)[0]==img.max(1)[0]] = 0.0hue = hue/6saturation = ( img.max(1)[0] - img.min(1)[0] ) / ( img.max(1)[0] + self.eps )saturation[ img.max(1)[0]==0 ] = 0value = img.max(1)[0]hue = hue.unsqueeze(1)saturation = saturation.unsqueeze(1)value = value.unsqueeze(1)hsv = torch.cat([hue, saturation, value],dim=1)return hsvdef hsv_to_rgb(self, hsv):h,s,v = hsv[:,0,:,:],hsv[:,1,:,:],hsv[:,2,:,:]#对出界值的处理h = h%1s = torch.clamp(s,0,1)v = torch.clamp(v,0,1)r = torch.zeros_like(h)g = torch.zeros_like(h)b = torch.zeros_like(h)hi = torch.floor(h * 6)f = h * 6 - hip = v * (1 - s)q = v * (1 - (f * s))t = v * (1 - ((1 - f) * s))hi0 = hi==0hi1 = hi==1hi2 = hi==2hi3 = hi==3hi4 = hi==4hi5 = hi==5r[hi0] = v[hi0]g[hi0] = t[hi0]b[hi0] = p[hi0]r[hi1] = q[hi1]g[hi1] = v[hi1]b[hi1] = p[hi1]r[hi2] = p[hi2]g[hi2] = v[hi2]b[hi2] = t[hi2]r[hi3] = p[hi3]g[hi3] = q[hi3]b[hi3] = v[hi3]r[hi4] = t[hi4]g[hi4] = p[hi4]b[hi4] = v[hi4]r[hi5] = v[hi5]g[hi5] = p[hi5]b[hi5] = q[hi5]r = r.unsqueeze(1)g = g.unsqueeze(1)b = b.unsqueeze(1)rgb = torch.cat([r, g, b], dim=1)return rgb

二、验证

matplotlib.colors中也有rgb和hsv相互转换的代码,我们用它和我上面的代码对比:

import torch
import cv2
import matplotlib.pyplot as plt
from rgb_hsv import RGB_HSV
import matplotlib.colors as mcolorsimg = cv2.imread('../images/0.jpg')
rgb = img[:,:,::-1]  #注意opencv是BGR顺序,必须转换成RGB
rgb = rgb / 255rgb_tensor = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).float()
convertor = RGB_HSV()hsv_tensor = convertor.rgb_to_hsv(rgb_tensor)
rgb1 = convertor.hsv_to_rgb(hsv_tensor)hsv_arr = hsv_tensor[0].permute(1,2,0).numpy()
rgb1_arr = rgb1[0].permute(1,2,0).numpy()hsv_m = mcolors.rgb_to_hsv(rgb)
rgb1_m = mcolors.hsv_to_rgb(hsv_m)print('mse of my code and matplotlib:',((rgb1_arr - rgb)**2).mean())
plt.figure()
plt.imshow(rgb)
plt.title('origin image')
plt.figure()
plt.imshow(hsv_arr)
plt.title('visual to hsv')
plt.figure()
plt.imshow(rgb1_arr)
plt.title('convert back: my code')
plt.figure()
plt.imshow(rgb1_m)
plt.title('convert back: matplotlib method')

打印出的mse是非常接近0的一个小浮点数(因加入的防除零的eps导致)。画出的转换效果图如下。可见和matplotlib结果是一样的,证实代码没有问题。


在神经网络中加入此代码后也证实确实可以反向传播,可以在cuda上运行,代码略。

pytorch版 RGB_to_HSV和HSV_to_RGB相关推荐

  1. PyTorch 版 EfficientDet 比官方 TF 实现快 25 倍?这个 GitHub 项目数天狂揽千星

    来源:机器之心 本文约3646字,建议阅读8分钟. 本文介绍在 Github 项目中,开发者 zylo117 开源了 PyTorch 版本的 EfficientDet,速度比原版高 20 余倍.如今, ...

  2. PyTorch版EfficientDet比官方TF实现快25倍?这个GitHub项目数天狂揽千星

    点上方蓝字视学算法获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 编辑:Sophia 计算机视觉联盟  报道  | 公众号 CVLianMeng 转载于 :机器之心 EfficientDe ...

  3. 全网第一SoTA成绩却朴实无华的PyTorch版EfficientDet

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文作者:Zylo117 https://zhuanlan.zhih ...

  4. 《动手学深度学习》PyTorch版GitHub资源

    之前,偶然间看到过这个PyTorch版<动手学深度学习>,当时留意了一下,后来,着手学习pytorch,发现找不到这个资源了.今天又看到了,赶紧保存下来. <动手学深度学习>P ...

  5. 364 页 PyTorch 版《动手学深度学习》PDF 开源了(全中文,支持 Jupyter 运行)

    点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 李沐,亚马逊 AI 主任科学家,名声在外!半年前,由李沐.Aston Zhang 等人合力打造 ...

  6. gorm 密码字段隐藏_【财富密码】第1期:《LSTM大战上证指数-PyTorch版》

    前言: Hello大家好,我是瑟林洞仙人!这里是[财富密码]系列第1期:<LSTM大战上证指数-PyTorch版>.在这里,我将用我的"意识流"代码,手把手教会大家如何 ...

  7. 364 页 PyTorch 版《动手学深度学习》分享(全中文,支持 Jupyter 运行)

    1 前言 最近有朋友留言要求分享一下李沐老师的<动手学深度学习>,小汤本着一直坚持的"好资源大家一起分享,共同学习,共同进步"的初衷,于是便去找了资料,而且还是中文版的 ...

  8. 最强NLP模型BERT喜迎PyTorch版!谷歌官方推荐,也会支持中文

    郭一璞 夏乙 发自 凹非寺  量子位 报道 | 公众号 QbitAI 谷歌的最强NLP模型BERT发布以来,一直非常受关注,上周开源的官方TensorFlow实现在GitHub上已经收获了近6000星 ...

  9. Step by Step演示如何训练Pytorch版的EfficientDet

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx Paper:https://arxiv.org/abs/1911.09070 Base Git ...

最新文章

  1. Pycharm+Anacond安装完成后的Python文件创建以及No module named 'bs4'.
  2. jquery找祖先包含_jquery如何获取祖先元素
  3. 今天rpm装glibc和glibc-common版本,出现二个包相互依赖,解决办法
  4. QPS、TPS、PV、UV、GMV、IP、RPS?
  5. 练习四十四:整数的排序
  6. 技术和技术管理人员评价标准
  7. 微型计算机原理及应用是啥,微型计算机原理及应用技术(第3版)
  8. windows安装caffe
  9. 头条小程序服务器设置,今日头条小程序怎么开发?如何注册申请
  10. 智慧图书馆管理系统提升服务水平和工作效率
  11. 图书管理系统——用例图、类图、时序图
  12. 千千静听皮肤急速合成器
  13. java 存根,使用mockito使用三个参数对方法进行存根
  14. Layui数据表格添加时间控件
  15. linux的PS3模拟器下载,PS3模拟器
  16. 使用wine在mac系统上运行windows程序
  17. scrapy下载斗鱼主播图片
  18. A算法和A*算法详解
  19. 【LaTex】 - 对齐符号的用法,换行符\\的用法,Misplaced 错误怎么解决
  20. Ardunio开发实例-数字温度传感器

热门文章

  1. 李宏毅机器学习--self-supervised:BERT、GPT、Auto-encoder
  2. 怎样使微信中打开链接自动打开外部浏览器打开指定URL页面或者直接下载APP(安卓/苹果)文件
  3. CKA原英文考试2019年12月答案
  4. FFmpeg命令行--视频转码
  5. DNS与Bind基本配置实现
  6. ansible程序自动化
  7. 第02章 一个实例初识WorkBench分析流程-卡扣结构的动作分析
  8. 翻录cda文件_翻录电视连续剧DVD并转换为单独的H.264 MP4文件
  9. 中国电信的天翼宽带怎么样才能不用“中国电信无线宽带”客户端
  10. 转载:微信的智能心跳方案