From dc4c3171fc0618fb463a846ccbde2e558902303b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 26 Feb 2022 04:51:57 +0100 Subject: [PATCH] add parity test for sync batchnorm (#12021) --- .../benchmarks/test_sync_batchnorm_parity.py | 112 +++++++++++++++ tests/models/test_sync_batchnorm.py | 131 ------------------ 2 files changed, 112 insertions(+), 131 deletions(-) create mode 100644 tests/benchmarks/test_sync_batchnorm_parity.py delete mode 100644 tests/models/test_sync_batchnorm.py diff --git a/tests/benchmarks/test_sync_batchnorm_parity.py b/tests/benchmarks/test_sync_batchnorm_parity.py new file mode 100644 index 0000000000..6c02ca0efc --- /dev/null +++ b/tests/benchmarks/test_sync_batchnorm_parity.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, DistributedSampler + +from pytorch_lightning import LightningModule, seed_everything, Trainer +from tests.helpers.runif import RunIf + + +class SyncBNModule(LightningModule): + def __init__(self, batch_size): + super().__init__() + self.batch_size = batch_size + self.bn_layer = nn.BatchNorm1d(1) + self.linear = nn.Linear(1, 10) + self.bn_outputs = [] + + def on_train_start(self) -> None: + assert isinstance(self.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm) + + def training_step(self, batch, batch_idx): + with torch.no_grad(): + out_bn = self.bn_layer(batch) + self.bn_outputs.append(out_bn.detach()) + out = self.linear(out_bn) + return out.sum() + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.02) + + def train_dataloader(self): + dataset = torch.arange(64, dtype=torch.float).view(-1, 1) + # we need to set a distributed sampler ourselves to force shuffle=False + sampler = DistributedSampler( + dataset, num_replicas=self.trainer.world_size, rank=self.trainer.global_rank, shuffle=False + ) + return DataLoader(dataset, sampler=sampler, batch_size=self.batch_size) + + +@RunIf(min_gpus=2, standalone=True) +def test_sync_batchnorm_parity(tmpdir): + """Test parity between 1) Training a synced batch-norm layer on 2 GPUs with batch size B per device 2) Training + a batch-norm layer on CPU with twice the batch size.""" + seed_everything(3) + # 2 GPUS, batch size = 4 per GPU => total batch size = 8 + model = SyncBNModule(batch_size=4) + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="gpu", + strategy="ddp", + devices=2, + max_steps=3, + sync_batchnorm=True, + num_sanity_val_steps=0, + replace_sampler_ddp=False, + deterministic=True, + benchmark=False, + ) + trainer.fit(model) + + # the strategy is responsible for tearing down the batchnorm wrappers + assert not isinstance(model.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm) + assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm) + + bn_outputs = torch.stack(model.bn_outputs) # 2 x 4 x 1 on each GPU + bn_outputs_multi_device = trainer.strategy.all_gather(bn_outputs).cpu() # 2 x 2 x 4 x 1 + + if trainer.global_rank == 0: + # pretend we are now training on a single GPU/process + # (we are reusing the rank 0 from the previous training) + + # 1 GPU, batch size = 8 => total batch size = 8 + bn_outputs_single_device = _train_single_process_sync_batchnorm(batch_size=8, num_steps=3) + + gpu0_outputs = bn_outputs_multi_device[0] # 2 x 4 x 1 + gpu1_outputs = bn_outputs_multi_device[1] # 2 x 4 x 1 + slice0 = bn_outputs_single_device[:, 0::2] + slice1 = bn_outputs_single_device[:, 1::2] + + assert torch.allclose(gpu0_outputs, slice0) + assert torch.allclose(gpu1_outputs, slice1) + + +def _train_single_process_sync_batchnorm(batch_size, num_steps): + seed_everything(3) + dataset = torch.arange(64, dtype=torch.float).view(-1, 1) + train_dataloader = DataLoader(dataset, batch_size=batch_size) + model = SyncBNModule(batch_size=batch_size) + optimizer = model.configure_optimizers() + model.train() + for batch_idx, batch in enumerate(train_dataloader): + optimizer.zero_grad() + loss = model.training_step(batch, batch) + loss.backward() + optimizer.step() + if batch_idx == num_steps - 1: + break + + return torch.stack(model.bn_outputs) # num_steps x batch_size x 1 diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py deleted file mode 100644 index cb1cba0646..0000000000 --- a/tests/models/test_sync_batchnorm.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F - -from pytorch_lightning import LightningModule, seed_everything, Trainer -from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies import DDPSpawnStrategy -from pytorch_lightning.utilities import FLOAT16_EPSILON -from tests.helpers.datamodules import MNISTDataModule -from tests.helpers.runif import RunIf -from tests.helpers.utils import set_random_main_port - - -class SyncBNModule(LightningModule): - def __init__(self, gpu_count=1, **kwargs): - super().__init__() - - self.gpu_count = gpu_count - self.bn_targets = None - if "bn_targets" in kwargs: - self.bn_targets = kwargs["bn_targets"] - - self.linear = nn.Linear(28 * 28, 10) - self.bn_layer = nn.BatchNorm1d(28 * 28) - - def on_train_start(self) -> None: - assert isinstance(self.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm) - - def forward(self, x, batch_idx): - with torch.no_grad(): - out_bn = self.bn_layer(x.view(x.size(0), -1)) - - if self.bn_targets: - bn_target = self.bn_targets[batch_idx] - - # executes on both GPUs - bn_target = bn_target[self.trainer.local_rank :: self.gpu_count] - bn_target = bn_target.to(out_bn.device) - assert torch.sum(torch.abs(bn_target - out_bn)) < FLOAT16_EPSILON - - out = self.linear(out_bn) - - return out, out_bn - - def training_step(self, batch, batch_idx): - x, y = batch - - y_hat, _ = self(x, batch_idx) - loss = F.cross_entropy(y_hat, y) - - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.linear.parameters(), lr=0.02) - - -# TODO: Fatal Python error: Bus error -@pytest.mark.skip(reason="Fatal Python error: Bus error") -@RunIf(min_gpus=2, standalone=True) -def test_sync_batchnorm_ddp(tmpdir): - seed_everything(234) - set_random_main_port() - - # define datamodule and dataloader - dm = MNISTDataModule() - dm.prepare_data() - dm.setup(stage=None) - - train_dataloader = dm.train_dataloader() - model = SyncBNModule() - - bn_outputs = [] - - # shuffle is false by default - for batch_idx, batch in enumerate(train_dataloader): - x, _ = batch - - _, out_bn = model.forward(x, batch_idx) - bn_outputs.append(out_bn) - - # get 3 steps - if batch_idx == 2: - break - - bn_outputs = [x.cuda() for x in bn_outputs] - - # reset datamodule - # batch-size = 16 because 2 GPUs in DDP - dm = MNISTDataModule(batch_size=16, dist_sampler=True) - dm.prepare_data() - dm.setup(stage=None) - - model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) - ddp = DDPSpawnStrategy( - parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], - num_nodes=1, - sync_batchnorm=True, - cluster_environment=LightningEnvironment(), - find_unused_parameters=True, - ) - - trainer = Trainer( - default_root_dir=tmpdir, - gpus=2, - num_nodes=1, - strategy=ddp, - max_epochs=1, - max_steps=3, - sync_batchnorm=True, - num_sanity_val_steps=0, - replace_sampler_ddp=False, - ) - - trainer.fit(model, dm) - # the strategy is responsible for tearing down the batchnorm wrappers - assert not isinstance(model.bn_layer, torch.nn.modules.batchnorm.SyncBatchNorm) - assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm)