Resiliency with DDict Checkpointing
The Dragon distribute dictionary, refered to as the DDict, provides APIs for users to add resiliency to their applications. The DDict.checkpoint() is the foundation that enables this capability. This feature allows users to create consistent snapshots of the DDict’s state at specific points in time. This is particularly useful for applications that require fault tolerance, iterative computations, or long-running processes where intermediate results need to be saved and potentially restored later.
Some features of this tutorial are still experimental and under development. These features may not be ready for production use and could change in future releases. Please refer to the Dragon documentation and release notes for the most up-to-date information on the status of these features.
Checkpointing and Rolling Back
Checkpointing allows applications to save their state at specific points in time. DDict.rollback() allows users to roll back to a previous checkpoint if needed, enabling recovery from errors or failures. The example below demonstrates how to use the DDict to create checkpoints and roll back to previous states if corruption of the current state is detected.
1import random
2import socket
3from dragon.native.machine import System
4from dragon.data.ddict import DDict
5from dragon.native.process_group import ProcessGroup
6from dragon.native.process import ProcessTemplate
7from dragon.native.barrier import Barrier
8import dragon.infrastructure.parameters as di_param
9
10
11def biased_sampler(ddict, barrier):
12
13 my_puid = di_param.this_process.my_puid
14 local_sample_agg = 0.0
15 num_samples = 0
16 rollback = 0
17
18 while num_samples < 1000:
19 # some work
20 sample = random.normalvariate()
21 local_sample_agg += sample
22 num_samples += 1
23
24 # condition suggesting state has been corrupted and we need to roll back
25 if abs(sample) >= 4 and num_samples > 100:
26 # checkpoint current state before rolling back
27 ddict.checkpoint()
28 rollback_chkpt = ddict.checkpoint_id - 1
29 print(f"Process {my_puid} rolling back to checkpoint {rollback_chkpt}", flush=True)
30 ddict.rollback()
31 rollback += 1
32 local_sample_agg = ddict[f"puid_{my_puid}"]
33 num_samples = rollback_chkpt * 100
34 continue
35
36 # periodically checkpoint state of application
37 if num_samples % 100 == 0:
38 ddict[f"puid_{my_puid}"] = local_sample_agg
39 ddict.checkpoint() # proceed to the next checkpoint
40 barrier.wait()
41
42
43if __name__ == "__main__":
44
45 my_alloc = System()
46 nnodes = my_alloc.nnodes
47 procs_per_node = 30
48 ddict = DDict(1, nnodes, nnodes * int(4 * 1024 * 1024), working_set_size=4)
49 barrier = Barrier(nnodes * procs_per_node)
50 pg = ProcessGroup()
51 temp_proc = ProcessTemplate(target=biased_sampler, args=(ddict, barrier))
52 pg.add_process(template=temp_proc, nproc=nnodes * procs_per_node)
53
54 pg.init()
55 pg.start()
56 pg.join()
57
58 avg = 0.0
59 # update current process to the latest checkpoint
60 ddict.sync_to_newest_checkpoint()
61 for puid, _ in pg.inactive_puids:
62 avg += ddict[f"puid_{puid}"]
63 avg /= nnodes * procs_per_node * 1000
64 print(f"Final average from all processes is {avg}", flush=True)
65 pg.close()
66 ddict.destroy()
Using Checkpointing to Persist Training State
Checkpointing combined with persistence, allows users to save the state of their application to a more permanent storage solution, such as disk. This is particularly useful for long-running training jobs where hardware and software failures can occur. The example below demonstrates how to checkpoint and persist the state of distributed data parallel training process using a DDict.
1import dragon
2import getpass
3import os
4import argparse
5
6from dragon.ai.collective_group import CollectiveGroup, RankInfo
7from dragon.data.ddict import DDict, PosixCheckpointPersister, DAOSCheckpointPersister
8from dragon.native.machine import System
9
10import torch
11import torch.nn as nn
12import torch.optim as optim
13from torch.utils.data import DataLoader, TensorDataset
14from torch.utils.data.distributed import DistributedSampler
15import torch.distributed as dist
16from torch.nn.parallel import DistributedDataParallel as DDP
17
18
19class SimpleNN(nn.Module):
20 def __init__(self, input_size, hidden_size, output_size):
21 super(SimpleNN, self).__init__()
22 self.fc1 = nn.Linear(input_size, hidden_size)
23 self.relu = nn.ReLU()
24 self.fc2 = nn.Linear(hidden_size, output_size)
25
26 def forward(self, x):
27 out = self.fc1(x)
28 out = self.relu(out)
29 out = self.fc2(out)
30 return out
31
32
33def training_fn(ddict, restart):
34 rank_info = RankInfo()
35 rank = rank_info.my_rank
36 local_rank = rank_info.my_local_rank
37 master_addr = rank_info.master_addr
38 master_port = rank_info.master_port
39 world_size = rank_info.world_size
40 print(
41 f"Rank Info: rank {rank}, local_rank {local_rank}, master_addr {master_addr}, master_port {master_port}, world_size {world_size}"
42 )
43
44 dist.init_process_group(
45 backend="nccl",
46 init_method=f"tcp://{master_addr}:{master_port}",
47 world_size=world_size,
48 rank=rank,
49 )
50
51 input_size = 10
52 hidden_size = 20
53 output_size = 1
54 num_samples = 100 * nnodes
55 batch_size = 10
56 learning_rate = 0.01
57 num_epochs = 20
58 device = torch.device(f"cuda:0")
59 criterion = nn.MSELoss().to(device)
60
61 if restart:
62 # get globally checkpointed variables
63 train_loader = ddict["train_loader"]
64 # get my local ddict shard to recover local checkpoints
65 my_manager_id = ddict.local_managers[local_rank]
66 ddict = ddict.manager(my_manager_id)
67 # get loader, model, optimizer state from ddict
68 model = SimpleNN(input_size, hidden_size, output_size)
69 model.load_state_dict(ddict[f"model_state_dict_{rank}"])
70 optimizer = optim.Adam(model.parameters(), lr=learning_rate)
71 optimizer.load_state_dict(ddict[f"optimizer_state_dict_{rank}"])
72
73 # Move optimizer state to device
74 for state in optimizer.state.values():
75 for k, v in state.items():
76 if isinstance(v, torch.Tensor):
77 state[k] = v.to(device)
78
79 ddict.checkpoint() # proceed to the next checkpoint
80
81 first_chkpt = ddict.checkpoint_id
82 last_chkpt = first_chkpt + num_epochs
83 else:
84 # generate random data
85 X_train = torch.randn(num_samples, input_size)
86 y_train = torch.randn(num_samples, output_size)
87 train_dataset = TensorDataset(X_train, y_train)
88 train_loader = DataLoader(
89 train_dataset, batch_size=batch_size, shuffle=False, sampler=DistributedSampler(train_dataset)
90 )
91 model = SimpleNN(input_size, hidden_size, output_size)
92 optimizer = optim.Adam(model.parameters(), lr=learning_rate)
93
94 # store loader to ddict. only needs to be done once with a persistent put
95 if rank == 0:
96 ddict.pput("train_loader", train_loader)
97 optimizer = optim.Adam(model.parameters(), lr=learning_rate)
98
99 first_chkpt = 0
100 last_chkpt = num_epochs
101 # get my local ddict so checkpoints are to node-local memory
102 my_manager_id = ddict.local_managers[local_rank]
103 ddict = ddict.manager(my_manager_id)
104
105 # device ids is an index. we control gpu affinity using policy so each process only sees it's own GPU.
106 model.to(device)
107 model = DDP(model, device_ids=[0])
108
109 print(f"Rank {rank} starting training", flush=True)
110 # training
111 for epoch in range(num_epochs):
112 for i, (inputs, labels) in enumerate(train_loader):
113 inputs = inputs.to(device)
114 labels = labels.to(device)
115 outputs = model(inputs)
116 loss = criterion(outputs, labels)
117 optimizer.zero_grad()
118 loss.backward()
119 optimizer.step()
120
121 # checkpoint model and optimizer state
122 ddict[f"model_state_dict_{rank}"] = model.module.state_dict()
123 ddict[f"optimizer_state_dict_{rank}"] = optimizer.state_dict()
124 # if every ddict manager may not get a key then using a small bput
125 # can keep managers in sync for checkpointing. In this case we have as many managers as GPUs so each manager gets a key.
126 # ddict.bput("epoch", epoch)
127 ddict.checkpoint()
128 if rank == 0:
129 print(f"Epoch [{first_chkpt + epoch + 1}/{last_chkpt}], Loss: {loss.item():.4f}", flush=True)
130
131 print(f"Rank {rank} finished training!", flush=True)
132 dist.destroy_process_group()
133
134
135if __name__ == "__main__":
136
137 parser = argparse.ArgumentParser(description="Training with DDict Checkpoint Persistence")
138 parser.add_argument(
139 "--restart",
140 action="store_true",
141 help="Whether to restart from the last persisted checkpoint",
142 )
143 args = parser.parse_args()
144
145 my_alloc = System()
146 nnodes = my_alloc.nnodes
147
148 # Checkpointing parameters
149 managers_per_node = my_alloc.primary_node.num_gpus
150 working_set_size = 2
151 persist_frequency = 2
152 persist_count = 2
153 persist_path = ""
154 restart = True if args.restart or my_alloc.restarted else False
155 persister = PosixCheckpointPersister # switch to DAOSCheckpointPersister if using DAOS
156 name = f"chkpt_persistence_example_{getpass.getuser()}"
157
158 ddict = DDict(
159 managers_per_node,
160 nnodes,
161 nnodes * int(4 * 1024 * 1024 * 1024),
162 wait_for_keys=True,
163 working_set_size=working_set_size,
164 persister_class=persister,
165 persist_freq=persist_frequency,
166 persist_count=persist_count,
167 persist_path=persist_path,
168 name=name,
169 )
170
171 if restart:
172 # if it's a restart, we recover the ddict and find the last persisted checkpoint
173 available_persisted_chkpt = ddict.persisted_ids()
174 print(f"available persisted checkpoints: {available_persisted_chkpt}", flush=True)
175
176 # restore from the last complete checkpoint
177 latest_chkpt = available_persisted_chkpt[-1]
178 ddict.restore(latest_chkpt)
179
180 # launch one training process per GPU
181 num_gpus_per_node = my_alloc.primary_node.num_gpus
182 policies = my_alloc.gpu_policies()
183 training_group = CollectiveGroup(
184 training_fn,
185 training_args=(
186 ddict,
187 restart,
188 ),
189 policies=policies,
190 )
191
192 training_group.init()
193 training_group.start()
194 training_group.join()
195 training_group.close()
196
197 ddict.destroy()
In this example, all checkpoints are first written to node-local memory for performance. At a specified frequency, checkpoints are asynchronously persisted to disk using a chosen persister class. Upon restart, the application can restore from the last persisted checkpoint. Although in this example duplicate state was saved, this covers the more general case where each rank may have a shard of the state, like what occurs in fully sharded data parallel training or when running an ensemble with different hyperparameters. This replicated state can also provide an insurance that if a node fails, the system can recover from another node’s in-memory checkpoint. In the next section we will cover how to automatically recover from failures using only in-memory checkpoints rather than persistent storage.
Recovering from Failures with In-Memory Checkpoints
In progress…