Building a custom dataset for PyTorch

Preparing a dataset is one of the most important parts of any machine/deep learning workflow. This post will help you build a custom dataset for any training purpose. Dataset building is very important as different transformations are applied at this stage. Augmentations on any image or transformations on any data point holds a very important value. Experiments have shown that augmentations can help improve accuracy by some amount.

If you look at the PyTorch documentation for datasets, it says

The most important argument of the DataLoader constructor is the dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

 - map-style datasets,

 - iterable-style datasets.
  • DataLoaders are used for creating batches of data. It'll be explained soon.

We'll be using the map-styled datasets, where we write the __getitem__ and __len__ protocols. The torchvision package provides inbuilt dataset functionality, but that is limited to certain types, for example torchvision.datasets.ImageFolder lets you build datasets for images, given that the images are arranged in the following way.



So if your classes are already sorted out, you could use this else follow the rest of the guide.

We'll be using the MNIST dataset that comes in a CSV file. Download the dataset here

  • Every custom dataset we build should subclass the class. This is because some of the basic functionalities of a dataset are already defined and we just extend them.

There are three main components in a Dataset

  • init: basic init of any class
  • getitem: returns a single data point
  • len: returns the length of the dataset
def __init__(self, path: str, transform = None): = np.loadtxt(path, delimiter=',', skiprows=1)
    self.transform = transform
def __len__(self):
    return len(

Now let's define the __getitem__. It should only return one data point(data+labels)

def __getitem__(self, idx):
    data_point =[idx, :] # get data at index=idx

    label = data_point[0].astype(np.int64) # mnist labels are at index: 0 of each row. It is important to cast it to an int

    image_data = data_point[1:].astype(np.float32) # get the image data, it is a 784x1 vector

    image_data = image_data.reshape(28, 28) # reshape the vector to image matrix
    if self.transform:
        image_data = self.transform(image_data)
    return image_data, torch.tensor(label) 
    # data can be returned as tuple or as a dictionary,
    # return {"image": image_data, "label": torch.tensor(label)}

Whole code

from import Dataset

class MNISTDataset(Dataset):
    def __init__(self, path: str, transform): = np.loadtxt(path, delimiter=',', skiprows=1)
        self.transform = transform

    def __getitem__(self, idx: int):
        data_point =[idx, :]
        label = data_point[0].astype(np.int64)
        image_data = data_point[1:].astype(np.float32)
        image_data = image_data.reshape(28, 28)
        if self.transform:
            image_data = self.transform(image_data)
        return image_data, torch.tensor(label)

    def __len__(self):
        return len(

Now that the dataset is created, let's create a batch. The batching of data is handled by DataLoader. DataLoader offers speed over dataset(by indexing as dataset[i:i+batch_size]) by creating multiple workers. So while the training is done, another worker will load the data hence no delay.

import torchvision.transforms as transforms

dataset = MNISTDataset("train.csv", transform=transforms.ToTensor())
# transforms.ToTensor() converts image to tensor format and normalizes it to [0.0, 1.0]

sample_loader = DataLoader(dataset, batch_size=32, shuffle=True) #creates a batch of size 32 and shuffles the data. What this means is the `idx` value passed into `__getitem__` would be random.

from torchvision.utils import make_grid

images, labels = next(iter(sample_loader)) # get one sample batch
grid = make_grid(images) # make a grid
plt.imshow(grid.permute(1, 2, 0), cmap='gray')
# PyTorch batches are of type BCHW(Batch, Channels, Height, Width). since we used make_grid, the batch becomes a single image. Now plt expects the image to be (HWC), so we permute image to (1, 2, 0).


tensor([9, 9, 3, 3, 1, 6, 5, 6, 4, 0, 0, 7, 2, 6, 6, 3, 2, 0, 9, 6, 6, 4, 7, 0, 1, 8, 3, 7, 1, 1, 5, 0])

That's how you create a custom dataset. It's that simple.

No Comments Yet