Back to the old days, I’ve studied how to implement highly efficient PyTorch pipelines for multi-gpu training [1]. DistributedDataParallel is the way to go, but it is cumbersome that we need boilerplates for spawning workers and constructing data readers.
Now, PyTorch Lighting offers clean API for setting up multi-gpu training easily. Here is a template I designed, which I will stick to for prototyping models for the rest of my life : )
import os from typing import List, Any from dataclasses import dataclass import torch import torch.distributed as dist from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, IterableDataset, random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy TOTAL_NUM_BATCHES = 320 BATCH_SIZE = 32 STATE_DIM = 5 NUM_GPUS = 2 NUM_WORKERS = 2 UPDATE_FREQ = 1 WEIGHTS = torch.tensor([2.0, 3.1, 2.1, -1.5, -1.7]) @dataclass class PolicyGradientInput: pid: int state: torch.Tensor reward: torch.Tensor def __len__(self): return self.state.shape[0] @classmethod def from_batch(cls, x): return cls( pid=x[0][0].item(), state=x[1][0].reshape(BATCH_SIZE, STATE_DIM), reward=x[2][0].reshape(BATCH_SIZE, 1), ) class EpisodicDataset(IterableDataset): def __iter__(self): worker_info = torch.utils.data.get_worker_info() pid = os.getpid() if worker_info is None: total_num_batches = TOTAL_NUM_BATCHES // NUM_GPUS worker_id = -1 else: total_workers = worker_info.num_workers total_num_batches = int(TOTAL_NUM_BATCHES // NUM_GPUS // total_workers) worker_id = worker_info.id # You will see that we have an EpisodicDataset on each of NUM_GPUS processes print(f"{worker_info}, pid={pid}, total_num_batches={total_num_batches}") for _ in range(total_num_batches): state = torch.randn(BATCH_SIZE, STATE_DIM) reward = torch.sum(state * WEIGHTS, dim=1) yield (pid, state, reward) class PPO(pl.LightningModule): def __init__(self): super().__init__() self.model = nn.Sequential( nn.Linear(STATE_DIM, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) ) self.traj_buffer = [] self.update_freq = UPDATE_FREQ self.step = 0 def training_step(self, batch: List[Any], batch_idx): batch: PolicyGradientInput = PolicyGradientInput.from_batch(batch) self.traj_buffer.append(batch) self.step += 1 rank = dist.get_rank() # use first three trajectories' pids as as signature. quickly check if all trainers share the same data # the answer is each trainer maintains different trajectories traj_buffer_signature = ','.join([str(traj.pid) for traj in self.traj_buffer[:3]]) print(f"rank={rank}, traj_buffer_len={len(self.traj_buffer)}, step={self.step}, signature={traj_buffer_signature}") if self.step % self.update_freq == 0: model_params = list(self.model.parameters())[0][0].detach().cpu().numpy() print(f"before {self.step} step training: rank={rank}, model_params={model_params}") return self.update_model() def configure_optimizers(self): # Somehow, Adam doesn't work for linear regression # optimizer = torch.optim.Adam(self.parameters(), lr=1e-2) optimizer = torch.optim.SGD(self.parameters(), lr=1e-2) return optimizer def update_model(self): traj = self.traj_buffer[-1] loss = F.mse_loss(self.model(traj.state), traj.reward) rank = dist.get_rank() print(f"rank={rank}, step={self.step}, loss={loss}") return loss ds = EpisodicDataset() dl = DataLoader(ds, batch_size=1, num_workers=NUM_WORKERS, pin_memory=True) ppo = PPO() trainer = pl.Trainer(gpus=NUM_GPUS, max_epochs=1, progress_bar_refresh_rate=1, accelerator='ddp') trainer.fit(ppo, dl)
As you can see, you only need to define a Dataset, a DataLoader with appropriate NUM_WORKERS
, and a pytorch-lightning Trainer in which you specify the number of gpus. Each of the NUM_GPUS
GPUs will then use NUM_WORKERS
processes for reading data and use one main process for training the model.
The example shows that each trainer on a GPU maintains a list of trajectories, which is not shared with the trainers on other GPUs. However, I believe model parameters are synced after every loss function computation because the underlying mechanism is still DistributedDataParallel
.
References
[1] https://czxttkl.com/2020/10/03/analyze-distributeddataparallels-behavior/