Review of intuitions behind the recent advances in NLP: From RNNs to Transformers and BERT
Minimal PyTorch LSTM example for regression and classification tasks
Show all

A complete guide to writing custom Datasets and DataLoader in PyTorch

19 mins read

Table of Contents

An Introduction To PyTorch Dataset and DataLoader
Why Write Good Data Loaders and Datasets?
The Basic PyTorch Dataset Structure
Implementing A Custom Dataset In PyTorch
Best Practices For Creating Custom Datasets
The Basic PyTorch DataLoader Class Structure
Example: Creating A Data Loader From A Dataset
Using Custom Samplers For More Control Over Data Loading

An Introduction To PyTorch Dataset and DataLoader

In this tutorial, we’ll go through the PyTorch data primitives, namely and, and understand how the pre-loaded datasets work and how to create our own DataLoader and Datasets by subclassing these modules.

Why Write Good Data Loaders and Datasets?

Why should we learn how to write good data loaders and datasets? Isn’t modeling the most important of a Deep Learning Pipeline.

Your training pipeline should be as modular as possible in order to aid in quick prototyping and maintain usability. Using a poorly-written data loader / not using a data loader (using a python generator or some function), can affect the parallelization ability of your code. Dataset processing is a highly important part of any training pipeline and should be kept separate from modeling.

The same technique won’t work everywhere. Some problems might require you to use image augmentations, therefore you’d prefer to have an argument (something like data = Dataset(…, fetch = True) ) to test the model’s performance. Or you might need to experiment with different sequence lengths and strides for fine-tuning an NLP model. To these ends, it’s recommended to use custom Datasets and DatatLoaders.

The Basic PyTorch Dataset Structure

The following code snippet contains the original implementation of the Dataset class from PyTorch. All pre-loaded Datasets inherit from this basic structure.

class Dataset(...):

    # Raises NotImplementedError
    def __getitem__(self, index):
    # Allows us to Add/Concat Datasets
    def __add__(self, other):

    # Returns the Attribute value or raises a AttributeError
    def __getattr__(self, attribute_name):

    # Utility methods to "Register" Functions
    def register_function(cls, ...):

    # Utility methods to "Register" Functions
    def register_datapipe_as_function(cls, ...):

As it has such a simple structure, you don’t always need to inherit from In most cases, we can get away by writing some key functions.

Implementing A Custom Dataset In PyTorch

Now, for most purposes, you will need to write your own implementation of a Dataset. So let’s see how you can write a custom dataset by subclassing

You’ll need to implement 3 functions:

  1. __init__ : This function is called when instancing the object. It’s typically used to store some essential locations such as file paths and image transforms.
  2. __len__ : This function returns the length of the dataset.
  3. __getitem__ : This is the big kahuna 🏅. This function is responsible for returning a sample from the dataset based on the index provided.
class CustomDataset(

	# Basic Instantiation
	def __init__(self, ..., *args, **kwargs):
	# Length of the Dataset
	def __len__(self):
	# Fetch an item from the Dataset
	def __getitem__(self, idx):

Let’s walk through some examples of Custom Datasets.

The Flicker Dataset

This code snippet is taken from my Kaggle Kernel on Neural Image Captioning. Let’s walk through the code:

  1. The __init__ method contains a reference to the dataframe containing references to the image paths and a transforms variable containing a list of image augmentations.
  2. The __len__ method returns the length of the dataframe. (The default python len function, is implemented for pandas)
  3. The __getitem__ method reads the image using PIL, applies the transforms that are needed, encodes the comments, and returns a dictionary with custom key values.
class FlickrDataset(Dataset):
    def __init__(self, df, 
        self.df = df
        self.transforms = T.Compose([
            T.Normalize(mean = [0.5], std = [0.5]),
    def __len__(self) -> int:
        return len(self.df)
    def __getitem__(self, idx: int):
        image_id = self.df.image_name.values[idx]
        image ='RGB')
        if self.transforms is not None:
            image = self.transforms(image)
        comments = self.df[self.df.image_name == image_id].values.tolist()[0][1:][0]
        encoded_inputs = tokenizer(comments,
                                   return_token_type_ids = False, 
                                   return_attention_mask = False, 
                                   max_length = 100, 
                                   padding = "max_length",
                                  return_tensors = "pt")
        sample = {"image",
		"captions": encoded_inputs["input_ids"].flatten().to(device)
        return sample

RSNA Brain Tumor Competition Dataset

This code snippet is taken from my Custom Wrapper for the RSNA-MICCAI Brain Tumor Radiogenomic Classification Kaggle Competition. Let’s walk through the code :

  1. The __init__ method contains a reference to the paths, targets, the MRI type, and other such information.
  2. The __len__ method returns the length of the dataframe. (The default python len function, is implemented for pandas)
  3. The __getitem__ method reads the image using a custom function and returns a custom dictionary containing augmented images and targets (if needed).
class Dataset(torch_data.Dataset):
    def __init__(
        label_smoothing: float = 0.01,
        split: str = "train",
        augment: bool = False,
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.split = split
        self.augment = augment

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(
                str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split
            data = load_dicom_images_3d(
                str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train"

            if self.augment:
                data = seq(images=data)

        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
            y = torch.tensor(
                abs(self.targets[index] - self.label_smoothing), dtype=torch.float
            return {"X": torch.tensor(data).float(), "y": y}

Best Practices For Creating Custom Datasets

There are some general things you need to remember while creating custom datasets.

  • The index for your dataset should vary between the length obtained from the __len__ function, otherwise, it’ll throw an error.
  • We need to overwrite the __len__ function to overwrite the output of many Sampler implementations and the default options of DataLoader. (Reference: PyTorch docs)
  • In case you’re working with data that comes from a stream, you should subclass IterableDataset. For more information refer to the docs.

What is DataLoader in PyTorch?

Sometimes when working with a big dataset it becomes quite difficult to load the entire data into the memory at once. As such the only way forward is to load data into memory in batches for processing, this means you may have to write extra code to do this. But do not worry, PyTorch has you covered with its Dataloader function.

The dataloader function is available in PyTorch class and supports the following tasks –

  1. Customization of Data Loading Order
  2. Map-Style and Iterable-Style Datsets
  3. Automatic Batching
  4. Data Loading with single and multiple processes
  5. Automatic Memory Pinning

The Basic PyTorch DataLoader Class Structure

The following code snippet contains the original implementation of the DataLoader class from PyTorch.

class DataLoader(...):

    # Basic __init__ function
    def __init__(self,..):

    # Returns Either a Single or a Multi Process Iterator
    def _get_iterator(self):

    # Handle Multiprocessing
    def multiprocessing_context(self):

    # Handle Multiprocessing
    def multiprocessing_context(self, multiprocessing_context):

    # Override default __setattr__ method
    def __setattr__(self, attr, val):

    # Override default __iter__ method
    def __iter__(self):

    # Helper Function for collation
    def _auto_collation(self):

    # The Actual Sampler Used for fetching
    def _index_sampler(self):

    # Returns the length of the Index Sampler (in case of map-style dataset)
    def __len__(self) -> int:

    # Checks if the worker number is rational based on system resource
    def check_worker_number_rationality(self):

Now, this does look complicated, but in most cases, we don’t need to understand most of this. But it’s nice to know how PyTorch takes care of multiprocessing and handling different types of Iterators.

Syntax of PyTorch DataLoader

The following section shows the syntax of dataloader function in PyTorch library along with the information of its parameters.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,


  • Dataset – It is mandatory for a DataLoader class to be constructed with a dataset first. PyTorch Dataloaders support two kinds of datasets:
    • Map-style datasets – These datasets map keys to data samples. Each item is retrieved by a get_item() method implementation.
    • Iterable-style datasets – These datasets implement the iter() protocol. Such datasets retrieve data in a stream sequence rather than doing random reads as in the case of map datasets.
  • Batch size – Refers to the number of samples in each batch.
  • Shuffle – Whether you want the data to be reshuffled or not.
  • Sampler – refers to an optional class instance. A sampler defines the strategy to retrieve the sample – sequential or random or in any other manner. Shuffle should be set to false when a sampler is used.
  • Batch_Sampler – Same as the data sampler defined above, but works at a batch level.
  • num_workers – Number of sub-processes needed for loading the data.
  • collate_fn – Collates samples into batches. Customized collation is possible in Torch.
  • pin_memory – Pinned (page-locked) memory locations are used by GPUs for faster data access. When set to True, this option enables the data loader to copy tensors into the CUDA pinned memory.
  • drop_last – If the total data size is not a multiple of the batch_size, the last batch has less number of elements than the batch_size. This incomplete batch can be dropped by setting this option to True.
  • timeout – Sets the time to wait while collecting a batch from the workers (sub-processes).
  • worker_init_fn – Defines a routine to be called by each worker process. Allows customized routines.

Example of DataLoader in PyTorch

Example – DataLoaders with Built-in Datasets

This first example will showcase how the built-in MNIST dataset of PyTorch can be handled with dataloader function. (MNIST is a famous dataset that contains hand-written digits.)

import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

Here in this example, we are using the transforms module of torchvision. It is generally used when we have to handle image datasets and can help in normalizing, resizing, and cropping of the images.

For this MNIST dataset, we are using the normalization technique. This way the values from -0.5 to +0.5 are converted to values from 0 to 1.

The following code that contains the transforms function is used for normalization.

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),

The following code snippet is used for loading the desired dataset. We are using PyTorch dataloader to load the data by giving batch_size = 64 and we have also enabled shuffling for reordering data each epoch of data load.

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader =, batch_size=64, shuffle=True)
Extracting /root/.pytorch/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw


/usr/local/lib/python3.7/dist-packages/torchvision/datasets/ UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


For fetching all the images of the dataset, we are going to use iter function along with a dataloader.

In [5]:

dataiter = iter(trainloader)
images, labels =
plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r')


torch.Size([64, 1, 28, 28])
<matplotlib.image.AxesImage at 0x7fdc324cdb50>
PyTorch Dataloader

Example – DataLoaders on Custom Datasets

This second example shows how we can use PyTorch dataloader on custom datasets. So let us first create a custom dataset.

The below code snippet helps us to create a custom dataset that contains 1000 random numbers.

from import Dataset
import random
class SampleDataset(Dataset):
  def __init__(self,r1,r2):
    for i in range(1,1000):
      n = random.randint(r1,r2)
  def __len__(self):
      return len(self.samples)
  def __getitem__(self,idx):



Finally, we will be to use the dataloader function on our custom dataset. Notice that we have given the batch_size as 12 and have also enabled parallel multiprocess data loading with num_workers =2.

The output shows that the loaded data is divided into 12 different batches. Some of the tensors are displayed for reference.

from import DataLoader
loader = DataLoader(dataset,batch_size=12, shuffle=True, num_workers=2 )
for i, batch in enumerate(loader):
        print(i, batch)


0 tensor([ 16, 179, 246, 127, 263, 418,  33, 410, 107, 281, 438, 164])
1 tensor([421,  55, 183,  19,  47, 402, 336, 290, 241, 121, 308, 140])
2 tensor([265, 149,  62, 421,  67, 427, 302, 149, 134, 269, 116, 267])
3 tensor([318, 404, 365, 324, 229, 184,  10, 391,  71, 424, 387, 256])
4 tensor([178, 138, 200, 398, 420,  98, 147, 338, 341, 434,  58, 332])
5 tensor([403, 256, 290, 238, 186,  57, 343, 361, 388,  81, 271, 111])
6 tensor([340,  59,  73, 298, 275, 102,  20, 413,  95,  83, 380, 323])
7 tensor([ 71,  15, 443,  44, 394, 252, 103,  11, 383, 292,  57, 109])
8 tensor([398, 406,  84, 369, 272, 409, 367, 205, 353,  24, 305,  21])
9 tensor([280, 200,  79, 424,  26,  58, 233, 194, 362, 379, 228, 428])
10 tensor([316, 225, 231, 272, 382, 132, 306, 295, 150, 365, 420,  17])
11 tensor([280, 432,  51, 123, 356,  29, 172, 225, 143, 147, 226, 262])
12 tensor([208, 366, 267, 389, 135, 398, 359, 365,  52, 210, 152, 214])
69 tensor([ 43, 351, 383, 435, 368,  26, 316, 145, 409, 140, 224, 159])
70 tensor([210,  68, 404,  30,  32, 324,  18, 416, 340, 354, 337, 436])
71 tensor([414, 114, 233, 320, 105, 318, 326, 139, 319, 205,  69, 123])
72 tensor([165, 265, 381,  33, 392, 261,  57,  23, 131, 186, 232, 186])
73 tensor([404, 105, 345, 436,  51, 392, 263, 138, 364, 439,  12, 295])
74 tensor([163,  70, 137, 435, 250, 354, 190, 335,  39, 323, 365,  96])
75 tensor([148, 383, 322, 300, 309, 125,  46,  29, 231, 432, 258, 376])
76 tensor([314, 266, 248, 236, 296, 434,  93, 138, 140,  12, 444, 302])
77 tensor([ 41, 257,  13,  64, 295, 330, 396, 251, 379, 232, 108, 364])
78 tensor([ 70, 161, 168,  41, 434, 258, 327, 270,  42, 347, 384, 282])
79 tensor([392,  13, 258, 416, 146, 308,  32, 276, 302, 177, 410, 263])
80 tensor([186, 433, 420,  11, 273, 230, 377, 416, 303,  83,  20, 240])
81 tensor([ 47, 354, 171, 207, 178, 351, 137, 138,  33, 224, 422, 280])
82 tensor([214, 193, 444, 432, 274, 268,  67, 217,  64,  84,  27, 102])
83 tensor([419,  62, 244])

Example: Creating A Data Loader From A Dataset

Most pre-loaded datasets from Torchvision return objects thus enabling us to directly feed them into the class and then enumerate through them in our training loop.

For example, this code snippet from the PyTorch tutorials shows how easily, we can create data loaders using pre-loaded datasets from torchvision.

from torchvision import datasets
from import DataLoader
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Using Custom Samplers For More Control Over Data Loading

The aforementioned code example returns mini-batches of data with the provided batch size.

For even more control over your data loading use custom Samplers. Every subclass must contain a __iter__ method and a __len__ method to specify enumeration. For more information refer to the PyTorch docs.

Another Example of Dataset + DataLoader

Import libraries

import pandas as pd
import torch
from import Dataset, DataLoader

Pandas is not essential to create a Dataset object. However, it’s a powerful tool for managing data so I’m going to use it. imports the required functions we need to create and use Dataset and DataLoader.

Create a custom Dataset class

class CustomTextDataset(Dataset):
    def __init__(self, txt, labels):
        self.labels = labels
        self.text = textdef __len__(self):
        return len(self.labels)def __getitem__(self, idx):
        label = self.labels[idx]
        text = self.text[idx]
        sample = {"Text": text, "Class": label}
        return sample

class CustomTextDataset(Dataset): Create a class called ‘CustomTextDataset’, this can be called whatever you want. Passed in to the class is the dataset module which we imported earlier.

def __init__(self, text, labels): When you initialize the class you need to import two variables. In this case, the variables are called ‘text’ and ‘labels’ to match the data which will be added.

self.labels = labels & self.text = text: The imported variables can now be used in functions within the class by using self.text or self.labels.

def __len__(self): This function just returns the length of the labels when called. E.g., if you had a dataset with 5 labels, then the integer 5 would be returned.

def __getitem__(self, idx): This function is used by Pytorch’s Dataset module to get a sample and construct the dataset. When initialized, it will loop through this function creating a sample from each instance in the dataset.

  • ‘idx’ passed into the function is a number, this number is the data instance which Dataset will be looping through. We use the self.labels and self.text variables mentioned earlier with the ‘idx’ variable passed in to get the current instance of data. These current instances are then saved in variables called ‘label’ and ‘data’.
  • Next, a variable is declared called ‘sample’ containing a dictionary storing the data. This is stored in another dictionary consisting of all data in the dataset. After initializing this class with data it will then contain lots of data instances marked as ‘Text’ and ‘Class’. You can name ‘Text’ and ‘Class’ anything.

Initialise the CustomTextDataset class

# define data and class labels
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']# create Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})# define data set object
TD = CustomTextDataset(text_labels_df['Text'],                               text_labels_df['Labels'])

First, we create two lists called ‘text’ and ‘labels’ as an example.

text_labels_df = pd.DataFrame({‘Text’: text, ‘Labels’: labels}): This is not essential, but Pandas is a useful tool for data management and pre-processing and will probably be used in your PyTorch pipeline. In this section, the lists ‘text’ and ‘labels’ containing the data are saved in a Pandas DataFrame.

TD = CustomTextDataset(text_labels_df[‘Text’], text_labels_df[‘Labels’]): This initialises the class we made earlier with the ‘Text’ and ‘Labels’ data being passed in. This data will become ‘self.text’ and ‘self.labels’ within the class. The Dataset is saved under the variable named TD.

The Dataset is now initialized and ready to be used!

Some code to show you what’s going on inside the Dataset

This will show you how the data is stored within the Dataset.

# Display text and label.
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')# Print how many items are in the data set
print('Length of data set: ', len(TD), '\n')# Print entire data set
print('Entire data set: ', list(DataLoader(TD)), '\n')


First iteration of data set: {‘Text’: ‘Happy’, ‘Class’: ‘Positive’}
Length of data set: 5
Entire data set: [{‘Text’: [‘Happy’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Amazing’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Sad’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Unhapy’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Glum’], ‘Class’: [‘Negative’]}]

How to pre-process your data using ‘collate_fn’

In machine learning or deep learning text needs to be cleaned and turned in to vectors prior to training. DataLoader has a handy parameter called collate_fn. This parameter allows you to create separate data processing functions and will apply the processing within that function to the data before it is output.

def collate_batch(batch):    word_tensor = torch.tensor([[1.], [0.], [45.]])
    label_tensor = torch.tensor([[1.]])
    text_list, classes = [], []    for (_text, _class) in batch:
        classes.append(label_tensor)     text =
     classes = torch.tensor(classes)     return text, classesDL_DS = DataLoader(TD, batch_size=2, collate_fn=collate_batch)

As an example, two tensors are created to represent the word and class. In practice, these could be word vectors passed in through another function. The batch is then unpacked and then we add the word and label tensors to lists.

The word tensors are then concatenated and the list of class tensors, in this case, 1, are combined into a single tensor. The function will now return processed text data ready for training.

To activate this function you simply add the parameter collate_fn=Your_Function_name when initializing the DataLoader object.

Iterate through the dataset when training a model

We will iterate through the Dataset without using collate_fn because it’s easier to see how the words and classes are being output by DataLoader. If the above function were used with collate_fn then the output would be tensors.

DL_DS = DataLoader(TD, batch_size=2, shuffle=True)for (idx, batch) in enumerate(DL_DS):    # Print the 'text' data of the batch
    print(idx, 'Text data: ', batch['Text'])    # Print the 'class' data of batch
    print(idx, 'Class data: ', batch['Class'], '\n')

DL_DS = DataLoader(TD, batch_size=2, shuffle=True) : This initialises DataLoader with the Dataset object “TD” which we just created. In this example, the batch size is set to 2. This means that when you iterate through the Dataset, DataLoader will output 2 instances of data instead of one. For more information on batches see this article. Shuffle will reshuffle the data at each epoch, this prevents the model from learning the order of training data.

for (idx, batch) in enumerate(DL_DS): Iterate through the data in the DataLoader object we just created. enumerate(DL_DS) returns the index number of the batch and the batch consisting of two data instances.


As you can see, the 5 data instances we created are output in batches of 2. Since we have an odd number of training examples the last one is output in its own batch. Each number — 0,1 or 2 represents a batch.

Full code:

# Import libraries
import pandas as pd
import torch
from import Dataset, DataLoader

# create custom dataset class
class CustomTextDataset(Dataset):
    def __init__(self, text, labels):
        self.labels = labels
        self.text = text

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        data = self.text[idx]
        sample = {"Text": data, "Class": label}
        return sample

# define data and class labels
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']

# create Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})

# define data set object
TD = CustomTextDataset(text_labels_df['Text'], text_labels_df['Labels'])

# Display image and label.
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')

# Print how many items are in the data set
print('Length of data set: ', len(TD), '\n')

# Print entire data set
print('Entire data set: ', list(DataLoader(TD)), '\n')

# collate_fn
def collate_batch(batch):
    word_tensor = torch.tensor([[1.], [0.], [45.]])
    label_tensor = torch.tensor([[1.]])

    text_list, classes = [], []

    for (_text, _class) in batch:

    text =
    classes = torch.tensor(classes)

    return text, classes

# create DataLoader object of DataSet object
bat_size = 2
DL_DS = DataLoader(TD, batch_size=bat_size, shuffle=True)

# loop through each batch in the DataLoader object
for (idx, batch) in enumerate(DL_DS):

    # Print the 'text' data of the batch
    print(idx, 'Text data: ', batch, '\n')

    # Print the 'class' data of batch
    print(idx, 'Class data: ', batch, '\n')


Amir Masoud Sefidian
Amir Masoud Sefidian
Machine Learning Engineer

Comments are closed.