95 lines
3.8 KiB
Python
95 lines
3.8 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
|
|
|
|
class LayerSync(ABC):
|
|
"""Abstract base class for creating plugins that wrap layers of a model with synchronization logic for
|
|
multiprocessing."""
|
|
|
|
@abstractmethod
|
|
def apply(self, model: Module) -> Module:
|
|
"""Override this method to apply synchronization to the layers of this model."""
|
|
|
|
@abstractmethod
|
|
def revert(self, model: Module) -> Module:
|
|
"""Override this method to undo all modifications made in :meth:`apply`."""
|
|
|
|
|
|
class NativeSyncBatchNorm(LayerSync):
|
|
"""A plugin that wraps all batch normalization layers of a model with synchronization logic for
|
|
multiprocessing.
|
|
|
|
This plugin has no effect in single-device operation.
|
|
"""
|
|
|
|
def apply(self, model: Module) -> Module:
|
|
"""Add global batchnorm for a model spread across multiple GPUs and nodes.
|
|
|
|
Override this method to synchronize batchnorm layers between specific process groups instead
|
|
of the whole world.
|
|
|
|
Args:
|
|
model: Reference to the current LightningModule
|
|
|
|
Return:
|
|
LightningModule with batchnorm layers synchronized within the process groups.
|
|
"""
|
|
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
def revert(self, model: Module) -> Module:
|
|
"""Convert the wrapped batchnorm layers back to regular batchnorm layers.
|
|
|
|
Args:
|
|
model: Reference to the current LightningModule
|
|
|
|
Return:
|
|
LightningModule with regular batchnorm layers that will no longer sync across processes.
|
|
"""
|
|
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
|
|
# Original author: Kapil Yedidi (@kapily)
|
|
converted_module = model
|
|
if isinstance(model, torch.nn.modules.batchnorm.SyncBatchNorm):
|
|
# Unfortunately, LayerSync does not store the original class - if it did
|
|
# we could return the one that was originally created.
|
|
converted_module = _BatchNormXd(
|
|
model.num_features, model.eps, model.momentum, model.affine, model.track_running_stats
|
|
)
|
|
if model.affine:
|
|
with torch.no_grad():
|
|
converted_module.weight = model.weight
|
|
converted_module.bias = model.bias
|
|
converted_module.running_mean = model.running_mean
|
|
converted_module.running_var = model.running_var
|
|
converted_module.num_batches_tracked = model.num_batches_tracked
|
|
if hasattr(model, "qconfig"):
|
|
converted_module.qconfig = model.qconfig
|
|
for name, child in model.named_children():
|
|
converted_module.add_module(name, self.revert(child))
|
|
del model
|
|
return converted_module
|
|
|
|
|
|
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
|
|
def _check_input_dim(self, input: torch.Tensor) -> None:
|
|
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
|
# is this method that is overwritten by the subclass.
|
|
# Here, we are bypassing some tensor sanity checks and trusting that the user
|
|
# provides the right input dimensions at inference.
|
|
return
|