Distributed PyTorch
Using ProcessGroup for PyTorch Distributed Training
Any time you want to do distributed training on GPUs with PyTorch, there is necessary configuration
to the PyTorch backend. Doing that with ProcessGroup is straightforward albeit unslightly as is always
the case for distributed training. Future work will provide helper classes to complete most standard
configurations. In the meantime, given some PyTorch function designed for distributed training training_fn
,
these code snippets will aid in using ProcessGroup
with a CUDA backend.
1from dragon.native.process_group import ProcessGroup
2import dragon.ai.torch # needs to be imported before torch to avoid multiprocessing conflicts
3
4import torch
5import torch.distributed as dist
6
7def training_fn():
8
9 device = 'cuda:' + os.getenv("LOCAL_RANK")
10 dist.init_process_group('nccl')
11 torch.cuda.set_device(device)
12
13 #### Do your training
14
15 dist.destroy_process_group()
16
17
18def configure_training_group(training_fn, training_args: tuple = None, training_kwargs: dict = None):
19
20 # Get the list of nodes available to the Dragon runtime
21 my_alloc = System()
22 node_list = my_alloc.nodes
23
24 tasks_per_node = 4 # Set to the number of GPUs on a given node
25
26 num_nodes = len(node_list)
27 world_size = num_nodes * tasks_per_node #
28
29 master_node = node_list[0].host_name
30 master_port = str(29500)
31
32 pg = ProcessGroup()
33 for node_rank, policy in range(num_nodes):
34 for local_rank in range(tasks_per_node):
35 rank = node_rank * self.tasks_per_node + local_rank
36
37 env = dict(os.environ).copy()
38 env["MASTER_ADDR"] = master_node
39 env["MASTER_PORT"] = master_port
40 env["RANK"] = str(rank)
41 env["LOCAL_RANK"] = str(local_rank)
42 env["WORLD_SIZE"] = str(self.world_size)
43 env["LOCAL_WORLD_SIZE"] = str(self.tasks_per_node)
44 env["GROUP_RANK"] = str(node_rank)
45
46 template = ProcessTemplate(target=training_fn,
47 args=training_args,
48 kwargs=training_kwargs,
49 env=env,
50 policy=policy,
51 stderr=stderr)
52
53 pg.add_process(nproc=1, template=template)
54
55 pg.init()
56 pg.start()
57
58 pg.join()
59 pg.close()