Files
fzu-product/4.人工智能/4.6.5.4.1数据读取.md
2023-04-16 03:04:12 +08:00

195 lines
11 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 数据读取
Torchvision 中默认使用的图像加载器是 PIL因此为了确保 Torchvision 正常运行,我们还需要安装一个 Python 的第三方图像处理库——Pillow 库。Pillow 提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。
我们先介绍 Torchvision 的常用数据集及其读取方法。
PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。
下面我们分别来看下 Dataset 类与 DataLoader 类。
# Dataset 类
PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset 类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。
其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:
- __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
- __len__():返回数据集大小;
- __getitem__():索引数据集中的某一个数据。
下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个 Tensor 类型的数据集。
```
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
'''
我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入 Tensor 类型的数据与标签;
在 __len__ 函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。
'''
```
然后我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据 Tensor然后生成 10 维的标签 Tensor与数据 Tensor 相对应。利用这两个 Tensor生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。
```
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 生成10个随机数随机数的范围只能是0或者1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
```
# DataLoader 类
在实际项目中如果数据量很大考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。
DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。
DataLoader 类的调用方式如下:
```
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
batch_size=2, # 输出的batch大小
shuffle=True, # 数据是否打乱
num_workers=0) # 进程数, 0表示只有主进程
# 以循环形式输出
for data, target in tensor_dataloader:
print(data, target)
'''
输出:
tensor([[-0.1781, -1.1019, -0.1507],
[-0.6170, 0.2366, 0.1006]]) tensor([0, 0])
tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]) tensor([0, 0])
tensor([[-0.4561, -1.2480, -0.3051],
[-0.9738, 0.9465, 0.4812]]) tensor([1, 0])
tensor([[ 0.0260, 1.5276, 0.1687],
[ 1.3692, -0.0170, -1.6831]]) tensor([1, 0])
tensor([[ 0.0515, -0.8892, -0.1699],
[ 0.4931, -0.0697, 0.4171]]) tensor([1, 0])
'''
# 输出一个batch(用iter()强制类型转换成迭代器的对象next()是输出迭代器下一个元素)
print('One batch tensor data: ', iter(tensor_dataloader).next())
'''
输出:
One batch tensor data: [tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]), tensor([0, 0])]
'''
```
结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示:
- datasetDataset 类型,输入的数据集,必须参数;
- batch_sizeint 类型,每个 batch 有多少个样本;
- shufflebool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;
- num_workersint 类型加载数据的进程数0 意味着所有的数据都会被加载进主进程,默认为 0。
<strong>思考题</strong>
按照上述代码One batch tensor data 的输出是否正确,若不正确,为什么?
# 利用 Torchvision 读取数据
Torchvision 库中的 torchvision.datasets 包中提供了丰富的图像数据集的接口。常用的图像数据集,例如 MNIST、COCO 等,这个模块都为我们做了相应的封装。
下表中列出了 torchvision.datasets 包所有支持的数据集。各个数据集的说明与接口,详见链接 [https://pytorch.org/vision/stable/datasets.html](https://pytorch.org/vision/stable/datasets.html)。
![](static/boxcnxvqC7FKt1qeCZoI2kVf9yg.png)
注意torchvision.datasets 这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。
为了让你进一步加深对知识的理解,我们以 MNIST 数据集为例,来说明一下这个模块具体的使用方法。
# MNIST 数据集简介
MNIST 数据集是一个著名的手写数字数据集,因为上手简单,在深度学习领域,手写数字识别是一个很经典的学习入门样例。
MNIST 数据集是 NIST 数据集的一个子集MNIST 数据集你可以通过[这里](http://yann.lecun.com/exdb/mnist/)下载。它包含了四个部分。
![](static/boxcnCP2Sp932nPy8Il5Z5d4Aih.png)
MNIST 数据集是 ubyte 格式存储,我们先将“训练集图片”解析成图片格式,来直观地看一看数据集具体是什么样子的。具体怎么解析,我们在后面数据预览再展开。
![](static/boxcnjsG31hhjqdxOnoCGFGR6sh.png)
接下来,我们看一下如何使用 Torchvision 来读取 MNIST 数据集。
对于 torchvision.datasets 所支持的所有数据集,它都内置了相应的数据集接口。例如刚才介绍的 MNIST 数据集torchvision.datasets 就有一个 MNIST 的接口,接口内封装了从下载、解压缩、读取数据、解析数据等全部过程。
这些接口的工作方式差不多,都是先从网络上把数据集下载到指定目录,然后再用加载器把数据集加载到内存中,最后将加载后的数据集作为对象返回给用户。
以 MNIST 为例,我们可以用如下方式调用:
```
# 以MNIST为例
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=None,
target_transform=None,
download=True)
```
torchvision.datasets.MNIST 是一个类,对它进行实例化,即可返回一个 MNIST 数据集对象。构造函数包括包含 5 个参数:
- root是一个字符串用于指定你想要保存 MNIST 数据集的位置。如果 download 是 Flase则会从目标位置读取数据集
- download布尔类型表示是否下载数据集。如果为 True则会自动从网上下载这个数据集存储到 root 指定的位置。如果指定位置已经存在数据集文件,则不会重复下载;
- train布尔类型表示是否加载训练集数据。如果为 True则只加载训练数据。如果为 False则只加载测试数据集。这里需要注意并不是所有的数据集都做了训练集和测试集的划分这个参数并不一定是有效参数具体需要参考官方接口说明文档
- transform用于对图像进行预处理操作例如数据增强、归一化、旋转或缩放等。这些操作我们会在下节课展开讲解
- target_transform用于对图像标签进行预处理操作。
运行上述的代码后,程序会首先去指定的网址下载了 MNIST 数据集,然后进行了解压缩等操作。如果你再次运行相同的代码,则不会再有下载的过程。
如果你用 type 函数查看一下 mnist_dataset 的类型,就可以得到 torchvision.datasets.mnist.MNIST ,而这个类是之前我们介绍过的 Dataset 类的派生类。相当于 torchvision.datasets ,它已经帮我们写好了对 Dataset 类的继承,完成了对数据集的封装,我们直接使用即可。
这里我们主要以 MNIST 为例进行了说明。其它的数据集使用方法类似调用的时候你只要需要将类名“MNIST”换成其它数据集名字即可。
# 数据预览
完成了数据读取工作,我们得到的是对应的 mnist_dataset刚才已经讲过了这是一个封装了的数据集。
如果想要查看 mnist_dataset 中的具体内容,我们需要把它转化为列表。(如果 IOPub data rate 超限,可以只加载测试集数据,令 train=False
```
mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)
```
转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。
这里图像数据是 PIL.Image.Image 类型的,这种类型可以直接在 Jupyter 中显示出来。显示一条数据的代码如下。
```
display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])
```
上面介绍了两种读取数据的方法,也就是自定义和读取常用图像数据集。最通用的数据读取方法,就是自己定义一个 Dataset 的派生类。而读取常用的图像数据集,就可以利用 PyTorch 提供的视觉包 Torchvision。
极客时间版权所有: [https://time.geekbang.org/column/article/429826](https://time.geekbang.org/column/article/429826)
(有删改)