Resiliency with DDict Checkpointing

The Dragon distribute dictionary, refered to as the DDict, provides APIs for users to add resiliency to their applications. The DDict.checkpoint() is the foundation that enables this capability. This feature allows users to create consistent snapshots of the DDict’s state at specific points in time. This is particularly useful for applications that require fault tolerance, iterative computations, or long-running processes where intermediate results need to be saved and potentially restored later.

Some features of this tutorial are still experimental and under development. These features may not be ready for production use and could change in future releases. Please refer to the Dragon documentation and release notes for the most up-to-date information on the status of these features.

Checkpointing and Rolling Back

Checkpointing allows applications to save their state at specific points in time. DDict.rollback() allows users to roll back to a previous checkpoint if needed, enabling recovery from errors or failures. The example below demonstrates how to use the DDict to create checkpoints and roll back to previous states if corruption of the current state is detected.

Listing 45 Use DDict with checkpointing to compute average of samples with potential rollback if value is out of allowed range
 1import random
 2import socket
 3from dragon.native.machine import System
 4from dragon.data.ddict import DDict
 5from dragon.native.process_group import ProcessGroup
 6from dragon.native.process import ProcessTemplate
 7from dragon.native.barrier import Barrier
 8import dragon.infrastructure.parameters as di_param
 9
10
11def biased_sampler(ddict, barrier):
12
13    my_puid = di_param.this_process.my_puid
14    local_sample_agg = 0.0
15    num_samples = 0
16    rollback = 0
17
18    while num_samples < 1000:
19        # some work
20        sample = random.normalvariate()
21        local_sample_agg += sample
22        num_samples += 1
23
24        # condition suggesting state has been corrupted and we need to roll back
25        if abs(sample) >= 4 and num_samples > 100:
26            # checkpoint current state before rolling back
27            ddict.checkpoint()
28            rollback_chkpt = ddict.checkpoint_id - 1
29            print(f"Process {my_puid} rolling back to checkpoint {rollback_chkpt}", flush=True)
30            ddict.rollback()
31            rollback += 1
32            local_sample_agg = ddict[f"puid_{my_puid}"]
33            num_samples = rollback_chkpt * 100
34            continue
35
36        # periodically checkpoint state of application
37        if num_samples % 100 == 0:
38            ddict[f"puid_{my_puid}"] = local_sample_agg
39            ddict.checkpoint()  # proceed to the next checkpoint
40            barrier.wait()
41
42
43if __name__ == "__main__":
44
45    my_alloc = System()
46    nnodes = my_alloc.nnodes
47    procs_per_node = 30
48    ddict = DDict(1, nnodes, nnodes * int(4 * 1024 * 1024), working_set_size=4)
49    barrier = Barrier(nnodes * procs_per_node)
50    pg = ProcessGroup()
51    temp_proc = ProcessTemplate(target=biased_sampler, args=(ddict, barrier))
52    pg.add_process(template=temp_proc, nproc=nnodes * procs_per_node)
53
54    pg.init()
55    pg.start()
56    pg.join()
57
58    avg = 0.0
59    # update current process to the latest checkpoint
60    ddict.sync_to_newest_checkpoint()
61    for puid, _ in pg.inactive_puids:
62        avg += ddict[f"puid_{puid}"]
63    avg /= nnodes * procs_per_node * 1000
64    print(f"Final average from all processes is {avg}", flush=True)
65    pg.close()
66    ddict.destroy()

Using Checkpointing to Persist Training State

Checkpointing combined with persistence, allows users to save the state of their application to a more permanent storage solution, such as disk. This is particularly useful for long-running training jobs where hardware and software failures can occur. The example below demonstrates how to checkpoint and persist the state of distributed data parallel training process using a DDict.

Listing 46 Checkpoint model and optimzier state to DDict and persist state to disk
  1import dragon
  2import getpass
  3import os
  4import argparse
  5
  6from dragon.ai.collective_group import CollectiveGroup, RankInfo
  7from dragon.data.ddict import DDict, PosixCheckpointPersister, DAOSCheckpointPersister
  8from dragon.native.machine import System
  9
 10import torch
 11import torch.nn as nn
 12import torch.optim as optim
 13from torch.utils.data import DataLoader, TensorDataset
 14from torch.utils.data.distributed import DistributedSampler
 15import torch.distributed as dist
 16from torch.nn.parallel import DistributedDataParallel as DDP
 17
 18
 19class SimpleNN(nn.Module):
 20    def __init__(self, input_size, hidden_size, output_size):
 21        super(SimpleNN, self).__init__()
 22        self.fc1 = nn.Linear(input_size, hidden_size)
 23        self.relu = nn.ReLU()
 24        self.fc2 = nn.Linear(hidden_size, output_size)
 25
 26    def forward(self, x):
 27        out = self.fc1(x)
 28        out = self.relu(out)
 29        out = self.fc2(out)
 30        return out
 31
 32
 33def training_fn(ddict, restart):
 34    rank_info = RankInfo()
 35    rank = rank_info.my_rank
 36    local_rank = rank_info.my_local_rank
 37    master_addr = rank_info.master_addr
 38    master_port = rank_info.master_port
 39    world_size = rank_info.world_size
 40    print(
 41        f"Rank Info: rank {rank}, local_rank {local_rank}, master_addr {master_addr}, master_port {master_port}, world_size {world_size}"
 42    )
 43
 44    dist.init_process_group(
 45        backend="nccl",
 46        init_method=f"tcp://{master_addr}:{master_port}",
 47        world_size=world_size,
 48        rank=rank,
 49    )
 50
 51    input_size = 10
 52    hidden_size = 20
 53    output_size = 1
 54    num_samples = 100 * nnodes
 55    batch_size = 10
 56    learning_rate = 0.01
 57    num_epochs = 20
 58    device = torch.device(f"cuda:0")
 59    criterion = nn.MSELoss().to(device)
 60
 61    if restart:
 62        # get globally checkpointed variables
 63        train_loader = ddict["train_loader"]
 64        # get my local ddict shard to recover local checkpoints
 65        my_manager_id = ddict.local_managers[local_rank]
 66        ddict = ddict.manager(my_manager_id)
 67        # get loader, model, optimizer state from ddict
 68        model = SimpleNN(input_size, hidden_size, output_size)
 69        model.load_state_dict(ddict[f"model_state_dict_{rank}"])
 70        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 71        optimizer.load_state_dict(ddict[f"optimizer_state_dict_{rank}"])
 72
 73        # Move optimizer state to device
 74        for state in optimizer.state.values():
 75            for k, v in state.items():
 76                if isinstance(v, torch.Tensor):
 77                    state[k] = v.to(device)
 78
 79        ddict.checkpoint()  # proceed to the next checkpoint
 80
 81        first_chkpt = ddict.checkpoint_id
 82        last_chkpt = first_chkpt + num_epochs
 83    else:
 84        # generate random data
 85        X_train = torch.randn(num_samples, input_size)
 86        y_train = torch.randn(num_samples, output_size)
 87        train_dataset = TensorDataset(X_train, y_train)
 88        train_loader = DataLoader(
 89            train_dataset, batch_size=batch_size, shuffle=False, sampler=DistributedSampler(train_dataset)
 90        )
 91        model = SimpleNN(input_size, hidden_size, output_size)
 92        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 93
 94        # store loader to ddict. only needs to be done once with a persistent put
 95        if rank == 0:
 96            ddict.pput("train_loader", train_loader)
 97        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 98
 99        first_chkpt = 0
100        last_chkpt = num_epochs
101        # get my local ddict so checkpoints are to node-local memory
102        my_manager_id = ddict.local_managers[local_rank]
103        ddict = ddict.manager(my_manager_id)
104
105    # device ids is an index. we control gpu affinity using policy so each process only sees it's own GPU.
106    model.to(device)
107    model = DDP(model, device_ids=[0])
108
109    print(f"Rank {rank} starting training", flush=True)
110    # training
111    for epoch in range(num_epochs):
112        for i, (inputs, labels) in enumerate(train_loader):
113            inputs = inputs.to(device)
114            labels = labels.to(device)
115            outputs = model(inputs)
116            loss = criterion(outputs, labels)
117            optimizer.zero_grad()
118            loss.backward()
119            optimizer.step()
120
121        # checkpoint model and optimizer state
122        ddict[f"model_state_dict_{rank}"] = model.module.state_dict()
123        ddict[f"optimizer_state_dict_{rank}"] = optimizer.state_dict()
124        # if every ddict manager may not get a key then using a small bput
125        # can keep managers in sync for checkpointing. In this case we have as many managers as GPUs so each manager gets a key.
126        # ddict.bput("epoch", epoch)
127        ddict.checkpoint()
128        if rank == 0:
129            print(f"Epoch [{first_chkpt + epoch + 1}/{last_chkpt}], Loss: {loss.item():.4f}", flush=True)
130
131    print(f"Rank {rank} finished training!", flush=True)
132    dist.destroy_process_group()
133
134
135if __name__ == "__main__":
136
137    parser = argparse.ArgumentParser(description="Training with DDict Checkpoint Persistence")
138    parser.add_argument(
139        "--restart",
140        action="store_true",
141        help="Whether to restart from the last persisted checkpoint",
142    )
143    args = parser.parse_args()
144
145    my_alloc = System()
146    nnodes = my_alloc.nnodes
147
148    # Checkpointing parameters
149    managers_per_node = my_alloc.primary_node.num_gpus
150    working_set_size = 2
151    persist_frequency = 2
152    persist_count = 2
153    persist_path = ""
154    restart = True if args.restart or my_alloc.restarted else False
155    persister = PosixCheckpointPersister  # switch to DAOSCheckpointPersister if using DAOS
156    name = f"chkpt_persistence_example_{getpass.getuser()}"
157
158    ddict = DDict(
159        managers_per_node,
160        nnodes,
161        nnodes * int(4 * 1024 * 1024 * 1024),
162        wait_for_keys=True,
163        working_set_size=working_set_size,
164        persister_class=persister,
165        persist_freq=persist_frequency,
166        persist_count=persist_count,
167        persist_path=persist_path,
168        name=name,
169    )
170
171    if restart:
172        # if it's a restart, we recover the ddict and find the last persisted checkpoint
173        available_persisted_chkpt = ddict.persisted_ids()
174        print(f"available persisted checkpoints: {available_persisted_chkpt}", flush=True)
175
176        # restore from the last complete checkpoint
177        latest_chkpt = available_persisted_chkpt[-1]
178        ddict.restore(latest_chkpt)
179
180    # launch one training process per GPU
181    num_gpus_per_node = my_alloc.primary_node.num_gpus
182    policies = my_alloc.gpu_policies()
183    training_group = CollectiveGroup(
184        training_fn,
185        training_args=(
186            ddict,
187            restart,
188        ),
189        policies=policies,
190    )
191
192    training_group.init()
193    training_group.start()
194    training_group.join()
195    training_group.close()
196
197    ddict.destroy()

In this example, all checkpoints are first written to node-local memory for performance. At a specified frequency, checkpoints are asynchronously persisted to disk using a chosen persister class. Upon restart, the application can restore from the last persisted checkpoint. Although in this example duplicate state was saved, this covers the more general case where each rank may have a shard of the state, like what occurs in fully sharded data parallel training or when running an ensemble with different hyperparameters. This replicated state can also provide an insurance that if a node fails, the system can recover from another node’s in-memory checkpoint. In the next section we will cover how to automatically recover from failures using only in-memory checkpoints rather than persistent storage.

Recovering from Failures with In-Memory Checkpoints

In progress…