PyTorch 02 - Dataloaders and Transforms

Datasets & DataLoaders

Note

Adapted from PyTorch Quickstart to use CIFAR10 dataset instead of FashionMNIST.

  • Code for processing data samples can get messy and hard to maintain
  • We ideally want our dataset code to be decoupled from our model training code for better readability and modularity.

PyTorch Data Primitives

PyTorch provides two data primitives:

  1. torch.utils.data.DataLoader and
  2. torch.utils.data.Dataset

that allow you to use pre-loaded datasets as well as your own data.

  • Dataset stores the samples and their corresponding labels, and
  • DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
Note

An iterable is a Python object capable of returning its members one at a time. It must implement the __iter__ method or the __getitem__ method. See Iterators for more details.

Pre-loaded Datasets

PyTorch provides a number of pre-loaded datasets (such as FashionMNIST or CIFAR10) that subclass torch.utils.data.Dataset and implement functions specific to the particular data.

They can be used to prototype and benchmark your model.

You can find them here:

  • Image Datasets
    View sublist
    • image classification
    • object detection or segmentation
    • optical flow
    • stereo matching
    • image pairs
    • image captioning
    • video classification
    • video prediction
  • Text Datasets
    View sublist
    • text classification
    • language modeling
    • machine translation
    • sequence tagging
    • question answering
    • unsupervised learning
  • Audio Datasets

Loading a Dataset

Here is an example of how to load the CIFAR10 dataset from TorchVision.

We load the CIFAR10 with the following parameters:

  • root is the path where the train/test data is stored,
  • train if True, specifies training dataset, if False, specifies test dataset,
  • download=True downloads the data from the internet if it’s not available at root.
  • transform and target_transform specify the feature and label transformations. More on that later.
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
  0%|          | 0.00/170M [00:00<?, ?B/s]  2%|▏         | 2.95M/170M [00:00<00:05, 29.4MB/s]  8%|▊         | 14.2M/170M [00:00<00:01, 78.3MB/s] 15%|█▌        | 25.6M/170M [00:00<00:01, 94.5MB/s] 22%|██▏       | 37.0M/170M [00:00<00:01, 102MB/s]  28%|██▊       | 48.1M/170M [00:00<00:01, 105MB/s] 35%|███▍      | 59.2M/170M [00:00<00:01, 107MB/s] 41%|████      | 69.9M/170M [00:00<00:00, 105MB/s] 47%|████▋     | 80.4M/170M [00:00<00:00, 102MB/s] 53%|█████▎    | 90.6M/170M [00:00<00:00, 100MB/s] 59%|█████▉    | 101M/170M [00:01<00:00, 98.8MB/s] 65%|██████▍   | 111M/170M [00:01<00:00, 96.7MB/s] 71%|███████   | 120M/170M [00:01<00:00, 95.0MB/s] 76%|███████▌  | 130M/170M [00:01<00:00, 92.2MB/s] 82%|████████▏ | 139M/170M [00:01<00:00, 89.7MB/s] 87%|████████▋ | 148M/170M [00:01<00:00, 88.3MB/s] 92%|█████████▏| 157M/170M [00:01<00:00, 86.9MB/s] 97%|█████████▋| 166M/170M [00:01<00:00, 84.2MB/s]100%|██████████| 170M/170M [00:01<00:00, 92.5MB/s]

In this case we simply use the ToTensor transform which converts the image from a (H x W x C) shape to a (C x H x W) shape and converts the pixel values from [0,255] to a torch.FloatTensor in the range [0.0, 1.0].

img, label = training_data[0]
print(f"img.shape: {img.shape}")
print(f"img.dtype: {img.dtype}")
img.shape: torch.Size([3, 32, 32])
img.dtype: torch.float32

Let’s also look at the data directory contents.

% ls -al data
total 333008
drwxr-xr-x@  6 tomg  staff        192 Oct  1 22:14 .
drwxr-xr-x  62 tomg  staff       1984 Oct 17 11:09 ..
drwxr-xr-x@ 10 tomg  staff        320 Jun  4  2009 cifar-10-batches-py
-rw-r--r--@  1 tomg  staff  170498071 Sep  4 19:51 cifar-10-python.tar.gz

And the contents of the cifar-10-batches-py directory.

% ls -al data/cifar-10-batches-py 
total 363752
drwxr-xr-x@ 10 tomg  staff       320 Jun  4  2009 .
drwxr-xr-x@  6 tomg  staff       192 Oct  1 22:14 ..
-rw-r--r--@  1 tomg  staff       158 Mar 31  2009 batches.meta
-rw-r--r--@  1 tomg  staff  31035704 Mar 31  2009 data_batch_1
-rw-r--r--@  1 tomg  staff  31035320 Mar 31  2009 data_batch_2
-rw-r--r--@  1 tomg  staff  31035999 Mar 31  2009 data_batch_3
-rw-r--r--@  1 tomg  staff  31035696 Mar 31  2009 data_batch_4
-rw-r--r--@  1 tomg  staff  31035623 Mar 31  2009 data_batch_5
-rw-r--r--@  1 tomg  staff        88 Jun  4  2009 readme.html
-rw-r--r--@  1 tomg  staff  31035526 Mar 31  2009 test_batch

You see in this case the images aren’t stored individually, but rather as combined batches. The CIFAR10 Dataset and DataLoader classes hide these details from us.

Iterating and Visualizing the Dataset

It’s very important to understand the data you’re working with. Visually inspecting the dataset is a good way to get started.

We can index Datasets manually like a list: training_data[index]. In this case we randomly sample images from the dataset.

We use matplotlib to visualize some samples in our training data.

labels_map = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

figure = plt.figure(figsize=(6, 6))
cols, rows = 4, 4

for i in range(1, cols * rows + 1):
    # Randomly choose indices
    sample_idx = torch.randint(len(training_data), size=(1,)).item()

    img, label = training_data[sample_idx]
    #print(f"img.shape: {img.shape}")
    #print(f"label: {label}")
    
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.permute(1, 2, 0))
plt.show()

Tip

Try re-running the above cell a few times to see different samples from the dataset.

Collecting Sample Data to Illustrate Custom Dataset

To illustrate creating a custom dataset, we will collect images from the CIFAR10 dataset and save them to a local directory. We will also save the labels to a CSV file.

import os
import pandas as pd
from torchvision import datasets
from torchvision.transforms import ToTensor
from PIL import Image

# Create directories to store images and annotations
os.makedirs('cifar10_images', exist_ok=True)
annotations_file = 'cifar10_annotations.csv'

# Load CIFAR10 dataset
cifar10 = datasets.CIFAR10(root='data', train=True, download=True, transform=ToTensor())

# Number of images to download
n_images = 10

# Store images and their labels
data = []
for i in range(n_images):
    img, label = cifar10[i]
    img = img.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
    img = (img * 255).byte().numpy()  # Convert to numpy array and scale to [0, 255]
    img = Image.fromarray(img)  # Convert to PIL Image

    img_filename = f'cifar10_images/img_{i}.png'
    img.save(img_filename)  # Save image

    data.append([img_filename, label])  # Append image path and label to data list

# Write annotations to CSV file
df = pd.DataFrame(data, columns=['image_path', 'label'])
df.to_csv(annotations_file, index=False)

print(f"Saved {n_images} images and their labels to {annotations_file}")
Saved 10 images and their labels to cifar10_annotations.csv
# List the directory `cifar_images
import os
os.listdir('cifar10_images')
['img_9.png',
 'img_6.png',
 'img_7.png',
 'img_2.png',
 'img_0.png',
 'img_8.png',
 'img_5.png',
 'img_1.png',
 'img_4.png',
 'img_3.png']
import pandas as pd

# Read the annotations file
annotations = pd.read_csv('cifar10_annotations.csv')

# Display the first 10 lines of the annotations file
print(annotations.head(10))
                 image_path  label
0  cifar10_images/img_0.png      6
1  cifar10_images/img_1.png      9
2  cifar10_images/img_2.png      9
3  cifar10_images/img_3.png      4
4  cifar10_images/img_4.png      1
5  cifar10_images/img_5.png      1
6  cifar10_images/img_6.png      2
7  cifar10_images/img_7.png      7
8  cifar10_images/img_8.png      8
9  cifar10_images/img_9.png      3

Creating a Custom Dataset Class

A custom Dataset class must implement three functions:

  1. __init__,
  2. __len__, and
  3. __getitem__.

We’ll look at an example implementation.

The CIFAR10 images are stored in a directory and their labels are stored separately in a CSV file.

import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image

class MyCustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_labels.iloc[idx, 0]
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

The __init__ function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

The len() function returns the number of samples in our dataset.

Example:

def __len__(self):
    return len(self.img_labels)

__getitem__

The getitem() function loads and returns a sample from the dataset at the given index idx.

Based on the index, it

  • identifies the image’s location on disk,
  • converts that to a tensor using read_image,
  • retrieves the corresponding label from the csv data in self.img_labels,
  • calls the transform functions on them (if applicable), and
  • returns the tensor image and corresponding label in a tuple.
def __getitem__(self, idx):
    img_path = self.img_labels.iloc[idx, 0]
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]

    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label


Let’s try to use it.

from torchvision.transforms import ToTensor

mydataset = MyCustomImageDataset(annotations_file='cifar10_annotations.csv', img_dir='cifar10_images') # , transform=ToTensor())

img, label = mydataset[0]
print(f"img.shape: {img.shape}")
print(f"label: {label}")
img.shape: torch.Size([3, 32, 32])
label: 6
import matplotlib.pyplot as plt

figure = plt.figure(figsize=(6, 6))
cols, rows = 2, 2

for i in range(1, cols * rows + 1):
    # Randomly choose indices
    sample_idx = torch.randint(len(mydataset), size=(1,)).item()

    img, label = mydataset[sample_idx]
    #print(f"img.shape: {img.shape}")
    #print(f"label: {label}")
    
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.permute(1, 2, 0))
plt.show()


Preparing your data for training with DataLoaders

The Dataset retrieves your dataset’s features and labels one sample at a time.

While training a model, we typically want to

  • pass samples in “batches”,
  • reshuffle the data at every epoch to reduce model overfitting, and
  • use Python’s multiprocessing to speed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API.

Because MyCustomImageDataset is a subclass of Dataset, we can use it with a DataLoader just as we would use any other Dataset.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Iterate through the DataLoader

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed.

Each iteration below returns a batch of train_features and train_labels (containing batch_size=64 features and labels respectively).

Because we specified shuffle=True, after we iterate over all batches the data is shuffled

(for finer-grained control over the data loading order, take a look at Samplers).

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])

Note from the size that we got a batch of 64 images.

img = train_features[0].squeeze() # squeeze() removes dimension of size 1, e.g. (1, 3, 32, 32) -> (3, 32, 32)
label = train_labels[0]

plt.imshow(img.permute(1, 2, 0))
plt.show()

print(f"Label: {labels_map[label.item()]}")

Label: frog

Transforms and Data Augmentation

Data does not always come in its final processed form that is required for training machine learning algorithms.

We use transforms to perform some manipulation of the data and make it suitable for training as well as augmentation.

All TorchVision dataset classes have two parameters

  • transform to modify the features and
  • target_transform to modify the labels

They accept callables containing the transformation logic.

The CIFAR10 images are in PIL Image format, and the labels are integers.

For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors.

To make these transformations, we use ToTensor and Lambda.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

ToTensor()

ToTensor converts a PIL image or NumPy ndarray into a FloatTensor and scales the image’s pixel intensity values in the range \(0., 1.\)

Lambda Transforms

Lambda transforms apply any user-defined lambda function.

Here, we define a function to turn the integer into a one-hot encoded tensor.

It

  • first creates a zero tensor of size 10 (the number of labels in our dataset) and
  • calls scatter_ which assigns a value=1 on the index as given by the label y.
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

Transforms V2 and Data Augmentation

The torchvision.transforms module offers many commonly-used transforms for image and video data.

Adapted from Getting started with transforms v2

Transforms V2 Example

We’ll illustrate with an example.

First, a bit of setup

from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'

from torchvision.transforms import v2
from torchvision.io import read_image

import os
import urllib.request
import zipfile

# Check if the directory exists
if not os.path.exists('transform_v2_files'):
    # Download the zip file
    url = 'https://github.com/trgardos/ml-549-fa24/raw/refs/heads/main/transform_v2_files.zip?download='
    zip_path = 'transform_v2_files.zip'
    urllib.request.urlretrieve(url, zip_path)
    
    # Unzip the file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall('.')
    
    # Remove the zip file
    os.remove(zip_path)


torch.manual_seed(1)

from transform_v2_files.transforms.helpers import plot
img = read_image(str(Path('transform_v2_files/assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.uint8, img.shape = torch.Size([3, 512, 512])

The basics

The Torchvision transforms behave like a regular torch.nn.Module:

  • instantiate a transform,
  • pass an input,
  • get a transformed output:
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)

plot([img, out])

This transform takes a random crop of the image of size 224x224. Try re-running the cell a few times to see different samples.

Transforms for image classification

Here’s a basic image classification transform pipeline:

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])

Again, try re-running the cell a few times to see different samples.

Such transformation pipeline is typically passed as the transform argument to the Datasets class, e.g. ImageNet(...,transform=transforms).

See the Transforms Gallery for examples of how to use augmentation transforms like CutMix and MixUp.

Detection, Segmentation, Videos

Besides images, transforms can also transform bounding boxes, segmentation / detection masks, or videos.

Let’s look at an example with bounding boxes.

from torchvision import tv_tensors  # we'll describe this a bit later

boxes = tv_tensors.BoundingBoxes(
    [
        [15, 10, 370, 510],
        [275, 340, 510, 510],
        [130, 345, 210, 425]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])
<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>

This also works for: * masks (torchvision.tv_tensors.Mask) for object segmentation or semantic segmentation, or * videos (torchvision.tv_tensors.Video)

We could have passed them to the transforms in exactly the same way.

What are TVTensors?

TVTensors are torch.Tensor subclasses.

The available TVTensors are Image, BoundingBoxes, Mask, and Video.

TVTensors look and feel just like regular tensors - they are tensors.

Everything that is supported on a plain torch.Tensor like .sum() or any torch.* operator will also work on a TVTensor:

img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))

print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25087958)

These TVTensor classes are at the core of the transforms:

in order to transform a given input, the transforms first look at the class of the object, and dispatch to the appropriate implementation accordingly.

For more info, see the TVTensors FAQ.

What do I pass as input?

Above, we’ve seen two examples:

  • one where we passed a single image as input i.e. out = transforms(img), and
  • one where we passed both an image and bounding boxes, i.e. out_img, out_boxes = transforms(img, boxes).

Transforms support arbitrary input structures.

The input can be e.g. a single image, a tuple, an arbitrarily nested dictionary…

The same structure will be returned as output.

Below, we use the same detection transforms, but pass a tuple (image, target_dict) as input and we’re getting the same structure as output:

target = {
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
    "this_is_ignored": ("arbitrary", {"structure": "!"})
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
('arbitrary', {'structure': '!'})

We passed a tuple so we get a tuple back, and the second element is the tranformed target dict.

Transforms don’t really care about the structure of the input; as mentioned above, they only care about the type of the objects and transforms them accordingly.

Foreign objects like strings or ints are simply passed-through. This can be useful e.g. if you want to associate a path with every single sample when debugging!

Transforms and Datasets intercompatibility

Roughly speaking, the output of the datasets must correspond to the input of the transforms. How to do that depends on whether you’re using the torchvision built-in datatsets, or your own custom datasets.

Using built-in datasets

If you’re just doing image classification, you don’t need to do anything. Just use transform argument of the dataset e.g. ImageNet(..., transform=transforms).

Compatbility with Older Datasets

Torchvision also supports datasets for object detection or segmentation like torchvision.datasets.CocoDetection. Those datasets predate the existence of the transforms.v2 module and of the TVTensors, so they don’t return TVTensors out of the box.

An easy way to force those datasets to return TVTensors and to make them compatible with v2 transforms is to use the wrap_dataset_for_transforms_v2 function:

from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2

dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!

Using your own datasets

If you have a custom dataset, then you’ll need to convert your objects into the appropriate TVTensor classes.

To create TVTensor instances refer to How do I create TVTensors? for more details.

There are two main places where you can implement that conversion logic:

  • At the end of the datasets’s __getitem__ method, before returning the sample (or by sub-classing the dataset).
  • As the very first step of your transforms pipeline

Either way, the logic will depend on your specific dataset.

V2 Transforms

See V2 API Snapshot for a more complete list.

Further Reading

Back to top