lightning/pytorch_lightning/plugins/layer_sync.py

95 lines
3.8 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
from torch.nn import Module
class LayerSync(ABC):
"""Abstract base class for creating plugins that wrap layers of a model with synchronization logic for
multiprocessing."""
@abstractmethod
def apply(self, model: Module) -> Module:
"""Override this method to apply synchronization to the layers of this model."""
@abstractmethod
def revert(self, model: Module) -> Module:
"""Override this method to undo all modifications made in :meth:`apply`."""
class NativeSyncBatchNorm(LayerSync):
"""A plugin that wraps all batch normalization layers of a model with synchronization logic for
multiprocessing.
This plugin has no effect in single-device operation.
"""
def apply(self, model: Module) -> Module:
"""Add global batchnorm for a model spread across multiple GPUs and nodes.
Override this method to synchronize batchnorm layers between specific process groups instead
of the whole world.
Args:
model: Reference to the current LightningModule
Return:
LightningModule with batchnorm layers synchronized within the process groups.
"""
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
def revert(self, model: Module) -> Module:
"""Convert the wrapped batchnorm layers back to regular batchnorm layers.
Args:
model: Reference to the current LightningModule
Return:
LightningModule with regular batchnorm layers that will no longer sync across processes.
"""
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
# Original author: Kapil Yedidi (@kapily)
converted_module = model
if isinstance(model, torch.nn.modules.batchnorm.SyncBatchNorm):
# Unfortunately, LayerSync does not store the original class - if it did
# we could return the one that was originally created.
converted_module = _BatchNormXd(
model.num_features, model.eps, model.momentum, model.affine, model.track_running_stats
)
if model.affine:
with torch.no_grad():
converted_module.weight = model.weight
converted_module.bias = model.bias
converted_module.running_mean = model.running_mean
converted_module.running_var = model.running_var
converted_module.num_batches_tracked = model.num_batches_tracked
if hasattr(model, "qconfig"):
converted_module.qconfig = model.qconfig
for name, child in model.named_children():
converted_module.add_module(name, self.revert(child))
del model
return converted_module
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input: torch.Tensor) -> None:
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
# is this method that is overwritten by the subclass.
# Here, we are bypassing some tensor sanity checks and trusting that the user
# provides the right input dimensions at inference.
return