当前位置:首页 >> 智能终端演进 >> 【TensorFlow-windows】keras接口——ImageDataGenerator裁剪,诺基亚 5070

【TensorFlow-windows】keras接口——ImageDataGenerator裁剪,诺基亚 5070

cpugpu芯片开发光刻机 智能终端演进 2
文件名:【TensorFlow-windows】keras接口——ImageDataGenerator裁剪,诺基亚 5070 【TensorFlow-windows】keras接口——ImageDataGenerator裁剪 前言

Keras中有一个图像数据处理器ImageDataGenerator,能够很方便地进行数据增强,并且从文件中批量加载图片,避免数据集过大时,一下子加载进内存会崩掉。但是从官方文档发现,并没有一个比较重要的图像增强方式:随机裁剪,本博客就是记录一下如何在对ImageDataGenerator中生成的batch做图像裁剪

国际惯例,参考博客:

官方ImageDataGenerator文档

Keras 在fit_generator训练方式中加入图像random_crop

Extending Keras’ ImageDataGenerator to Support Random Cropping

how to use fit_generator with multiple image inputs

第二个博客比较全,第三个博客只介绍了分类数据的增强,如果是图像分割或者超分辨率,输出仍是一张图像,所以涉及到对image和mask进行同步增强

代码

先介绍一下数据集目录结构:

在test文件夹下,分别有GT和NGT两个文件夹,每个文件夹存储的都是bmp图像文件

其次需要注意,从ImageDataGenerator中取数据用的是next(generator)函数

载入相关包

from keras_preprocessing.image import ImageDataGeneratorimport matplotlib.pyplot as pltimport numpy as np

先使用自带的ImageDataGenerator配合flow_from_director读取数据 创建生成器

train_img_datagen=ImageDataGenerator()#各种预处理train_mask_datagen=ImageDataGenerator()#各种预处理

读取文件

seed=2 #图像会随机打乱即shuffle,但是输入和输出的打乱顺序必须一样batch_size=2target_size=(1080,1920)train_img_gen=train_img_datagen.flow_from_directory('./test',classes=['NGT'],class_mode=None,batch_size=batch_size,target_size=target_size,shuffle=True,seed=seed,interpolation='bicubic')train_mask_gen=train_img_datagen.flow_from_directory('./test',classes=['GT'],class_mode=None,batch_size=batch_size,target_size=target_size,shuffle=True,seed=seed,interpolation='bicubic')

封装打包

train_generator=zip(train_img_gen,train_mask_gen)

定义裁剪器,裁剪图像和对应的mask:

def crop_generator(batch_gen,crop_size=(270,480)):while True:batch_x,batch_y=next(batch_gen)crops_img=np.zeros((batch_x.shape[0],crop_size[0],crop_size[1],3))crops_mask=np.zeros((batch_y.shape[0],crop_size[0],crop_size[1],3))height,width=batch_x.shape[1],batch_x.shape[2]for i in range(batch_x.shape[0]):#裁剪图像x=np.random.randint(0,height-crop_size[0]+1)y=np.random.randint(0,width-crop_size[1]+1)crops_img[i]=batch_x[i,x:x+crop_size[0],y:y+crop_size[1]]crops_mask[i]=batch_y[i,x:x+crop_size[0],y:y+crop_size[1]]yield (crops_img,crops_mask)

使用裁剪器对Generator进行裁剪

train_crops=crop_generator(train_generator)

可视化:

img,mask=next(train_crops)print(img.shape)plt.subplot(2,1,1)plt.imshow(img[0]/255)plt.subplot(2,1,2)plt.imshow(mask[0]/255)

后记

记住要用while(True)死循环,并且yield在while循环内部,和for循环外部,代表每个批次

代码: 链接:https://pan.baidu.com/s/1UNZLke5kygBFHJ8iR8wV2A 提取码:e51e

协助本站SEO优化一下,谢谢!
关键词不能为空
同类推荐
«    2025年12月    »
1234567
891011121314
15161718192021
22232425262728
293031
控制面板
您好,欢迎到访网站!
  查看权限
网站分类
搜索
最新留言
文章归档
网站收藏
友情链接