数据集格式

之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果

训练代码如下:

from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = labelfrom transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensornormalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesscene = scene.with_transform(transforms)from transformers import DefaultDataCollatordata_collator = DefaultDataCollator()from transformers import AutoModelForImageClassification, TrainingArguments, Trainerdef compute_metric(eval_pred):metric = load_metric("accuracy")logits,labels = eval_predprint(logits,labels)print(len(logits),len(labels))predictions = np.argmax(logits,axis=-1)print(len(predictions))print('predictions')print(predictions)return metric.compute(predictions = predictions,references = labels)model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k",num_labels=len(labels),id2label=id2label,label2id=label2id,
)training_args = TrainingArguments(output_dir="./results",overwrite_output_dir = 'True',per_device_train_batch_size=16,evaluation_strategy="steps",num_train_epochs=4,save_steps=100,eval_steps=100,logging_steps=10,learning_rate=2e-4,save_total_limit=2,remove_unused_columns=False,load_best_model_at_end=False,save_strategy='no',
)trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=scene["train"],eval_dataset=scene["test"],tokenizer=feature_extractor,compute_metrics=compute_metric,)trainer.train()
trainer.evaluate()
trainer.save_model('/home/huhao/script/model')

测试代码如下

from transformers import AutoFeatureExtractor, AutoModelForImageClassificationextractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
# 我已经把训练好的模型上传到网上,这里下载即可使用
from datasets import load_dataset
# /home/huhao/TensorFlow2.0_ResNet/dataset
# /home/huhao/dataset
import numpy as np
from datasets import load_metric
# 这个是数据集加载的路径
scene = load_dataset("/home/huhao/script/dataset")
dataset = scene['train']
scene = dataset.train_test_split(test_size=0.2)labels = scene["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = labelfrom transformers import AutoFeatureExtractor
# google/vit-base-patch16-224-in21k
feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensornormalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesscene = scene.with_transform(transforms)
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
from transformers import AutoModelForImageClassification, TrainingArguments, Trainertraining_args = TrainingArguments(output_dir="./results",overwrite_output_dir = 'True',per_device_train_batch_size=16,evaluation_strategy="steps",num_train_epochs=4,save_steps=100,eval_steps=100,logging_steps=10,learning_rate=2e-4,save_total_limit=2,remove_unused_columns=False,load_best_model_at_end=False,save_strategy='no',
)model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence",num_labels=len(labels),id2label=id2label,label2id=label2id,
)def compute_metric(eval_pred):metric = load_metric("f1")logits,labels = eval_predprint(len(logits),len(labels))predictions = np.argmax(logits,axis=-1)print('对测试集进行评估')print('labels')print(labels)print('predictions')print(predictions)return metric.compute(predictions = predictions,references = labels,average='macro')trainer = Trainer(model=model,args=training_args,data_collator=data_collator,eval_dataset=scene["test"],tokenizer=feature_extractor,compute_metrics=compute_metric,)
compute_metrics = trainer.evaluate()
# {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
print('输出最后的结果eval_f1:')
print(compute_metrics['eval_f1'])
from doctest import Example
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
import os
extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
from transformers import pipeline#generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
val_path = '/home/huhao/script/val/'
all_img = os.listdir(val_path)
for img in all_img:tmp_score = 0end_label = ''img_path = os.path.join(val_path,img)score_list = vision_classifier(img_path)for sample in score_list:score = sample['score']label = sample['label']if tmp_score < score:tmp_score = scoreend_label = labelprint(result_dict[end_label])

深度学习区分不同种类的图片相关推荐

  1. 深度学习之昆虫种类识别

    深度学习之昆虫种类识别 源代码: #导入所需的包 import os from PIL import Image import numpy as np import matplotlib.pyplot ...

  2. 1 图片channels_深度学习中各种图像库的图片读取方式

    深度学习中各种图像库的图片读取方式总结 在数据预处理过程中,经常需要写python代码搭建深度学习模型,不同的深度学习框架会有不同的读取数据方式(eg:Caffe的python接口默认BGR格式,Te ...

  3. (转)深度学习中各种图像库的图片读取方式

    https://blog.csdn.net/u013841196/article/details/81194310 深度学习中各种图像库的图片读取方式总结 在数据预处理过程中,经常需要写python代 ...

  4. 深度学习中各种图像库的图片读取方式

    深度学习中各种图像库的图片读取方式总结 在数据预处理过程中,经常需要写python代码搭建深度学习模型,不同的深度学习框架会有不同的读取数据方式(eg:Caffe的python接口默认BGR格式,Te ...

  5. 深度学习之数据处理——如何将图片和标签打乱并划分为训练集和测试集

    深度学习之数据处理--如何将图片和标签打乱并划分为训练集和测试集 记录我的第一篇CSDN博客 最近我在网上找到Office31数据集,这个数据集中包含了三个子数据集,分别为:Amazon.dslr.w ...

  6. 机器学习与深度学习——通过奇异值分解算法压缩图片

    机器学习与深度学习--通过奇异值分解算法压缩图片 什么是奇异值分解? 奇异值分解(Singular Value Decomposition,SVD)是一种重要的线性代数方法,用于将一个矩阵分解成三个部 ...

  7. 深度学习-第T2周——彩色图片分类

    深度学习-第T2周--彩色图片分类 深度学习-第P1周--实现mnist手写数字识别 一.前言 二.我的环境 三.前期工作 1.导入依赖项并设置GPU 2.导入数据集 3.归一化 4.可视化图片 四. ...

  8. 【深度学习数据集】常用公开图片数据集下载

    1.MNIST MNIST是一个手写数字数据库,它有60000个训练样本集和10000个测试样本集,每个样本图像的宽高为28*28.此数据集是以二进制存储的,不能直接以图像格式查看,不过很容易找到将其 ...

  9. 【动手学深度学习】Softmax 回归 + 损失函数 + 图片分类数据集

    学习资料: 09 Softmax 回归 + 损失函数 + 图片分类数据集[动手学深度学习v2]_哔哩哔哩_bilibili torchvision.transforms.ToTensor详解 | 使用 ...

最新文章

  1. go语言json的使用技巧
  2. Python算法实战系列:栈
  3. python如何安装seaborn模块_seaborn模块的基本使用
  4. linux shell脚本 wget,bash – 在shell脚本中运行wget和其他命令
  5. Python 计算机视觉(一) —— 数字图像处理基础
  6. 神策数据助力海尔落地 6 大智慧厨房在线场景
  7. 【年度重磅】2020华为云社区年度技术精选合集,700页+免费下载!
  8. 备忘录模式(Memento Pattern)
  9. 报错:/BuildRoot/Library/Caches/com.apple.xbs/Sources/UIKit_Sim/UIKit-3512.29.5/UITableView.m:7943解决方法
  10. golang中base64编码_Rust 中的字符集编码 Rust 实践指南
  11. ALSA 中 hw 和 plughw 的区别
  12. 计算机防火墙知识点,防火墙及防火墙的基本概念-信息安全工程师知识点
  13. 关于test eax eax
  14. CALCULATE函数的运算顺序-第一弹
  15. 润乾报表主子报表通过参数控制子报表显示
  16. 低代码、端到端,一小时构建IoT示例场景,声网发布灵隼物联网云平台
  17. 想剑网三妹子最多服务器,玩家有多“疯狂”?为了新门派,提前一年为其准备108套外观...
  18. 秒杀系统的页面静态化
  19. 【计算机网络】分组交换网中的时延、丢包和吞吐量
  20. gx works2 存储器空间或桌面堆栈不足_小户型旧房翻新的8个重点,小家也能住出大空间...

热门文章

  1. 量子+AI应用:量子计算与神经网络
  2. keil stm32f407工程环境搭建
  3. 戴尔r540服务器型号报价,戴尔_PowerEdge R540_机架式服务器参数_服务器推荐购买 | Dell 中国大陆...
  4. vgh电压高了有什么_液晶屏VGH,VGL,VCOM电压值正常为多少-液晶屏vgh电压
  5. Git基础:第七、八章 Git提交规范Github/Gitee(github资料附录表)
  6. CSP-J第二轮真题 分类题单
  7. matlab 符号变量范围,Matlab符号变量
  8. Unity 3D 人形角色动画(Avatar)||Unity 3D 导航系统||Unity 3D 障碍物
  9. AJAX请求是什么?
  10. python的jieba库第一次中文分词记录