在设计网络进行训练前必不可少的部分就是获取文件夹下的图片进行训练

持续更新.............

使用os模块:

参考:Python os.path() 模块 | 菜鸟教程 (runoob.com)

import osprint( os.path.basename('/root/runoob.txt') )   # 返回文件名————————out:runoob.txt
print( os.path.dirname('/root/runoob.txt') )    # 返回目录路径——————out:/root
print( os.path.split('/root/runoob.txt') )      # 分割文件名与路径——————out:('/root','runoob.txt')
print( os.path.join('root','test','runoob.txt') )  # 将目录和文件名合成一个路径——————out:root/test/runoob.txt

os.listdir() 方法 : 返回指定文件夹包含的文件或文件夹名字的列表。该列表顺序以字母排序。

import os
for name in os.listdir(root):print(name)

os.path.join() : 将多个路径组合后返回

#类中定义
self.namelabel={}
for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件夹if not os.path.isdir(os.path.join(root,name)):#listdir可能会把文件夹下文件都包含进来,先把文件过滤掉continueself.namelabel[name]=len(self.namelabel.keys())

存取csv文件,若无csv制作csv,有csv则直接读取

def load_csv(self,file):#若csv不存在,则需要创建if not os.path.exists(os.path.join(self.root,file)):images=[]for name in self.namelabel.keys():#glob模块,获取到文件夹下不同后缀的图片images+=glob.glob(os.path.join(self.root,name,'*png'))images+=glob.glob(os.path.join(self.root,name,'*jpg'))print(len(images),images)#此时images: 数据集文件夹\\不同类别文件夹\\00000000.png,排列完png后排列jpg例如dataset\\dog\\00000000.png#开始生成csv文件random.shuffle(images)with open(os.path.join(self.root,file),mode='w',newline='') as f:writer=csv.writer(f)for img in images:name=img.split(os.sep)[-2]#[-2]即不同类别文件夹名称,例如doglabel=self.namelabel[name]writer.writerow([img,label])print('write in:',file)"""mode='w':打开一个文件只用于写入"""#若csv文件存在,则进行读取images,labels=[],[]with open(os.path.join(self.root,file)) as f:reader=csv.reader(f)for row in reader:img,label=row    label=int(label)images.append(img)labels.append(label)assert len(images)==len(labels)return images,labels

可视化图片——利用visdom库

#利用Dataloader
import visdom
import time
#以Pokemon为例,数据集文件夹名:pokeman
viz=visdom.Visdom()
db=Pokemon('pokeman',64,'w')
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
#单图
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:#多图展示viz.images,单图viz.imageviz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch_y'))time.sleep(10)
#利用torchvision.datasets.ImageFolder
#需要文件夹的排列非常规范
tf = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)print(db.class_to_idx)

以宝可梦数据集为例:(全)

import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Imageclass Pokemon(Dataset):def __init__(self,root,resize,mode):super(Pokemon, self).__init__()self.root=rootself.resize=resizeself.name2label={} #定义字典,创建映射表for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件夹#listdir可能会把文件夹下文件都包含进来,先把文件过滤掉if not os.path.isdir(os.path.join(root,name)):continue# 保存在表中;将最长的映射作为最新的元素的label的值self.name2label[name]=len(self.name2label.keys())print(self.name2label)self.images,self.labels=self.load_csv('images.csv')#进行裁剪,划分训练集验证集与测试集if mode=='train':#60%self.images=self.images[:int(0.6*len(self.images))]self.labels=self.labels[:int(0.6*len(self.labels))]elif mode=='val':#20%self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else:#20%self.images = self.images[int(0.8 * len(self.images)):]self.labels = self.labels[int(0.8 * len(self.labels)):]def load_csv(self,filename):#此csv文件不存在则进行创建if not os.path.exists(os.path.join(self.root,filename)):images=[]for name in self.name2label.keys():images+=glob.glob(os.path.join(self.root,name,'*.png'))images+=glob.glob(os.path.join(self.root, name, '*.jpg'))print(len(images),images)random.shuffle(images)with open(os.path.join(self.root,filename),mode='w',newline='') as f:writer = csv.writer(f)for img in images:name=img.split(os.sep)[-2]label=self.name2label[name]writer.writerow([img,label])print('write in :',filename)#存在则进行读取images,labels=[],[]with open(os.path.join(self.root,filename)) as f:reader=csv.reader(f)for row in reader:img,label=rowlabel=int(label)images.append(img)labels.append(label)#保证长度一致assert len(images)==len(labels)return images,labelsdef __len__(self):return len(self.images)def denormalize(self,x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hot=(x-mean)/std# x=x_hot*std+meanmean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)std=torch.tensor(std).unsqueeze(1).unsqueeze(1)x=x_hat*std+meanreturn xdef __getitem__(self, idx):#idx:[0-len(images)]img,label=self.images[idx],self.labels[idx]tf=transforms.Compose([lambda x:Image.open(x).convert('RGB'),#resizetransforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),transforms.RandomRotation(50),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])img=tf(img)label=torch.tensor(label)return img,labeldef main():import visdomimport timeviz=visdom.Visdom()db=Pokemon('pokeman',64,'w')x,y=next(iter(db))print('sample:',x.shape,y.shape,y)viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)for x,y in loader:#多图展示viz.images,单图viz.imageviz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch_y'))time.sleep(10)if __name__ == '__main__':main()

训练第一步:python获取文件下图片相关推荐

  1. python 抓取目录树_python 获取文件下所有文件或目录os.walk()的实例

    在python3.6版本中去掉了os.path.walk()函数 os.walk() 函数声明:walk(top,topdown=True,oneerror=None) 1.参数top表示需要遍历的目 ...

  2. python获取文件路径下的文件_python 获取文件下所有文件或目录os.walk()的实例

    在python3.6版本中去掉了os.path.walk()函数 os.walk() 函数声明:walk(top,topdown=True,oneerror=None) 1.参数top表示需要遍历的目 ...

  3. python 获取文件夹所有文件列表_python获取文件夹下所有文件及os模块方法

    python获取文件夹下所有文件 方法一:使用os.listdir import os for filename in os.listdir(r'c:\windows'): print filenam ...

  4. Python 获取文件夹下所有文件

    前言 使用Python获取文件夹下的所有文件时,存在多种方式. 1. os.listdir os.listdir:参数为文件夹路径,可以返回文件夹下的所有子文件夹.文件名称. 示例: import o ...

  5. python获取文件路径名_python文件名获取文件路径

    如何使用Python获取文件所在目录和文件名 python中如何根据文件名找他的路径.现在我遍历到怎么才能将某一个文件对应的路径找到呢? 遍历用os.walk: import osfrom os.pa ...

  6. 微信支付接口升级(开通微信代金券)第一步:获取微信沙盒签名

    吐槽:微信官方文档写得简直是高山流水,望而却步,让人看得头皮发麻. ps:如果是没有后台代码或开发人员的朋友,请联系我qq2294974790,可以帮忙开通(收费80:需要商户号和微信秘钥) 好了,言 ...

  7. python获取文件夹里有什么文件+查看特定格式的文件

    python获取文件夹里有什么文件+查看特定格式的文件 功能 程序 效果 后续 功能 获取文件夹的的文件+获取文件夹里的特定格式的文件,比如.png等 程序 import os folder = r& ...

  8. python 获取文件CRC值

    python 获取文件CRC值 crc值在文件改变之前是唯一的 import zlib def crc(fileName):hash = 0for eachLine in open(fileName, ...

  9. python 获取文件夹名称大全_python 获取指定文件夹下所有文件名称并写入列表的实例...

    如下所示: import os import os.path rootdir = "./pic_data" file_object = open('train_list.txt', ...

最新文章

  1. 计算机教师资格证报考科目,还在纠结报考教师资格证该选哪个科目呢?看完这篇,你不再迷茫...
  2. 浅谈tomcat中间件的优化【转】
  3. iOS框架介绍之coreImage
  4. 关于c语言的符号常量以下叙述中正确的是,关于C语言的符号常量,以下叙述中正确的是...
  5. hdu 6086 Rikka with String(AC自动机+状压dp)
  6. spring用的很开心的标签(随时增加)
  7. mysql导入数据库注释乱码_source命令 导入.sql文件时,中文乱码 或者是注释乱码...
  8. WiFi---AP+STA共存模式(ESP8266)
  9. openCV之图像基础(笔记02)
  10. php 修改 apk名称6,反编译sencha toucha打包的apk文件,修改应用名称支持中文以及去除应用标题栏...
  11. C# 6.0 新特性
  12. 项目中的设计模式【工厂方法模式】
  13. C++ concurrent_queue::try_pop 方法
  14. FileUpload文件上传控件
  15. 传奇客户端DATA文件详细说明
  16. 10年程序员私单的经历,送你3个找客户的关键技巧
  17. 爬mei紫图最后代码2015-2019-1-14全部
  18. 在ESNP中还原内网私接小路由器导致用户无法上网场景
  19. 小程序复用公众号资质快速认证
  20. css如何导入特殊字体

热门文章

  1. 移动云mas 通过HTTP请求发送普通短信和 模板短信
  2. 四川职业技术学院linux,2019年四川交通职业技术学院单招中职(信息技术一类)专业技能测试大纲...
  3. 三分钟带你领路Java-JFrame窗体美化
  4. android 虚拟按键自定义,如何适配Android底部虚拟按键
  5. android 虚拟按键源码流程分析
  6. vue——router更改路由地址,但是页面不能跳转
  7. 微信小程序Day4学习笔记
  8. tensorflow conv2d()参数解析
  9. vue之表格数据渲染,实现点击表格某列按钮弹出框显示剩余数据(模态框知识点)
  10. 人工智能数学基础03之:隐函数推导