lightning/tests/parity_fabric/test_parity_ddp.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

167 lines
5.2 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from copy import deepcopy
import torch
import torch.distributed
import torch.nn.functional
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from lightning.fabric.fabric import Fabric
from parity_fabric.models import ConvNet
from parity_fabric.utils import (
cuda_reset,
is_cuda_memory_close,
is_state_dict_equal,
is_timing_close,
make_deterministic,
)
def train_torch_ddp(
rank,
world_size,
device=torch.device("cpu"),
backend="nccl",
):
make_deterministic()
memory_stats = {}
os.environ["LOCAL_RANK"] = str(rank)
torch.distributed.init_process_group(backend, rank=rank, world_size=world_size)
model = ConvNet().to(device)
initial_state_dict = deepcopy(model.state_dict())
ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None))
dataloader = model.get_dataloader()
sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False)
dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size)
optimizer = model.get_optimizer()
loss_fn = model.get_loss_function()
memory_stats["start"] = torch.cuda.memory_stats()
ddp_model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()
inputs, labels = next(iterator)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
t1 = time.perf_counter()
iteration_timings.append(t1 - t0)
memory_stats["end"] = torch.cuda.memory_stats()
# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict())
return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats
def train_fabric_ddp(fabric):
make_deterministic()
memory_stats = {}
model = ConvNet()
initial_state_dict = deepcopy(model.state_dict())
optimizer = model.get_optimizer()
model, optimizer = fabric.setup(model, optimizer)
dataloader = model.get_dataloader()
dataloader = fabric.setup_dataloaders(dataloader)
loss_fn = model.get_loss_function()
memory_stats["start"] = torch.cuda.memory_stats()
model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()
inputs, labels = next(iterator)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
fabric.backward(loss)
optimizer.step()
t1 = time.perf_counter()
iteration_timings.append(t1 - t0)
memory_stats["end"] = torch.cuda.memory_stats()
# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, model.state_dict())
return model.state_dict(), torch.tensor(iteration_timings), memory_stats
def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float = 0.02):
cuda_reset()
# Launch processes with Fabric and re-use them for the PyTorch training for convenience
fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices)
fabric.launch()
# Train with Fabric
state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric)
fabric.barrier()
cuda_reset()
torch.distributed.destroy_process_group()
# sleep for a bit to avoid race conditions, since the very first call in `train_torch_ddp`
# is initializing a new process group
time.sleep(3)
# Train with raw PyTorch
state_dict_torch, timings_torch, memory_torch = train_torch_ddp(
rank=fabric.global_rank,
world_size=fabric.world_size,
device=fabric.device,
backend=fabric.strategy._process_group_backend,
)
# Compare the final weights
assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric)))
# Compare the time per iteration
assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance)))
# Compare memory usage
if accelerator == "cuda":
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"])))
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"])))
if __name__ == "__main__":
from jsonargparse.cli import CLI
CLI(run_parity_test)