PyTorch Lightning template

Back to the old days, I’ve studied how to implement highly efficient PyTorch pipelines for multi-gpu training [1]. DistributedDataParallel is the way to go, but it is cumbersome that we need boilerplates for spawning workers and constructing data readers.

Now, PyTorch Lighting offers clean API for setting up multi-gpu training easily. Here is a template I designed, which I will stick to for prototyping models for the rest of my life : )

import os
from typing import List, Any

from dataclasses import dataclass
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, IterableDataset, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy


TOTAL_NUM_BATCHES = 320
BATCH_SIZE = 32
STATE_DIM = 5
NUM_GPUS = 2
NUM_WORKERS = 2
UPDATE_FREQ = 1
WEIGHTS = torch.tensor([2.0, 3.1, 2.1, -1.5, -1.7])

@dataclass
class PolicyGradientInput:
    pid: int
    state: torch.Tensor
    reward: torch.Tensor

    def __len__(self):
        return self.state.shape[0]

    @classmethod
    def from_batch(cls, x):
        return cls(
            pid=x[0][0].item(),
            state=x[1][0].reshape(BATCH_SIZE, STATE_DIM),
            reward=x[2][0].reshape(BATCH_SIZE, 1),
        )

class EpisodicDataset(IterableDataset):
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        pid = os.getpid()
        if worker_info is None:
            total_num_batches = TOTAL_NUM_BATCHES // NUM_GPUS
            worker_id = -1
        else:
            total_workers = worker_info.num_workers
            total_num_batches = int(TOTAL_NUM_BATCHES // NUM_GPUS // total_workers)
            worker_id = worker_info.id
        # You will see that we have an EpisodicDataset on each of NUM_GPUS processes
        print(f"{worker_info}, pid={pid}, total_num_batches={total_num_batches}")

        for _ in range(total_num_batches):
            state = torch.randn(BATCH_SIZE, STATE_DIM)
            reward = torch.sum(state * WEIGHTS, dim=1)
            yield (pid, state, reward)


class PPO(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(STATE_DIM, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1)
        )
        self.traj_buffer = []
        self.update_freq = UPDATE_FREQ
        self.step = 0

    def training_step(self, batch: List[Any], batch_idx):
        batch: PolicyGradientInput = PolicyGradientInput.from_batch(batch)
        self.traj_buffer.append(batch)
        self.step += 1
        rank = dist.get_rank()
        # use first three trajectories' pids as as signature. quickly check if all trainers share the same data
        # the answer is each trainer maintains different trajectories
        traj_buffer_signature = ','.join([str(traj.pid) for traj in self.traj_buffer[:3]])
        print(f"rank={rank}, traj_buffer_len={len(self.traj_buffer)}, step={self.step}, signature={traj_buffer_signature}")
        if self.step % self.update_freq == 0:
            model_params = list(self.model.parameters())[0][0].detach().cpu().numpy()
            print(f"before {self.step} step training: rank={rank}, model_params={model_params}")
            return self.update_model()

    def configure_optimizers(self):
        # Somehow, Adam doesn't work for linear regression
        # optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-2)
        return optimizer

    def update_model(self):
        traj = self.traj_buffer[-1]
        loss = F.mse_loss(self.model(traj.state), traj.reward)
        rank = dist.get_rank()
        print(f"rank={rank}, step={self.step}, loss={loss}")
        return loss


ds = EpisodicDataset()
dl = DataLoader(ds, batch_size=1, num_workers=NUM_WORKERS, pin_memory=True)
ppo = PPO()
trainer = pl.Trainer(gpus=NUM_GPUS, max_epochs=1, progress_bar_refresh_rate=1, accelerator='ddp')
trainer.fit(ppo, dl)

As you can see, you only need to define a Dataset, a DataLoader with appropriate NUM_WORKERS, and a pytorch-lightning Trainer in which you specify the number of gpus. Each of the NUM_GPUS GPUs will then use NUM_WORKERS processes for reading data and use one main process for training the model.

The example shows that each trainer on a GPU maintains a list of trajectories, which is not shared with the trainers on other GPUs. However, I believe model parameters are synced after every loss function computation because the underlying mechanism is still DistributedDataParallel

 

References

[1] https://czxttkl.com/2020/10/03/analyze-distributeddataparallels-behavior/ 

Leave a comment

Your email address will not be published. Required fields are marked *