dragon.ai.torch.DragonDataset

class DragonDataset

Bases: Dataset

This is a PyTorch dataset that utilizes the dragon distributed dictionary to store the training data and labels. It takes either an iterable for the data or an existing dragon distributed dictionary with a list of its keys. The PyTorch Dataloader requires three functions to be supported: __getitem__, __len__, and __init__. For use with an arbitrary iterable, a stop() function is also provided that closes the dragon distributed dictionary. If the user provides a dictionary, the user is expected to manage the dictionary and close it directly.

Example usage:

import dragon
import torch
from dragon.data import DDict
from dragon.native.process import Process

def train(dataset):
    train_kwargs = {"batch_size": 64}
    if torch.cuda.is_available():
        cuda_kwargs = {
            "num_workers": 0,
            "pin_memory": True,
            "shuffle": True,
        }
        train_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)

    for epoch in range(1, 10):
        train(model, device, train_loader, optimizer, epoch)
        scheduler.step()

if __name__ == "__main__":
    mp.set_start_method("dragon")

    d = DDict(2, 1, 1 * 1024 * 1024 * 1024)
    # fill ddict with data

    dragon_dataset = DragonDataset(d, dataset_keys=d.keys())

    # this process may be on any node in the allocation
    proc = Process(target=train, args=(dragon_dataset,))
    proc.start()
    proc.join()
    d.destroy()
__init__(dataset: Iterable [tuple [Any , Any ]], *, dataset_keys=None, dragon_dict_args=None)

Construct a Dataset from a DDict usable by a PyTorch DataLoader.

Parameters:

Methods

__init__(dataset, *[, dataset_keys, ...])

Construct a Dataset from a DDict usable by a PyTorch DataLoader.

stop()

Bring down the Dataset and clean up all resources

__init__(dataset: Iterable [tuple [Any , Any ]], *, dataset_keys=None, dragon_dict_args=None)

Construct a Dataset from a DDict usable by a PyTorch DataLoader.

Parameters:
__len__()
__getitem__(idx)

Gets a data and label pair from the distributed dictionary based on an idx in [0, len(self.dict)). It retrieves the key self.keys[idx].

Parameters:

idx (int ) – A randomly generated index to the list of keys

Returns:

Tuple of the data and label with key self.keys[idx]

Return type:

tuple

stop()

Bring down the Dataset and clean up all resources