如下所示:

# coding: utf-8

from __future__ import print_function

import copy

import click

import cv2

import numpy as np

import torch

from torch.autograd import Variable

from torchvision import models, transforms

import matplotlib.pyplot as plt

import load_caffemodel

import scipy.io as sio

# if model has LSTM

# torch.backends.cudnn.enabled = False

imgpath = "D:/ck/files_detected_face224/"

imgname = "S055_002_00000025.png" # anger

image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]

std_file = [0.229, 0.224, 0.225]

raw_image = cv2.imread(image_path)[..., ::-1]

print(raw_image.shape)

raw_image = cv2.resize(raw_image, (224, ) * 2)

image = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(

mean=mean_file,

std =std_file,

#mean = mean_file,

#std = std_file,

)

])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()

convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W

convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))

convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C

print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1

err = np.max(diff)

print(err)

plt.imshow(np.uint8(convert_image1))

plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]

std_file = [0.229, 0.224, 0.225]

# coding: utf-8

import matplotlib.pyplot as plt

import argparse

import os

import numpy as np

import torchvision

import torchvision.transforms as transforms

dataset_names = ("cifar10","cifar100","mnist")

parser = argparse.ArgumentParser(description="PyTorchLab")

parser.add_argument("-d", "--dataset", metavar="DATA", default="cifar10", choices=dataset_names,

help="dataset to be used: " + " | ".join(dataset_names) + " (default: cifar10)")

args = parser.parse_args()

data_dir = os.path.join(".", args.dataset)

print(args.dataset)

args.dataset = "cifar10"

if args.dataset == "cifar10":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(train_set.train_data.shape)

print(train_set.train_data.mean(axis=(0,1,2))/255)

print(train_set.train_data.std(axis=(0,1,2))/255)

# imshow image

train_data = train_set.train_data

ind = 100

img0 = train_data[ind,...]

## test channel number, in total , the correct channel is : RGB,not like BGR in caffe

# error produce

#b,g,r=cv2.split(img0)

#img0=cv2.merge([r,g,b])

print(img0.shape)

print(type(img0))

plt.imshow(img0)

plt.show() # in ship in sea

#img0 = cv2.resize(img0,(224,224))

#cv2.imshow("img0",img0)

#cv2.waitKey()

elif args.dataset == "cifar100":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(train_set.train_data.shape)

print(np.mean(train_set.train_data, axis=(0,1,2))/255)

print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(list(train_set.train_data.size()))

print(train_set.train_data.float().mean()/255)

print(train_set.train_data.float().std()/255)

结果:

cifar10

Files already downloaded and verified

(50000, 32, 32, 3)

[ 0.49139968 0.48215841 0.44653091]

[ 0.24703223 0.24348513 0.26158784]

(32, 32, 3)

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load("cifar10_train_data.mat");

train_data = data.train_data;

disp(size(train_data));

temp = mean(train_data,1);

disp(size(temp));

train_data = double(train_data);

% compute mean_file

mean_val = mean(mean(mean(train_data,1),2),3)/255;

% compute std_file

temp1 = train_data(:,:,:,1);

std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);

std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);

std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);

std_val = [std_val1, std_val2, std_val3];

disp(mean_val);

disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]

% std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持云海天教程。

pytorch 对特征进行mean_Pytorch的mean和std调查实例相关推荐

  1. pytorch 对特征进行mean_pytorch常用normalization函数

    参考:https://blog.csdn.net/liuxiao214/article/details/81037416 归一化层,目前主要有这几个方法,Batch Normalization(201 ...

  2. 可视化pytorch网络特征图

    0. 背景 在目标检测任务中,我们会使用多尺度的特征图进行预测,背后的常识是:浅层特征图包含丰富的边缘信息有利于定位小目标,高层特征图中包含大量的语义信息有利于大目标的定位和识别.为了进一步了解特征图 ...

  3. torch标准化_计算pytorch标准化(Normalize)所需要数据集的均值和方差实例

    pytorch做标准化利用transforms.Normalize(mean_vals, std_vals),其中常用数据集的均值方差有: if 'coco' in args.dataset: mea ...

  4. CPU上跑到 100 fps 的高精度PyTorch人脸特征点检测库

      视学算法分享   作者 | cunjian 编译 | CV君 转自 | 我爱计算机视觉 [导读]向大家推荐一款基于PyTorch实现的快速高精度人脸特征点检测库,其在CPU上的运行速度可达100 ...

  5. pytorch以特征图的输入方式训练LSTM模型

    训练的时候总是会遇到这样的任务: 特征有很多维度,每个维度都有相同的embedding或长度 训练的时候想按照类似这样的二维图,训练LSTM模型,最后得出这张图对应的1个或多个结果 文章目录 步骤一: ...

  6. pytorch实现特征图可视化,代码简洁,包教包会

    是不是要这样的效果 技术要点 1.选择一层网络,将图片的tensor放进去 2.将网络的输出plt.imshow 代码可直接复制使用,需要改的就是你的图片位置 import torch from to ...

  7. pytorch 神经网络特征可视化

    可参考博客 Pytorch可视化模型任意中间层的类激活热力图(Grad-CAM)_潜行隐耀的博客-CSDN博客_pytorch热力图 Pytorch输出网络中间层特征可视化_Joker-Tong的博客 ...

  8. PyTorch 可视化特征

    这个也可以参考: https://blog.csdn.net/LEILEI18A/article/details/80389229 这篇博客主要记录了如何提取特定层的特征,然后对它进行可视化 二 主要 ...

  9. pytorch根据特征图训练LSTM Stacked AutoEncoder

    文章目录 步骤一:构造训练数据 步骤二:构造LSTM模型 构造三:构建训练数据 构造建模三件套 步骤四:训练模型并保存 全部代码 步骤一:构造训练数据 def get_train_data(clust ...

最新文章

  1. 深度学习目标检测(object detection)系列(一) R-CNN
  2. hihocoder #1333 : 平衡树·Splay2
  3. Linux程序包管理
  4. AutoLayout框架之序言
  5. 每天一小时python官方文档学习(四)————数据结构之列表
  6. jenkins未授权访问漏洞记录(端口:7001,80,8080,50000)
  7. 主定理(Master Theorem)与时间复杂度
  8. 为什么c语言一用windows.h就报错_C代码里面加一行网址依然可以运行,并不会报错,为何...
  9. 虚拟化qemu-img的简单用法。
  10. python 原理 pdf_《深入浅出深度学习:原理剖析与Python实践》.pdf
  11. SM3算法 C语言 (从OpenSSL库中分离算法:六)
  12. 微信卡包系列-核销微信卡券优惠券
  13. Chrome应用商店打不开问题
  14. 从源码分析Redis分布式锁的原子性保证
  15. 使用Map集合来做一个不同姓氏人数的统计 有一个String数组保存着10个人的姓名{“张三“,“李四“,“王二“...} 通过程序设计,把不同姓氏的姓氏和人数保存到Map集合中
  16. 学习PPT,这些制作设计技巧需先掌握
  17. java编辑遗忘曲线代码_通过excel vba 实现艾宾浩斯遗忘曲线的复习提醒
  18. redis的过期键删除策略
  19. 安装和卸载.deb包
  20. 长尾分布,重尾分布(Heavy-tailed Distribution)

热门文章

  1. AWT_Swing_JTextField (Java)
  2. ReactNative 启动js server报错:Metro Bundler can't listen on port 8081
  3. THREEJS - 点击/拾取
  4. mysql 5.7参数配置_MySQL 5.7-新增配置参数
  5. python入门系列:迭代器和生成器
  6. SLS机器学习介绍(05):时间序列预测
  7. win10屏蔽自动更新方法
  8. php中的三元运算符
  9. 【Computer Organization笔记22】虚拟存储器:段式存储,页式存储
  10. 二维平面坐标系中,判断某点是否在正六边形内 | python 实现 + 数学推导(已知正六边形六个顶点坐标)