当前位置:首页 >> 编程语言 >> 【Pytorch】深度学习之数据读取,sx30(pytorch读取数据集)

【Pytorch】深度学习之数据读取,sx30(pytorch读取数据集)

0evadmin 编程语言 1
文件名:【Pytorch】深度学习之数据读取,sx30 【Pytorch】深度学习之数据读取

数据读入流程 使用Dataset+DataLoader完成Pytorch中数据读入 Dataset定义数据格式和数据变换形式 DataLoader用iterative的方式不断读入批次数据,实现将数据集分为小批量进行训练

使用PyTorch自带数据集 使用Dataset完成数据格式和数据变换的定义

import torchfrom torchvision import datasetstrain_data = datasets.ImageFolder(train_path, transform=data_transform)val_data = datasets.ImageFolder(val_path, transform=data_transform)

参数说明: transform实现对图像数据的变换处理

使用DataLoader完成按批次读取数据

from torch.utils.data import DataLoadertrain_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)

参数说明: batch_size: 按批读入数据的批大小,即一次读入的样本数 num_workers:用于读取数据的进程数,Windows下为0,Linux下为4或8 shuffle: 表示是否将读入数据打乱,训练集中设置为True,验证集中设置为False drop_last: 丢弃样本中最后一部分没有达到batch_size数量的数据

数据展示

import matplotlib.pyplot as pltimages, labels = next(iter(val_loader))print(images.shape)# 使用transpose()函数改变原始图像的表示形式,从(H,W,C)的表示转换为(C,H,W)的表示plt.imshow(images[0].transpose(1,2,0)) plt.show()

自定义数据集方式

自定义Dataset类继承Dataset类实现三个函数,__init__函数、__getitem__函数、__len__函数 import osimport pandas as pdfrom torchvision.io import read_imageclass MyDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):"""Args:annotations_file (string): Path to the csv file with annotations.img_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be applied on a sample.target_transform (callable, optional): Optional transform to be applied on the target."""self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):"""Args:idx (int): Index"""# 使用path.join()函数构建图像路径,img_labels.iloc[行,列]用于通过行列索引访问DataFrame中的元素img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
协助本站SEO优化一下,谢谢!
关键词不能为空
同类推荐
«    2025年12月    »
1234567
891011121314
15161718192021
22232425262728
293031
控制面板
您好,欢迎到访网站!
  查看权限
网站分类
搜索
最新留言
文章归档
网站收藏
友情链接