TL; DR

  • SLURM monitors the total resident memory (RSS) consumed by all the task processes (incl. dataloader workers)
  • pin_memory=True increases RSS significantly and may cause leaks with mmap based LMDB, pushing to the memory limit sooner
  • PyTorch FastDataLoader or DataLoader created with persistent_workers=True is going to accumulate RSS with workers that never reset MMAP based storage such as LMDB env across epochs

When it comes to training deep learning models, the I/O storage capacity and transfer bandwidth are ususally the bottleneck. While HDF5 is efficient to load the entire dataset for training, it is limited to the system memory capacity, typically up to hunderds of GB. On the other hand, memory mampped (MMAP) storage allows data access beyond system memory constraints. One popular implementation is LMDB, providing numerous language bindings including Python. It is tempting to replace HDF5 with LMDB for super large dataset loading and access. When running locally, potential memory allocation issues may not emerge to the surface as modern computer systems support disk swap space in case the process consumes more than available memory. However, training deep models could take days or longer and it is not uncommon to set up the training job in a high performance coomputing cluster such as SLURM. To submit a job to the SLURM cluster, the memory usage must be specified beforehand and at runtime, the memory usage is monitored according to the Resident Set Size (RSS) statistics. Unfortunately, the same training process on SLRUM is now subject to out of memory error because the swap space may not be available for SLURM tasks and MMAP based LMDB may grow the RSS over time beyond the usage limit. The following pytest snippet demonstrates the task RSS is increasing with LMDB access to a huage dataset over epochs:

import os
import sys
import random
import pytest
import torch

from torch.utils.data import DataLoader
from time import time
from tqdm import tqdm

print()

KB = 2**10
MB = 2**10 * KB
GB = 2**10 * MB

def rss_usage(breakdown=False):
    import psutil
    proc = psutil.Process(os.getpid())
    RSS = []
    RSS.append((os.getpid(), proc.memory_info().rss))
    for child in proc.children(recursive=True):
        RSS.append((child.pid, child.memory_info().rss))
    
    rss = sum(mem for _, mem in RSS)
    return (rss, RSS) if breakdown else rss

def test_rss():
    print(sys.argv)
    argv = sys.argv
    sys.argv = [argv[1]]
    import lmdb
    from utils.utils import FastDataLoader
    from dataset.lmdb_dataset import UCF101LMDB_2CLIP
    from main_nce import parse_args, get_transform
    args = parse_args()
    sys.argv = argv
    lmdb_root = "/mnt/ssd/dataset/ucf101/lmdb"
    lmdb_path = f"{lmdb_root}/UCF101/ucf101_frame.lmdb"
    trans = get_transform('train', args)
    ucf101 = UCF101LMDB_2CLIP(db_path=lmdb_path, mode='train', transform=trans, num_frames=32, ds=1, return_label=True)
    print(f"Created UCF101 2clip dataset of size {len(ucf101)}")

    dataloader = FastDataLoader(ucf101, 
                            batch_size=32, shuffle=True,
                            num_workers=4, persistent_workers=False, 
                            pin_memory=not True, sampler=None, drop_last=True)
    batches = 8
    for epoch in range(3):
        rss = rss_usage()
        print(f"[e{epoch:02d}] RSS: {rss / GB:.2f} GB")
        for idx, (input_seq, label) in tqdm(enumerate(dataloader), total=len(dataloader), disable=True):
            if idx % 4 == 0:
                rss, RSS = rss_usage(True)
                for pid, mem in RSS:
                    print(f"[e{epoch:02d}][b{idx:02d}][{pid}] consumes {mem / GB:.2f} GB")
                print(f"[e{epoch:02d}][b{idx:02d}] RSS: {rss / GB:.2f} GB")
            if idx == batches:
                break
[e00] RSS: 2.56 GB
[e00][b00][14023] consumes 1.08 GB
[e00][b00][14055] consumes 0.49 GB
[e00][b00][14071] consumes 0.81 GB
[e00][b00][14087] consumes 0.83 GB
[e00][b00][14103] consumes 0.84 GB
[e00][b00] RSS: 4.07 GB
[e00][b04][14023] consumes 1.08 GB
[e00][b04][14055] consumes 0.78 GB
[e00][b04][14071] consumes 0.90 GB
[e00][b04][14087] consumes 0.64 GB
[e00][b04][14103] consumes 0.80 GB
[e00][b04] RSS: 4.20 GB
[e00][b08][14023] consumes 1.08 GB
[e00][b08][14055] consumes 0.97 GB
[e00][b08][14071] consumes 1.00 GB
[e00][b08][14087] consumes 0.66 GB
[e00][b08][14103] consumes 1.24 GB
[e00][b08] RSS: 4.95 GB
[e01] RSS: 4.97 GB
[e01][b00][14023] consumes 1.08 GB
[e01][b00][14055] consumes 0.66 GB
[e01][b00][14071] consumes 0.92 GB
[e01][b00][14087] consumes 0.66 GB
[e01][b00][14103] consumes 0.80 GB
[e01][b00] RSS: 4.12 GB
[e01][b04][14023] consumes 1.08 GB
[e01][b04][14055] consumes 0.80 GB
[e01][b04][14071] consumes 0.99 GB
[e01][b04][14087] consumes 1.04 GB
[e01][b04][14103] consumes 1.03 GB
[e01][b04] RSS: 4.93 GB
[e01][b08][14023] consumes 1.08 GB
[e01][b08][14055] consumes 0.87 GB
[e01][b08][14071] consumes 1.05 GB
[e01][b08][14087] consumes 1.19 GB
[e01][b08][14103] consumes 1.07 GB
[e01][b08] RSS: 5.26 GB
[e02] RSS: 5.29 GB
[e02][b00][14023] consumes 1.08 GB
[e02][b00][14055] consumes 0.85 GB
[e02][b00][14071] consumes 1.06 GB
[e02][b00][14087] consumes 1.09 GB
[e02][b00][14103] consumes 1.09 GB
[e02][b00] RSS: 5.17 GB
[e02][b04][14023] consumes 1.08 GB
[e02][b04][14055] consumes 0.92 GB
[e02][b04][14071] consumes 1.12 GB
[e02][b04][14087] consumes 0.86 GB
[e02][b04][14103] consumes 1.14 GB
[e02][b04] RSS: 5.12 GB
[e02][b08][14023] consumes 1.08 GB
[e02][b08][14055] consumes 0.97 GB
[e02][b08][14071] consumes 1.19 GB
[e02][b08][14087] consumes 0.93 GB
[e02][b08][14103] consumes 1.23 GB
[e02][b08] RSS: 5.39 GB

The root cause is the LMDB Python API to access database records as follows may not release the mapped memory timely on completion to reduce the runtime RSS.

class UCF101LMDB_2CLIP(object):
        ...
        print('Loading LMDB from %s, split:%d' % (self.db_path, self.which_split))
        self.env = lmdb.open(self.db_path, subdir=os.path.isdir(self.db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        ...
        
    def __getitem__(self, index):
        vpath, vlen, vlabel, vname = self.video_subset.iloc[index]
        env = self.env
        with env.begin(write=False) as txn:
            raw = msgpack.loads(txn.get(self.get_video_id[vname].encode('ascii')))

Worse, the FastLoader never recreates dataset iterator workers that involes the LMDB env and will grow RSS over epochs due to increasing MMAP access. If using the vanilla DataLoader, make sure to set persistent_workers=False in case of a similar memory leak. Nonetheless, sufficient memory must be allocated at least for peak usage in one epoch. This serves as the workaround.