chore: Syntax Highlighting
This commit is contained in:
@@ -20,7 +20,7 @@ PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。
|
||||
|
||||
下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个 Tensor 类型的数据集。
|
||||
|
||||
```
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -43,7 +43,7 @@ class MyDataset(Dataset):
|
||||
|
||||
然后我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据 Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。
|
||||
|
||||
```
|
||||
```python
|
||||
# 生成数据
|
||||
data_tensor = torch.randn(10, 3)
|
||||
target_tensor = torch.randint(2, (10,)) # 标签是0或1
|
||||
@@ -67,7 +67,7 @@ DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Datas
|
||||
|
||||
DataLoader 类的调用方式如下:
|
||||
|
||||
```
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
|
||||
batch_size=2, # 输出的batch大小
|
||||
@@ -143,7 +143,7 @@ MNIST 数据集是 ubyte 格式存储,我们先将“训练集图片”解析
|
||||
|
||||
以 MNIST 为例,我们可以用如下方式调用:
|
||||
|
||||
```
|
||||
```python
|
||||
# 以MNIST为例
|
||||
import torchvision
|
||||
mnist_dataset = torchvision.datasets.MNIST(root='./data',
|
||||
@@ -173,7 +173,7 @@ torchvision.datasets.MNIST 是一个类,对它进行实例化,即可返回
|
||||
|
||||
如果想要查看 mnist_dataset 中的具体内容,我们需要把它转化为列表。(如果 IOPub data rate 超限,可以只加载测试集数据,令 train=False)
|
||||
|
||||
```
|
||||
```python
|
||||
mnist_dataset_list = list(mnist_dataset)
|
||||
print(mnist_dataset_list)
|
||||
```
|
||||
@@ -182,7 +182,7 @@ print(mnist_dataset_list)
|
||||
|
||||
这里图像数据是 PIL.Image.Image 类型的,这种类型可以直接在 Jupyter 中显示出来。显示一条数据的代码如下。
|
||||
|
||||
```
|
||||
```python
|
||||
display(mnist_dataset_list[0][0])
|
||||
print("Image label is:", mnist_dataset_list[0][1])
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user