dragon.ai.torch.DragonDataset
- class DragonDataset
Bases:
Dataset
This is a PyTorch dataset that utilizes the dragon distributed dictionary to store the training data and labels. It takes either an iterable for the data or an existing dragon distributed dictionary with a list of its keys. The PyTorch Dataloader requires three functions to be supported:
__getitem__
,__len__
, and__init__
. For use with an arbitrary iterable, astop()
function is also provided that closes the dragon distributed dictionary. If the user provides a dictionary, the user is expected to manage the dictionary and close it directly.Example usage:
import dragon import torch from dragon.data import DDict from dragon.native.process import Process def train(dataset): train_kwargs = {"batch_size": 64} if torch.cuda.is_available(): cuda_kwargs = { "num_workers": 0, "pin_memory": True, "shuffle": True, } train_kwargs.update(cuda_kwargs) train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) for epoch in range(1, 10): train(model, device, train_loader, optimizer, epoch) scheduler.step() if __name__ == "__main__": mp.set_start_method("dragon") d = DDict(2, 1, 1 * 1024 * 1024 * 1024) # fill ddict with data dragon_dataset = DragonDataset(d, dataset_keys=d.keys()) # this process may be on any node in the allocation proc = Process(target=train, args=(dragon_dataset,)) proc.start() proc.join() d.destroy()
- __init__(dataset: Iterable [tuple [Any , Any ]], *, dataset_keys=None, dragon_dict_args=None)
Construct a Dataset from a
DDict
usable by a PyTorchDataLoader
.- Parameters:
dataset (PyTorch Dataset or
DDict
) – Base PyTorch Datasetdataset_keys (list or iterable) – All keys in the dataset (e.g., from :py:method:`~dragon.data.DDict.keys`)
dragon_dict_args (dict ) – Optional arguments to construct a new
DDict
Methods
__init__
(dataset, *[, dataset_keys, ...])Construct a Dataset from a
DDict
usable by a PyTorchDataLoader
.stop
()Bring down the Dataset and clean up all resources
- __init__(dataset: Iterable [tuple [Any , Any ]], *, dataset_keys=None, dragon_dict_args=None)
Construct a Dataset from a
DDict
usable by a PyTorchDataLoader
.- Parameters:
dataset (PyTorch Dataset or
DDict
) – Base PyTorch Datasetdataset_keys (list or iterable) – All keys in the dataset (e.g., from :py:method:`~dragon.data.DDict.keys`)
dragon_dict_args (dict ) – Optional arguments to construct a new
DDict
- __len__()
- __getitem__(idx)
Gets a data and label pair from the distributed dictionary based on an idx in [0, len(self.dict)). It retrieves the key self.keys[idx].
- stop()
Bring down the Dataset and clean up all resources