文章目录

  • 图像分类案例2:
    • 获取数据集
    • 数据增强
    • 读取数据
    • 定义模型
    • 定义训练函数
    • 略微调参
    • 训练模型
    • 测试 提交结果
    • GAN
    • DCGAN

图像分类案例2:

kaggle 狗识别

获取数据集

比赛的网址是https://www.kaggle.com/c/dog-breed-identification 在这项比赛中,我们尝试确定120种不同的狗。该比赛中使用的数据集实际上是著名的ImageNet数据集的子集。
我们可以从比赛网址上下载数据集,其目录结构为:

| Dog Breed Identification| train|   | 000bec180eb18c7604dcecc8fe0dba07.jpg|   | 00a338a92e4e7bf543340dc849230e75.jpg|   | ...| test|   | 00a3edd22dc7859c487a64777fc8d093.jpg|   | 00a6892e5c7f92c1f465e213fd904582.jpg|   | ...| labels.csv| sample_submission.csv

train和test目录下分别是训练集和测试集的图像,训练集包含10,222张图像,测试集包含10,357张图像,图像格式都是JPEG,每张图像的文件名是一个唯一的id。labels.csv包含训练集图像的标签,文件包含10,222行,每行包含两列,第一列是图像id,第二列是狗的类别。狗的类别一共有120种。

我们希望对数据进行整理,方便后续的读取,我们的主要目标是:

  • 从训练集中划分出验证数据集,用于调整超参数。划分之后,数据集应该包含4个部分:划分后的训练集、划分后的验证集、完整训练集、完整测试集
  • 对于4个部分,建立4个文件夹:train, valid, train_valid, test。在上述文件夹中,对每个类别都建立一个文件夹,在其中存放属于该类别的图像。前三个部分的标签已知,所以各有120个子文件夹,而测试集的标签未知,所以仅建立一个名为unknown的子文件夹,存放所有测试数据。

我们希望整理后的数据集目录结构为:

| train_valid_test| train|   | affenpinscher|   |   | 00ca18751837cd6a22813f8e221f7819.jpg|   |   | ...|   | afghan_hound|   |   | 0a4f1e17d720cdff35814651402b7cf4.jpg|   |   | ...|   | ...| valid|   | affenpinscher|   |   | 56af8255b46eb1fa5722f37729525405.jpg|   |   | ...|   | afghan_hound|   |   | 0df400016a7e7ab4abff824bf2743f02.jpg|   |   | ...|   | ...| train_valid|   | affenpinscher|   |   | 00ca18751837cd6a22813f8e221f7819.jpg|   |   | ...|   | afghan_hound|   |   | 0a4f1e17d720cdff35814651402b7cf4.jpg|   |   | ...|   | ...| test|   | unknown|   |   | 00a3edd22dc7859c487a64777fc8d093.jpg|   |   | ...

划分数据集:

data_dir = '/home/kesci/input/Kaggle_Dog6357/dog-breed-identification'  # 数据集目录
label_file, train_dir, test_dir = 'labels.csv', 'train', 'test'  # data_dir中的文件夹、文件
new_data_dir = './train_valid_test'  # 整理之后的数据存放的目录
valid_ratio = 0.1  # 验证集所占比例
def mkdir_if_not_exist(path):# 若目录path不存在,则创建目录if not os.path.exists(os.path.join(*path)):os.makedirs(os.path.join(*path))def reorg_dog_data(data_dir, label_file, train_dir, test_dir, new_data_dir, valid_ratio):# 读取训练数据标签labels = pd.read_csv(os.path.join(data_dir, label_file))id2label = {Id: label for Id, label in labels.values}  # (key: value): (id: label)# 随机打乱训练数据train_files = os.listdir(os.path.join(data_dir, train_dir))random.shuffle(train_files)    # 原训练集valid_ds_size = int(len(train_files) * valid_ratio)  # 验证集大小for i, file in enumerate(train_files):img_id = file.split('.')[0]  # file是形式为id.jpg的字符串img_label = id2label[img_id]if i < valid_ds_size:mkdir_if_not_exist([new_data_dir, 'valid', img_label])shutil.copy(os.path.join(data_dir, train_dir, file),os.path.join(new_data_dir, 'valid', img_label))else:mkdir_if_not_exist([new_data_dir, 'train', img_label])shutil.copy(os.path.join(data_dir, train_dir, file),os.path.join(new_data_dir, 'train', img_label))mkdir_if_not_exist([new_data_dir, 'train_valid', img_label])shutil.copy(os.path.join(data_dir, train_dir, file),os.path.join(new_data_dir, 'train_valid', img_label))# 测试集mkdir_if_not_exist([new_data_dir, 'test', 'unknown'])for test_file in os.listdir(os.path.join(data_dir, test_dir)):shutil.copy(os.path.join(data_dir, test_dir, test_file),os.path.join(new_data_dir, 'test', 'unknown'))reorg_dog_data(data_dir, label_file, train_dir, test_dir, new_data_dir, valid_ratio)

数据增强

transform_train = transforms.Compose([# 随机对图像裁剪出面积为原图像面积0.08~1倍、且高和宽之比在3/4~4/3的图像,再放缩为高和宽均为224像素的新图像transforms.RandomResizedCrop(224, scale=(0.08, 1.0),  ratio=(3.0/4.0, 4.0/3.0)),# 以0.5的概率随机水平翻转transforms.RandomHorizontalFlip(),# 随机更改亮度、对比度和饱和度transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),transforms.ToTensor(),# 对各个通道做标准化,(0.485, 0.456, 0.406)和(0.229, 0.224, 0.225)是在ImageNet上计算得的各通道均值与方差transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet上的均值和方差
])# 在测试集上的图像增强只做确定性的操作
transform_test = transforms.Compose([transforms.Resize(256),# 将图像中央的高和宽均为224的正方形区域裁剪出来transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

读取数据

# new_data_dir目录下有train, valid, train_valid, test四个目录
# 这四个目录中,每个子目录表示一种类别,目录中是属于该类别的所有图像
train_ds = torchvision.datasets.ImageFolder(root=os.path.join(new_data_dir, 'train'),transform=transform_train)
valid_ds = torchvision.datasets.ImageFolder(root=os.path.join(new_data_dir, 'valid'),transform=transform_test)
train_valid_ds = torchvision.datasets.ImageFolder(root=os.path.join(new_data_dir, 'train_valid'),transform=transform_train)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join(new_data_dir, 'test'),transform=transform_test)
batch_size = 128
train_iter = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, shuffle=True)
train_valid_iter = torch.utils.data.DataLoader(train_valid_ds, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)  # shuffle=False

定义模型

def get_net(device):finetune_net = models.resnet34(pretrained=False)  # 预训练的resnet34网络finetune_net.load_state_dict(torch.load('/home/kesci/input/resnet347742/resnet34-333f7ec4.pth'))for param in finetune_net.parameters():  # 冻结参数param.requires_grad = False# 原finetune_net.fc是一个输入单元数为512,输出单元数为1000的全连接层# 替换掉原finetune_net.fc,新finetuen_net.fc中的模型参数会记录梯度finetune_net.fc = nn.Sequential(nn.Linear(in_features=512, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=120)  # 120是输出类别数)return finetune_net

定义训练函数

def train(net, train_iter, valid_iter, num_epochs, lr, wd, device, lr_period,lr_decay):loss = nn.CrossEntropyLoss()optimizer = optim.SGD(net.fc.parameters(), lr=lr, momentum=0.9, weight_decay=wd)net = net.to(device)for epoch in range(num_epochs):train_l_sum, n, start = 0.0, 0, time.time()if epoch > 0 and epoch % lr_period == 0:  # 每lr_period个epoch,学习率衰减一次lr = lr * lr_decayfor param_group in optimizer.param_groups:param_group['lr'] = lrfor X, y in train_iter:X, y = X.to(device), y.to(device)optimizer.zero_grad()y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()train_l_sum += l.item() * y.shape[0]n += y.shape[0]time_s = "time %.2f sec" % (time.time() - start)if valid_iter is not None:valid_loss, valid_acc = evaluate_loss_acc(valid_iter, net, device)epoch_s = ("epoch %d, train loss %f, valid loss %f, valid acc %f, "% (epoch + 1, train_l_sum / n, valid_loss, valid_acc))else:epoch_s = ("epoch %d, train loss %f, "% (epoch + 1, train_l_sum / n))print(epoch_s + time_s + ', lr ' + str(lr))

略微调参

num_epochs, lr_period, lr_decay = 20, 10, 0.1
lr, wd = 0.03, 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = get_net(device)
train(net, train_iter, valid_iter, num_epochs, lr, wd, device, lr_period, lr_decay)

训练模型

net = get_net(device)
train(net, train_valid_iter, None, num_epochs, lr, wd, device, lr_period, lr_decay)

测试 提交结果

preds = []
for X, _ in test_iter:X = X.to(device)output = net(X)output = torch.softmax(output, dim=1)preds += output.tolist()
ids = sorted(os.listdir(os.path.join(new_data_dir, 'test/unknown')))
with open('submission.csv', 'w') as f:f.write('id,' + ','.join(train_valid_ds.classes) + '\n')for i, output in zip(ids, preds):f.write(i.split('.')[0] + ',' + ','.join([str(num) for num in output]) + '\n')

GAN

待续

DCGAN

待续

小结8:图像分类案例2,GAN、DCGAN相关推荐

  1. 动手学深度学习: 图像分类案例2,GAN,DCGAN

    动手学深度学习: 图像分类案例2,GAN,DCGAN 内容摘自伯禹人工智能AI公益课程 图像分类案例2 1.关于整理数据集后得到的train.valid.train_valid和test数据集: 1) ...

  2. GAN小结(BEGAN EBGAN WGAN CycleGAN conditional GAN DCGAN PGGAN VAEGAN)

    断断续续看了生成对抗网络一些日子,下面把我比较感兴趣也算是我认为效果比较好的GAN进行简单梳理,其中会参考众多前辈的文章,主要包括 1.EBGAN 原文 https://arxiv.org/pdf/1 ...

  3. [Python图像处理] 二十六.图像分类原理及基于KNN、朴素贝叶斯算法的图像分类案例

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  4. [Python人工智能] 十三.如何评价神经网络、loss曲线图绘制、图像分类案例的F值计算

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章详细讲解了循环神经网络RNN和长短期记忆网络LSTM的原理知识,并采用TensorFlow实现手写数字识别的R ...

  5. [Python人工智能] 十.Tensorflow+Opencv实现CNN自定义图像分类案例及与机器学习KNN图像分类算法对比

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇详细讲解了gensim词向量Word2Vec安装.基础用法,并实现<庆余年>中文短文本相似度计算及多个 ...

  6. 《动手学深度学习》Task09:目标检测基础+图像风格迁移+图像分类案例1

    1 目标检测基础 1.1 目标检测和边界框(9.3) %matplotlib inline from PIL import Imageimport sys sys.path.append('/home ...

  7. Deep Convolutional GAN (DCGAN)

    使用MNIST数据集创建一个GAN.实现Deep Convolutiona GAN (DCGAN),DCGAN是2015年开发的非常成功和有影响力的GAN模型(论文地址https://arxiv.or ...

  8. PaperNotes(6)-GAN/DCGAN/WGAN/WGAN-GP/WGAN-SN-网络结构/实验效果

    GAN模型网络结构+实验效果演化 1.GAN 1.1网络结构 1.2实验结果 2.DCGAN 2.1网络结构 2.2实验结果 3.WGAN 3.1网络结构 3.2实验结果 4.WGAN-GP 4.1网 ...

  9. 【图像分类案例】(1) ResNeXt 交通标志四分类,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 ResNeXt 神经网络模型,通过案例实战 ResNeXt 的训练以及预测过程.每个小节的末尾有网络.训练.预测的完整代码.想要数据 ...

最新文章

  1. 单链表-单链表拆分为两个线性表(尾插法+尾插法)
  2. 自动装配——@Resource(JSR250)和@Inject(JSR330)---[java规范的注解]
  3. java 中通过 Lettuce 来操作 Redis
  4. code128条码c语言,C#生成code128条形码的方法
  5. Spring Boot中@ConfigurationProperties与@PropertySource的基本使用(读取指定的properties文件)
  6. FinTech:一个单体系统足以撑起银行持续交付全球大项目
  7. HTMLCSSJavaScript个人入门自学笔记
  8. 基于矩阵分解的隐因子模型
  9. 算法熟记-排序系列-归并排序
  10. 拓端tecdat|R语言使用限制平均生存时间RMST比较两条生存曲线分析肝硬化患者
  11. 中国计算机学会(CCF)推荐国际学术会议和期刊目录(2019年版,官网转载)
  12. 【MacOs系统-M2安装2022新版AWVS渗透工具】-保姆级安装教程
  13. Eclipse输入或创建txt文件位置
  14. 电子元器件检测与维修从入门到精通视频教程
  15. OpenJudge超详细题解,动画图文题解
  16. 英尺、英寸、厘米的转化:C语言
  17. C#控制语音卡实现呼叫、录音以及来电弹屏
  18. vite 框架 打包修改主页名字得方法
  19. 千寻位置48小时“复活”伽利略卫星定位系统
  20. 【AWS云从业者基础知识笔记】——模块1:AWS服务简介

热门文章

  1. 【风力】基于Matlab模拟风力涡轮机的雷达信号
  2. node批量抓取并下载小姐姐照片
  3. Vector Commitment Techniques and Applications to Verifiable Decentralized Storage代码解析
  4. Windows资源管理器浏览svg图片
  5. 如何快速交付高价值的软件
  6. SparkContext的初始化(伯篇)——执行环境与元数据清理器
  7. 天宫初级认证答案_梦幻西游:资深玩家知识问答答案
  8. 解决iis dllhost.exe问题
  9. 会声会影12(X2)中文版入门视频教程(高清)
  10. bes2300之led配置(三)