目录

bmp Dataset.from_tensor_slices:

Dataset简单用法

png这个测试ok:

读图片,resize,预测

构建dateset png格式可以训练:


bmp Dataset.from_tensor_slices:

    augfiles = ['test_images/532_img_.bmp']gtfiles = ['test_images/532_img_.bmp']augImages = tf.constant(augfiles)gtImages = tf.constant(gtfiles)dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))# dataset = dataset.shuffle(len(augImages))# dataset = dataset.repeat()dataset = dataset.map(parse_function).batch(1)

Dataset简单用法

一、Dataset使用
# from_tensor_slices:表示从张量中获取数据。
# make_one_shot_iterator():表示只将数据读取一次,然后就抛弃这个数据了。
input_data = [1,2,3,5,8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
for e in dataset:print(e)

png这个测试ok:

    img_byte = tf.compat.v1.read_file(filename='test_images/532_img_.png')img_data_jpg = tf.image.decode_png(img_byte)  # 图像解码img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.uint8)  # 改变图像数据的类型

读图片,resize,预测

filename_image_string = tf.io.read_file(imgfile)
filename_image = tf.image.decode_png(filename_image_string, channels=3)
filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
filename_image = tf.image.resize(filename_image, (256, 256))
l, w, c = filename_image.shape
filename_image = tf.reshape(filename_image, [1, l, w, c])
output = model.predict(filename_image)
output = output.reshape((l, w, c)) * 255
cv2.imwrite(out_dir+ os.path.basename(imgfile), output)

构建dateset png格式可以训练:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, SpatialDropout2D, ReLU, Input, Concatenate, Add
from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError
from tensorflow.keras.optimizers import Adam
import os
import pandas as pd
import cv2class UWCNN(tf.keras.Model):def __init__(self):super(UWCNN, self).__init__()self.conv1 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze1")self.relu1 = ReLU()self.conv2 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze2")self.relu2 = ReLU()self.conv3 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze3")self.relu3 = ReLU()self.concat1 = Concatenate(axis=3)self.conv4 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze4")self.relu4 = ReLU()self.conv5 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze5")self.relu5 = ReLU()self.conv6 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze6")self.relu6 = ReLU()self.concat2 = Concatenate(axis=3)self.conv7 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze7")self.relu7 = ReLU()self.conv8 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze8")self.relu8 = ReLU()self.conv9 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze9")self.relu9 = ReLU()self.concat3 = Concatenate(axis=3)self.conv10 = Conv2D(3, 3, (1, 1), 'same', name="conv2d_dehaze10")self.add1 = Add()def call(self, inputs):image_conv1 = self.relu1(self.conv1(inputs))image_conv2 = self.relu2(self.conv2(image_conv1))image_conv3 = self.relu3(self.conv3(image_conv2))dehaze_concat1 = self.concat1([image_conv1, image_conv2, image_conv3, inputs])image_conv4 = self.relu4(self.conv4(dehaze_concat1))image_conv5 = self.relu5(self.conv5(image_conv4))image_conv6 = self.relu6(self.conv6(image_conv5))dehaze_concat2 = self.concat2([dehaze_concat1, image_conv4, image_conv5, image_conv6])image_conv7 = self.relu7(self.conv7(dehaze_concat2))image_conv8 = self.relu8(self.conv8(image_conv7))image_conv9 = self.relu9(self.conv9(image_conv8))dehaze_concat3 = self.concat3([dehaze_concat2, image_conv7, image_conv8, image_conv9])image_conv10 = self.conv10(dehaze_concat3)out = self.add1([inputs, image_conv10])return outdef parse_function(filename, label):filename_image_string = tf.io.read_file(filename)label_image_string = tf.io.read_file(label)# Decode the filename_image_stringfilename_image = tf.image.decode_png(filename_image_string, channels=3)filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)# Decode the label_image_stringlabel_image = tf.image.decode_png(label_image_string, channels=3)label_image = tf.image.convert_image_dtype(label_image, tf.float32)return filename_image, label_imagedef combloss(y_actual, y_predicted):'''This is the custom loss function for keras model:param y_actual::param y_predicted::return:'''# this is just l2 + lssimlssim = tf.constant(1, dtype=tf.float32) - tf.reduce_mean(tf.image.ssim(y_actual, y_predicted, max_val=1, filter_size=13))  # remove max_val=1.0# lmse = MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)(y_actual, y_predicted)lmse = MeanSquaredError(reduction='sum_over_batch_size')(y_actual, y_predicted)lmse = tf.math.multiply(lmse, 4)return tf.math.add(lmse, lssim)def train(ckptpath="./train_type1/cp.ckpt", type='type1'):# df = pd.read_csv(datafile)augfiles = ['test_images/532_img_.png']gtfiles = ['test_images/532_label_.png']augImages = tf.constant(augfiles)gtImages = tf.constant(gtfiles)dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))# dataset = dataset.shuffle(len(augImages))# dataset = dataset.repeat()dataset = dataset.map(parse_function).batch(1)# Call backs# checkpoint_path = "./train_type1/cp.ckpt"checkpoint_path = ckptpathcheckpoint_dir = os.path.dirname(checkpoint_path)# Create a callback that saves the model's weightscp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)model = UWCNN()model.compile(optimizer=Adam(), loss=combloss)model.fit(dataset, epochs=1, callbacks=[cp_callback])# os.listdir(checkpoint_dir)# model.save('saved_model/my_model')model.save('save_model/' + type)# model.sample_weights('model_weight.h5')def model_test(imgdir="./test_images/", imgfile="12433.png", ckdir="./train_type1/cp.ckpt", outdir="./results/",type='type1'):model = UWCNN()# model.load_weights('model_weight.h5')# model = tf.keras.models.load_model('save_model/' + type, custom_objects={'loss': combloss}, compile=False)augfiles = ['test_images/532_img_.bmp']gtfiles = ['test_images/532_img_.bmp']augImages = tf.constant(augfiles)gtImages = tf.constant(gtfiles)dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))dataset = dataset.map(parse_function).batch(1)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=ckdir, save_weights_only=True, verbose=1)model.compile(optimizer=Adam(), loss=combloss)model.fit(dataset, epochs=1, callbacks=[cp_callback])model.summary()model.load_weights(ckdir)filename_image_string = tf.io.read_file(imgdir + imgfile)filename_image = tf.image.decode_png(filename_image_string, channels=3)filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)filename_image = tf.image.resize(filename_image, (460, 620))l, w, c = filename_image.shapefilename_image = tf.reshape(filename_image, [1, l, w, c])output = model.predict(filename_image)output = output.reshape((l, w, c)) * 255cv2.imwrite(outdir + type + "_" + imgfile, output)if __name__ == "__main__":train(ckptpath="./train_type1/cp.ckpt", type='type1')exit(0)type = "type1"ckdir = "./train_type1/cp.ckpt"model_test(imgdir="./test_images/", imgfile="532_img_.png", ckdir=ckdir, outdir="./results/", type=type)# model_test(imgdir="./test_images/", imgfile="602_img_.png", ckdir=ckdir, outdir="./results/", type=type)# model_test(imgdir="./test_images/", imgfile="617_img_.png", ckdir=ckdir, outdir="./results/", type=type)# model_test(imgdir="./test_images/", imgfile="12422.png", ckdir=ckdir, outdir="./results/", type=type)# model_test(imgdir="./test_images/", imgfile="12433.png", ckdir=ckdir, outdir="./results/", type=type)

tensorflow 读取图片 Dataset用法相关推荐

  1. 【tensorflow 读取图片方式】本地文件名读取以及url方式读取

    无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.点这里可以跳转到教程.人工智能教程 图片存在形式,一般 ...

  2. Tensorflow读取图片并转换成张量

    使用神经网络对图片文件进行训练时,需要将图片信息转换为张量,这里介绍如何将图片信息转化为Tensorflow的张量信息. 本文完整代码:https://github.com/iapcoder/Tens ...

  3. tensorflow 读取图片错误(error UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0 )的解析

    在用tensorflow 自带的读取文件gfile模块中,调用API,如: filename = directory + DIRECTORY_IMAGES + name + '.jpg' image_ ...

  4. TensorFlow与OpenCV,读取图片,进行简单操作并显示

    本文是OpenCV  2 Computer Vision Application Programming Cookbook读书笔记的第一篇.在笔记中将以Python语言改写每章的代码. PythonO ...

  5. Tensorflow 读取XML文件内容并对图片等比例缩放

    根据XML文件中对图片标记的信息读取,并显示在图片中. xml 文件内容: <annotation><folder>OXIIIT</folder><filen ...

  6. sceneflow 数据集多少张图片_快速使用 Tensorflow 读取 7 万数据集!

    原标题:快速使用 Tensorflow 读取 7 万数据集! 作者 | 郭俊麟 责编 | 胡巍巍 Brief 概述 这篇文章中,我们使用知名的图片数据库「THE MNIST DATABASE」作为我们 ...

  7. 【小白学PyTorch】16.TF2读取图片的方法

    <<小白学PyTorch>> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 ...

  8. pytorch dataset_【小白学PyTorch】16.TF2读取图片的方法

    <> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 小白学PyTorch | 13 Ef ...

  9. 基于Pytorch、Keras、Tensorflow的图片数据生成器搭建

    1. 前言 作为一个对三种深度学习框架(Tensorflow,Keras,Pytorch)刚刚完成入门学习的菜鸟,在实战的过程中,遇到了一些菜鸟常见问题,即图片数据加载与预处理.在刚刚接触深度学习的时 ...

最新文章

  1. ik分词器 分词原理_Solr7.7.2中文分词器
  2. Asp.net 2.0 发送邮件
  3. 百度安全 TrustZone SDK 正式成为 OP-TEE 官方推荐 Rust 开发环境
  4. Sharepoin学习笔记—架构系列--04 Sharepoint的四种执行模型 1
  5. Python3——网络编程基础
  6. 《剑指Offer》解题目录(更新完毕)
  7. HDOJ 2955 Robberies (01背包)
  8. Maven多模块项目搭建
  9. 50-00-010-配置-kylin-2.6.0官网配置
  10. android studio的sha1,[原]Android Studio查询SHA1的方法
  11. BZOJ4066 简单题(KD-Tree)
  12. html隔一行的代码,HTML n种方式实现隔行变色的示例代码
  13. 华为披露手机出货超1亿台的“两点”意图
  14. 阿里云 mysql 创建数据库 账户密码 外网连接等
  15. python矩阵教程_Python Numpy Tutorial / Python Numpy 教程 (矩阵和图像操作)
  16. 投递简历用什么邮箱最好用?
  17. python谷歌小恐龙,这还是你断网时的样子嘛~
  18. 中国人保为嘉德奥通承保产品责任险,为消费者保驾护航!
  19. 博图中热电阻/热电偶(RTD/TC)模拟量信号的处理
  20. PCIe学习(二):PCIe DMA关键模块分析之一

热门文章

  1. element ui table组件 异步加载数据盒子位移
  2. 质性研究工具NVivo教你辨别疫情各路谣言!
  3. 统一建模语言(UML) | 类图
  4. Python基础知识4——操作列表
  5. 数大电信巨头缺席北京通信展 折射电信业冬天
  6. 2020黑马程序员之黑马优购小程序
  7. 自适应滤波器做啸叫检测
  8. python使用RANSAC算法拟合直线
  9. 快速给小程序加上人性化的「添加到我的小程序」提示
  10. 缝制五彩经幡树 迎接藏历新年