# 数据读取 Torchvision 中默认使用的图像加载器是 PIL,因此为了确保 Torchvision 正常运行,我们还需要安装一个 Python 的第三方图像处理库——Pillow 库。Pillow 提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。 我们先介绍 Torchvision 的常用数据集及其读取方法。 PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。 下面我们分别来看下 Dataset 类与 DataLoader 类。 # Dataset 类 PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset 类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。 其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法: - __init__():构造函数,可自定义数据读取方法以及进行数据预处理; - __len__():返回数据集大小; - __getitem__():索引数据集中的某一个数据。 下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个 Tensor 类型的数据集。 ``` import torch from torch.utils.data import Dataset class MyDataset(Dataset): # 构造函数 def __init__(self, data_tensor, target_tensor): self.data_tensor = data_tensor self.target_tensor = target_tensor # 返回数据集大小 def __len__(self): return self.data_tensor.size(0) # 返回索引的数据与标签 def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] ''' 我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入 Tensor 类型的数据与标签; 在 __len__ 函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。 ''' ``` 然后我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据 Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。 ``` # 生成数据 data_tensor = torch.randn(10, 3) target_tensor = torch.randint(2, (10,)) # 标签是0或1 # 生成10个随机数,随机数的范围只能是0或者1 # 将数据封装成Dataset my_dataset = MyDataset(data_tensor, target_tensor) # 查看数据集大小 print('Dataset size:', len(my_dataset)) # 使用索引调用数据 print('tensor_data[0]: ', my_dataset[0]) ``` # DataLoader 类 在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。 DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。 DataLoader 类的调用方式如下: ``` from torch.utils.data import DataLoader tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数 batch_size=2, # 输出的batch大小 shuffle=True, # 数据是否打乱 num_workers=0) # 进程数, 0表示只有主进程 # 以循环形式输出 for data, target in tensor_dataloader: print(data, target) ''' 输出: tensor([[-0.1781, -1.1019, -0.1507], [-0.6170, 0.2366, 0.1006]]) tensor([0, 0]) tensor([[ 0.9451, -0.4923, -1.8178], [-0.4046, -0.5436, -1.7911]]) tensor([0, 0]) tensor([[-0.4561, -1.2480, -0.3051], [-0.9738, 0.9465, 0.4812]]) tensor([1, 0]) tensor([[ 0.0260, 1.5276, 0.1687], [ 1.3692, -0.0170, -1.6831]]) tensor([1, 0]) tensor([[ 0.0515, -0.8892, -0.1699], [ 0.4931, -0.0697, 0.4171]]) tensor([1, 0]) ''' # 输出一个batch(用iter()强制类型转换成迭代器的对象,next()是输出迭代器下一个元素) print('One batch tensor data: ', iter(tensor_dataloader).next()) ''' 输出: One batch tensor data: [tensor([[ 0.9451, -0.4923, -1.8178], [-0.4046, -0.5436, -1.7911]]), tensor([0, 0])] ''' ``` 结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示: - dataset:Dataset 类型,输入的数据集,必须参数; - batch_size:int 类型,每个 batch 有多少个样本; - shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱; - num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进程,默认为 0。 思考题 按照上述代码,One batch tensor data 的输出是否正确,若不正确,为什么? # 利用 Torchvision 读取数据 Torchvision 库中的 torchvision.datasets 包中提供了丰富的图像数据集的接口。常用的图像数据集,例如 MNIST、COCO 等,这个模块都为我们做了相应的封装。 下表中列出了 torchvision.datasets 包所有支持的数据集。各个数据集的说明与接口,详见链接 [https://pytorch.org/vision/stable/datasets.html](https://pytorch.org/vision/stable/datasets.html)。 ![](https://pic-hdu-cs-wiki-1307923872.cos.ap-shanghai.myqcloud.com/boxcnxvqC7FKt1qeCZoI2kVf9yg.png) 注意,torchvision.datasets 这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。 为了让你进一步加深对知识的理解,我们以 MNIST 数据集为例,来说明一下这个模块具体的使用方法。 # MNIST 数据集简介 MNIST 数据集是一个著名的手写数字数据集,因为上手简单,在深度学习领域,手写数字识别是一个很经典的学习入门样例。 MNIST 数据集是 NIST 数据集的一个子集,MNIST 数据集你可以通过[这里](http://yann.lecun.com/exdb/mnist/)下载。它包含了四个部分。 ![](https://pic-hdu-cs-wiki-1307923872.cos.ap-shanghai.myqcloud.com/boxcnCP2Sp932nPy8Il5Z5d4Aih.png) MNIST 数据集是 ubyte 格式存储,我们先将“训练集图片”解析成图片格式,来直观地看一看数据集具体是什么样子的。具体怎么解析,我们在后面数据预览再展开。 ![](https://pic-hdu-cs-wiki-1307923872.cos.ap-shanghai.myqcloud.com/boxcnjsG31hhjqdxOnoCGFGR6sh.png) 接下来,我们看一下如何使用 Torchvision 来读取 MNIST 数据集。 对于 torchvision.datasets 所支持的所有数据集,它都内置了相应的数据集接口。例如刚才介绍的 MNIST 数据集,torchvision.datasets 就有一个 MNIST 的接口,接口内封装了从下载、解压缩、读取数据、解析数据等全部过程。 这些接口的工作方式差不多,都是先从网络上把数据集下载到指定目录,然后再用加载器把数据集加载到内存中,最后将加载后的数据集作为对象返回给用户。 以 MNIST 为例,我们可以用如下方式调用: ``` # 以MNIST为例 import torchvision mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=None, target_transform=None, download=True) ``` torchvision.datasets.MNIST 是一个类,对它进行实例化,即可返回一个 MNIST 数据集对象。构造函数包括包含 5 个参数: - root:是一个字符串,用于指定你想要保存 MNIST 数据集的位置。如果 download 是 Flase,则会从目标位置读取数据集; - download:布尔类型,表示是否下载数据集。如果为 True,则会自动从网上下载这个数据集,存储到 root 指定的位置。如果指定位置已经存在数据集文件,则不会重复下载; - train:布尔类型,表示是否加载训练集数据。如果为 True,则只加载训练数据。如果为 False,则只加载测试数据集。这里需要注意,并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档; - transform:用于对图像进行预处理操作,例如数据增强、归一化、旋转或缩放等。这些操作我们会在下节课展开讲解; - target_transform:用于对图像标签进行预处理操作。 运行上述的代码后,程序会首先去指定的网址下载了 MNIST 数据集,然后进行了解压缩等操作。如果你再次运行相同的代码,则不会再有下载的过程。 如果你用 type 函数查看一下 mnist_dataset 的类型,就可以得到 torchvision.datasets.mnist.MNIST ,而这个类是之前我们介绍过的 Dataset 类的派生类。相当于 torchvision.datasets ,它已经帮我们写好了对 Dataset 类的继承,完成了对数据集的封装,我们直接使用即可。 这里我们主要以 MNIST 为例,进行了说明。其它的数据集使用方法类似,调用的时候你只要需要将类名“MNIST”换成其它数据集名字即可。 # 数据预览 完成了数据读取工作,我们得到的是对应的 mnist_dataset,刚才已经讲过了,这是一个封装了的数据集。 如果想要查看 mnist_dataset 中的具体内容,我们需要把它转化为列表。(如果 IOPub data rate 超限,可以只加载测试集数据,令 train=False) ``` mnist_dataset_list = list(mnist_dataset) print(mnist_dataset_list) ``` 转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。 这里图像数据是 PIL.Image.Image 类型的,这种类型可以直接在 Jupyter 中显示出来。显示一条数据的代码如下。 ``` display(mnist_dataset_list[0][0]) print("Image label is:", mnist_dataset_list[0][1]) ``` 上面介绍了两种读取数据的方法,也就是自定义和读取常用图像数据集。最通用的数据读取方法,就是自己定义一个 Dataset 的派生类。而读取常用的图像数据集,就可以利用 PyTorch 提供的视觉包 Torchvision。 极客时间版权所有: [https://time.geekbang.org/column/article/429826](https://time.geekbang.org/column/article/429826) (有删改)