pytorch 对特征进行mean_Pytorch的mean和std调查实例
如下所示:
# coding: utf-8
from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms
import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio
# if model has LSTM
# torch.backends.cudnn.enabled = False
imgpath = "D:/ck/files_detected_face224/"
imgname = "S055_002_00000025.png" # anger
image_path = imgpath + imgname
mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=mean_file,
std =std_file,
#mean = mean_file,
#std = std_file,
)
])(raw_image).unsqueeze(0)
print(image.shape)
convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)
convert_image1 = convert_image1 * 255
diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()
结论:
input_image = (raw_image / 255 - mean) ./ std
下面调查均值文件和方差文件是如何生成的:
mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms
dataset_names = ("cifar10","cifar100","mnist")
parser = argparse.ArgumentParser(description="PyTorchLab")
parser.add_argument("-d", "--dataset", metavar="DATA", default="cifar10", choices=dataset_names,
help="dataset to be used: " + " | ".join(dataset_names) + " (default: cifar10)")
args = parser.parse_args()
data_dir = os.path.join(".", args.dataset)
print(args.dataset)
args.dataset = "cifar10"
if args.dataset == "cifar10":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(train_set.train_data.shape)
print(train_set.train_data.mean(axis=(0,1,2))/255)
print(train_set.train_data.std(axis=(0,1,2))/255)
# imshow image
train_data = train_set.train_data
ind = 100
img0 = train_data[ind,...]
## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
# error produce
#b,g,r=cv2.split(img0)
#img0=cv2.merge([r,g,b])
print(img0.shape)
print(type(img0))
plt.imshow(img0)
plt.show() # in ship in sea
#img0 = cv2.resize(img0,(224,224))
#cv2.imshow("img0",img0)
#cv2.waitKey()
elif args.dataset == "cifar100":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(train_set.train_data.shape)
print(np.mean(train_set.train_data, axis=(0,1,2))/255)
print(np.std(train_set.train_data, axis=(0,1,2))/255)
elif args.dataset == "mnist":
train_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
#print(vars(train_set))
print(list(train_set.train_data.size()))
print(train_set.train_data.float().mean()/255)
print(train_set.train_data.float().std()/255)
结果:
cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
使用matlab检测是如何计算mean_file和std_file的:
% load cifar10 dataset
data = load("cifar10_train_data.mat");
train_data = data.train_data;
disp(size(train_data));
temp = mean(train_data,1);
disp(size(temp));
train_data = double(train_data);
% compute mean_file
mean_val = mean(mean(mean(train_data,1),2),3)/255;
% compute std_file
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;
temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;
temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;
mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];
disp(mean_val);
disp(std_val);
% result: mean_val: [0.4914, 0.4822, 0.4465]
% std_val: [0.2470, 0.2435, 0.2616]
均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:
mean_val = mean(mean(mean(train_data,1),2),3)/255;
标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。
以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持云海天教程。
pytorch 对特征进行mean_Pytorch的mean和std调查实例相关推荐
- pytorch 对特征进行mean_pytorch常用normalization函数
参考:https://blog.csdn.net/liuxiao214/article/details/81037416 归一化层,目前主要有这几个方法,Batch Normalization(201 ...
- 可视化pytorch网络特征图
0. 背景 在目标检测任务中,我们会使用多尺度的特征图进行预测,背后的常识是:浅层特征图包含丰富的边缘信息有利于定位小目标,高层特征图中包含大量的语义信息有利于大目标的定位和识别.为了进一步了解特征图 ...
- torch标准化_计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
pytorch做标准化利用transforms.Normalize(mean_vals, std_vals),其中常用数据集的均值方差有: if 'coco' in args.dataset: mea ...
- CPU上跑到 100 fps 的高精度PyTorch人脸特征点检测库
视学算法分享 作者 | cunjian 编译 | CV君 转自 | 我爱计算机视觉 [导读]向大家推荐一款基于PyTorch实现的快速高精度人脸特征点检测库,其在CPU上的运行速度可达100 ...
- pytorch以特征图的输入方式训练LSTM模型
训练的时候总是会遇到这样的任务: 特征有很多维度,每个维度都有相同的embedding或长度 训练的时候想按照类似这样的二维图,训练LSTM模型,最后得出这张图对应的1个或多个结果 文章目录 步骤一: ...
- pytorch实现特征图可视化,代码简洁,包教包会
是不是要这样的效果 技术要点 1.选择一层网络,将图片的tensor放进去 2.将网络的输出plt.imshow 代码可直接复制使用,需要改的就是你的图片位置 import torch from to ...
- pytorch 神经网络特征可视化
可参考博客 Pytorch可视化模型任意中间层的类激活热力图(Grad-CAM)_潜行隐耀的博客-CSDN博客_pytorch热力图 Pytorch输出网络中间层特征可视化_Joker-Tong的博客 ...
- PyTorch 可视化特征
这个也可以参考: https://blog.csdn.net/LEILEI18A/article/details/80389229 这篇博客主要记录了如何提取特定层的特征,然后对它进行可视化 二 主要 ...
- pytorch根据特征图训练LSTM Stacked AutoEncoder
文章目录 步骤一:构造训练数据 步骤二:构造LSTM模型 构造三:构建训练数据 构造建模三件套 步骤四:训练模型并保存 全部代码 步骤一:构造训练数据 def get_train_data(clust ...
最新文章
- 深度学习目标检测(object detection)系列(一) R-CNN
- hihocoder #1333 : 平衡树·Splay2
- Linux程序包管理
- AutoLayout框架之序言
- 每天一小时python官方文档学习(四)————数据结构之列表
- jenkins未授权访问漏洞记录(端口:7001,80,8080,50000)
- 主定理(Master Theorem)与时间复杂度
- 为什么c语言一用windows.h就报错_C代码里面加一行网址依然可以运行,并不会报错,为何...
- 虚拟化qemu-img的简单用法。
- python 原理 pdf_《深入浅出深度学习:原理剖析与Python实践》.pdf
- SM3算法 C语言 (从OpenSSL库中分离算法:六)
- 微信卡包系列-核销微信卡券优惠券
- Chrome应用商店打不开问题
- 从源码分析Redis分布式锁的原子性保证
- 使用Map集合来做一个不同姓氏人数的统计 有一个String数组保存着10个人的姓名{“张三“,“李四“,“王二“...} 通过程序设计,把不同姓氏的姓氏和人数保存到Map集合中
- 学习PPT,这些制作设计技巧需先掌握
- java编辑遗忘曲线代码_通过excel vba 实现艾宾浩斯遗忘曲线的复习提醒
- redis的过期键删除策略
- 安装和卸载.deb包
- 长尾分布,重尾分布(Heavy-tailed Distribution)
热门文章
- AWT_Swing_JTextField (Java)
- ReactNative 启动js server报错:Metro Bundler can't listen on port 8081
- THREEJS - 点击/拾取
- mysql 5.7参数配置_MySQL 5.7-新增配置参数
- python入门系列:迭代器和生成器
- SLS机器学习介绍(05):时间序列预测
- win10屏蔽自动更新方法
- php中的三元运算符
- 【Computer Organization笔记22】虚拟存储器:段式存储,页式存储
- 二维平面坐标系中,判断某点是否在正六边形内 | python 实现 + 数学推导(已知正六边形六个顶点坐标)