Refactor LightningDataParallel (#5670)
* module * fix model access * scalar conversion * refactor * kwargs * auto unsqueeze * refactor code duplication * clean up * docs * update dp docs * changelog * generalize test * test * rename * warning cache * isort * unsqueezing test * device * device * scalar test * device * device * include coverage of overrides * clear * add deprecation test * docs * improve coverage * increase coverage * fix merge * extend test * rename base class * mention the predict method in docs * combine iteration over collection * remove override * move * line * Apply suggestions from code review * fix running stage * f401 * fix cyclic import Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
5d239ccd70
commit
692f77b8a7
|
@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645))
|
||||
|
||||
|
||||
- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
|
||||
|
||||
|
||||
- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
|
||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning.cluster_environments import ClusterEnvironment
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
|
||||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -74,7 +74,7 @@ class DataParallelAccelerator(Accelerator):
|
|||
|
||||
# set dp device
|
||||
torch.cuda.set_device(self.trainer.root_gpu)
|
||||
model = LightningDataParallel(model, device_ids=device_ids)
|
||||
model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids)
|
||||
return model
|
||||
|
||||
def __init_half_precision(self, model):
|
||||
|
@ -181,8 +181,10 @@ class DataParallelAccelerator(Accelerator):
|
|||
scheduler.load_state_dict(state)
|
||||
|
||||
def get_reference_model(self, model) -> LightningModule:
|
||||
if isinstance(model, LightningDataParallel):
|
||||
return model.module
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
if isinstance(model, LightningParallelModule):
|
||||
model = model.module
|
||||
return model
|
||||
|
||||
@property
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
class _LightningModuleWrapperBase(torch.nn.Module):
|
||||
|
||||
def __init__(self, pl_module: LightningModule):
|
||||
"""
|
||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||
method, either ``training_step``, ``validation_step`` or ``test_step``.
|
||||
If the LightningModule is in none of the states `training`, `testing` or `validation`,
|
||||
the inputs will be redirected to the
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
|
||||
Inheriting classes may also modify the inputs or outputs of forward.
|
||||
|
||||
Args:
|
||||
pl_module: the model to wrap
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = pl_module
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
running_stage = self.module.running_stage
|
||||
|
||||
if running_stage == RunningStage.TRAINING:
|
||||
output = self.module.training_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "training_step")
|
||||
elif running_stage == RunningStage.TESTING:
|
||||
output = self.module.test_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "test_step")
|
||||
elif running_stage == RunningStage.EVALUATING:
|
||||
output = self.module.validation_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "validation_step")
|
||||
else:
|
||||
output = self.module.predict(*inputs, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def warn_if_output_is_none(output: Any, method_name: str) -> None:
|
||||
""" Warns user about which method returned None. """
|
||||
if output is None:
|
||||
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
|
|
@ -11,154 +11,29 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import threading
|
||||
import numbers
|
||||
import warnings
|
||||
from collections.abc import Iterable, Mapping
|
||||
from itertools import chain
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.cuda._utils import _get_device_index
|
||||
from torch.nn import DataParallel, Module
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel._functions import Gather
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
||||
def _find_tensors(obj): # pragma: no-cover
|
||||
r"""
|
||||
Recursively find all tensors contained in the specified object.
|
||||
"""
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return [obj]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return itertools.chain(*map(_find_tensors, obj))
|
||||
if isinstance(obj, dict):
|
||||
return itertools.chain(*map(_find_tensors, obj.values()))
|
||||
return []
|
||||
|
||||
|
||||
def get_a_var(obj): # pragma: no-cover
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for result in map(get_a_var, obj):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
if isinstance(obj, dict):
|
||||
for result in map(get_a_var, obj.items()):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
warning_cache = WarningCache()
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
class LightningDataParallel(DataParallel):
|
||||
"""
|
||||
Override the forward call in lightning so it goes to training and validation step respectively
|
||||
"""
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
raise RuntimeError(
|
||||
f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])"
|
||||
f" but found one of them on device: {t.device}"
|
||||
)
|
||||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
|
||||
if len(self.device_ids) == 1:
|
||||
|
||||
running_stage = self.module.running_stage
|
||||
|
||||
if running_stage == RunningStage.TRAINING:
|
||||
return self.module.training_step(*inputs[0], **kwargs[0])
|
||||
|
||||
elif running_stage == RunningStage.TESTING:
|
||||
return self.module.test_step(*inputs[0], **kwargs[0])
|
||||
|
||||
elif running_stage == RunningStage.EVALUATING:
|
||||
return self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
|
||||
else:
|
||||
return self.module.predict(*inputs[0], **kwargs[0])
|
||||
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||
|
||||
if isinstance(outputs[0], Result):
|
||||
outputs = self.__gather_structured_result(outputs)
|
||||
else:
|
||||
outputs = self.gather(outputs)
|
||||
return outputs
|
||||
|
||||
def __gather_structured_result(self, outputs):
|
||||
prototype_output = outputs[0]
|
||||
original_class = prototype_output.__class__
|
||||
outputs = [dict(x) for x in outputs]
|
||||
|
||||
# remove all the meta info
|
||||
meta = outputs[0]['meta']
|
||||
for i, output in enumerate(outputs):
|
||||
del output['meta']
|
||||
|
||||
outputs = self.gather(outputs)
|
||||
|
||||
result = original_class()
|
||||
|
||||
result.update(outputs)
|
||||
result['meta'] = meta
|
||||
return result
|
||||
|
||||
def gather(self, outputs):
|
||||
r"""
|
||||
Override the gather method to support python scalars as well.
|
||||
"""
|
||||
|
||||
def gather_map(outputs):
|
||||
elem = outputs[0]
|
||||
elem_type = type(elem)
|
||||
|
||||
if isinstance(elem, torch.Tensor):
|
||||
return Gather.apply(self.output_device, self.dim, *outputs)
|
||||
|
||||
if elem is None:
|
||||
return None
|
||||
|
||||
if isinstance(elem, Mapping):
|
||||
if not all((len(elem) == len(d) for d in outputs)):
|
||||
raise ValueError('All dicts must have the same number of keys')
|
||||
return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem))
|
||||
|
||||
if isinstance(elem, Iterable) and not isinstance(elem, str):
|
||||
return elem_type(map(gather_map, zip(*outputs)))
|
||||
|
||||
return outputs
|
||||
|
||||
# Recursive function calls like this create reference cycles.
|
||||
# Setting the function to None clears the refcycle.
|
||||
try:
|
||||
res = gather_map(outputs)
|
||||
finally:
|
||||
gather_map = None
|
||||
return res
|
||||
|
||||
def parallel_apply(self, replicas, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
def __init__(self, module: LightningModule, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
||||
" From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.",
|
||||
DeprecationWarning
|
||||
)
|
||||
super().__init__(LightningParallelModule(module), *args, **kwargs)
|
||||
|
||||
|
||||
class LightningDistributedDataParallel(DistributedDataParallel):
|
||||
|
@ -166,209 +41,60 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
|||
def __init__(self, module: LightningModule, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
||||
" From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.",
|
||||
" From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
|
||||
DeprecationWarning
|
||||
)
|
||||
super().__init__(LightningDistributedModule(module), *args, **kwargs)
|
||||
|
||||
|
||||
class LightningDistributedModule(torch.nn.Module):
|
||||
class LightningParallelModule(_LightningModuleWrapperBase):
|
||||
"""
|
||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
||||
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as
|
||||
shown in the example. It also takes care of converting Python scalars to Tensors and
|
||||
un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`.
|
||||
|
||||
def __init__(self, pl_module: LightningModule):
|
||||
"""
|
||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||
method, either ``training_step``, ``validation_step`` or ```test_step``.
|
||||
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
|
||||
shown in the example.
|
||||
Example:
|
||||
|
||||
Example:
|
||||
|
||||
ddp_model = DistributedDataParallel(
|
||||
module=LightningDistributedModule(lightning_module),
|
||||
device_ids=[local_rank],
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
pl_module: the model to wrap
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.module = pl_module
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
|
||||
running_stage = self.module.running_stage
|
||||
|
||||
if running_stage == RunningStage.TRAINING:
|
||||
output = self.module.training_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "training_step")
|
||||
|
||||
elif running_stage == RunningStage.TESTING:
|
||||
output = self.module.test_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "test_step")
|
||||
|
||||
elif running_stage == RunningStage.EVALUATING:
|
||||
output = self.module.validation_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "validation_step")
|
||||
|
||||
else:
|
||||
output = self.module.predict(*inputs, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# In manual_optimization, we need to call reducer prepare_for_backward.
|
||||
# Note: Keep track of Pytorch DDP and update if there is a change
|
||||
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
|
||||
def prepare_for_backward(model: DistributedDataParallel, output: Any):
|
||||
if torch.is_grad_enabled() and model.require_backward_grad_sync:
|
||||
model.require_forward_param_sync = True
|
||||
# We'll return the output object verbatim since it is a freeform
|
||||
# object. We need to find any tensors in this object, though,
|
||||
# because we need to figure out which parameters were used during
|
||||
# this forward pass, to ensure we short circuit reduction for any
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
if model.find_unused_parameters:
|
||||
model.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
model.reducer.prepare_for_backward([])
|
||||
else:
|
||||
model.require_forward_param_sync = False
|
||||
|
||||
|
||||
def warn_if_output_is_none(output: Any, method_name: str) -> None:
|
||||
if output is None:
|
||||
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
|
||||
|
||||
|
||||
def warn_missing_output(fx_called):
|
||||
if fx_called == 'training_step':
|
||||
warning_cache.warn("Your training_step returned None. Make sure that was your intention!")
|
||||
|
||||
|
||||
def parallel_apply(
|
||||
modules: Module,
|
||||
inputs: Tensor,
|
||||
kwargs_tup: Optional[tuple] = None,
|
||||
devices: Optional[list] = None,
|
||||
): # pragma: no-cover
|
||||
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
||||
on each of :attr:`devices`.
|
||||
dp_model = torch.nn.DataParallel(
|
||||
module=LightningParallelModule(lightning_module),
|
||||
device_ids=[3, 4],
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
modules: modules to be parallelized
|
||||
inputs: inputs to the modules
|
||||
devices: CUDA devices
|
||||
pl_module: the model to wrap
|
||||
|
||||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
||||
:attr:`devices` (if given) should all have same length. Moreover, each
|
||||
element of :attr:`inputs` can either be a single object as the only argument
|
||||
to a module, or a collection of positional arguments.
|
||||
"""
|
||||
assert len(modules) == len(inputs)
|
||||
if kwargs_tup is not None:
|
||||
assert len(modules) == len(kwargs_tup)
|
||||
else:
|
||||
kwargs_tup = ({}, ) * len(modules)
|
||||
if devices is not None:
|
||||
assert len(modules) == len(devices)
|
||||
else:
|
||||
devices = [None] * len(modules)
|
||||
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
||||
lock = threading.Lock()
|
||||
results = {}
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
def __init__(self, pl_module: LightningModule):
|
||||
super().__init__(pl_module)
|
||||
|
||||
def _worker(i, module, input, kwargs, device=None):
|
||||
torch.set_grad_enabled(grad_enabled)
|
||||
if device is None:
|
||||
device = get_a_var(input).get_device()
|
||||
try:
|
||||
with torch.cuda.device(device):
|
||||
# this also avoids accidental slicing of `input` if it is a Tensor
|
||||
if not isinstance(input, (list, tuple)):
|
||||
input = (input, )
|
||||
def forward(self, *inputs, **kwargs):
|
||||
output = super().forward(*inputs, **kwargs)
|
||||
|
||||
module = module.to(device)
|
||||
def output_transform(data: Any):
|
||||
data = python_scalar_to_tensor(data, self.module.device)
|
||||
data = unsqueeze_scalar_tensor(data)
|
||||
return data
|
||||
|
||||
# ---------------
|
||||
# CHANGE
|
||||
if module.running_stage == RunningStage.TRAINING:
|
||||
output = module.training_step(*input, **kwargs)
|
||||
fx_called = 'training_step'
|
||||
|
||||
elif module.running_stage == RunningStage.TESTING:
|
||||
output = module.test_step(*input, **kwargs)
|
||||
fx_called = 'test_step'
|
||||
|
||||
elif module.running_stage == RunningStage.EVALUATING:
|
||||
output = module.validation_step(*input, **kwargs)
|
||||
fx_called = 'validation_step'
|
||||
|
||||
else:
|
||||
output = module.predict(*input, **kwargs)
|
||||
fx_called = 'predict'
|
||||
|
||||
if output is None:
|
||||
warn_missing_output(fx_called)
|
||||
|
||||
if output is not None and module._distrib_type in ('dp', 'ddp2'):
|
||||
auto_squeeze_dim_zeros(output)
|
||||
# ---------------
|
||||
|
||||
with lock:
|
||||
results[i] = output
|
||||
# todo: specify the possible exception
|
||||
except Exception as ex:
|
||||
with lock:
|
||||
results[i] = ex
|
||||
|
||||
# TODO: fix hack (maybe not a hack)
|
||||
# make sure each module knows what training state it's in...
|
||||
# fixes weird bug where copies are out of sync
|
||||
root_m = modules[0]
|
||||
for m in modules[1:]:
|
||||
m.training = root_m.training
|
||||
m.testing = root_m.testing
|
||||
|
||||
if len(modules) > 1:
|
||||
threads = [
|
||||
threading.Thread(target=_worker, args=(i, module, input, kwargs, device))
|
||||
for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
else:
|
||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
||||
|
||||
outputs = []
|
||||
for i in range(len(inputs)):
|
||||
output = results[i]
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def auto_squeeze_dim_zeros(output):
|
||||
"""
|
||||
In DP or DDP2 we need to unsqueeze dim 0
|
||||
:param output:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.unsqueeze(0)
|
||||
output = apply_to_collection(
|
||||
output,
|
||||
dtype=(numbers.Number, torch.Tensor),
|
||||
function=output_transform,
|
||||
)
|
||||
return output
|
||||
|
||||
for k, v in output.items():
|
||||
if not isinstance(v, torch.Tensor):
|
||||
continue
|
||||
|
||||
is_scalar = v.dim() == 0
|
||||
if is_scalar:
|
||||
output[k] = output[k].unsqueeze(0)
|
||||
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
|
||||
""" Converts a Python scalar number to a torch tensor and places it on the given device. """
|
||||
if isinstance(data, numbers.Number):
|
||||
data = torch.tensor([data], device=device)
|
||||
return data
|
||||
|
||||
|
||||
def unsqueeze_scalar_tensor(data: Any) -> Any:
|
||||
""" Un-squeezes a 0-dim tensor. """
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 0:
|
||||
data = data.unsqueeze(0)
|
||||
return data
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
|
||||
|
||||
class LightningDistributedModule(_LightningModuleWrapperBase):
|
||||
|
||||
def __init__(self, pl_module: LightningModule):
|
||||
"""
|
||||
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||||
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
||||
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
|
||||
shown in the example.
|
||||
|
||||
Example:
|
||||
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=LightningDistributedModule(lightning_module),
|
||||
device_ids=[local_rank],
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
pl_module: the model to wrap
|
||||
|
||||
"""
|
||||
super().__init__(pl_module)
|
||||
|
||||
|
||||
def _find_tensors(obj): # pragma: no-cover
|
||||
r"""
|
||||
Recursively find all tensors contained in the specified object.
|
||||
"""
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return [obj]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return itertools.chain(*map(_find_tensors, obj))
|
||||
if isinstance(obj, dict):
|
||||
return itertools.chain(*map(_find_tensors, obj.values()))
|
||||
return []
|
||||
|
||||
|
||||
# In manual_optimization, we need to call reducer prepare_for_backward.
|
||||
# Note: Keep track of Pytorch DDP and update if there is a change
|
||||
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
|
||||
def prepare_for_backward(model: DistributedDataParallel, output: Any):
|
||||
if torch.is_grad_enabled() and model.require_backward_grad_sync:
|
||||
model.require_forward_param_sync = True
|
||||
# We'll return the output object verbatim since it is a freeform
|
||||
# object. We need to find any tensors in this object, though,
|
||||
# because we need to figure out which parameters were used during
|
||||
# this forward pass, to ensure we short circuit reduction for any
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
if model.find_unused_parameters:
|
||||
model.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
model.reducer.prepare_for_backward([])
|
||||
else:
|
||||
model.require_forward_param_sync = False
|
|
@ -21,7 +21,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule, prepare_for_backward
|
||||
from pytorch_lightning.plugins.legacy.plugin import LightningPlugin
|
||||
from pytorch_lightning.utilities import DeviceType
|
||||
|
||||
|
|
|
@ -23,3 +23,6 @@ class WarningCache:
|
|||
if m not in self.warnings:
|
||||
self.warnings.add(m)
|
||||
rank_zero_warn(m)
|
||||
|
||||
def clear(self):
|
||||
self.warnings.clear()
|
||||
|
|
|
@ -49,7 +49,6 @@ omit =
|
|||
pytorch_lightning/accelerators/dp_*.py
|
||||
pytorch_lightning/accelerators/tpu_*.py
|
||||
pytorch_lightning/cluster_environments/*.py
|
||||
pytorch_lightning/overrides/data_parallel.py
|
||||
pytorch_lightning/utilities/xla_device_utils.py
|
||||
pytorch_lightning/utilities/distributed.py
|
||||
pytorch_lightning/tuner/auto_gpu_select.py
|
||||
|
|
|
@ -18,7 +18,12 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDataParallel,
|
||||
LightningDistributedDataParallel,
|
||||
LightningParallelModule,
|
||||
)
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin
|
||||
from tests.base import BoringModel
|
||||
from tests.deprecated_api import _soft_unimport_module
|
||||
|
@ -165,6 +170,8 @@ class CustomDDPPlugin(DDPPlugin):
|
|||
device_ids=device_ids,
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
assert isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
||||
assert isinstance(model.module, LightningDistributedModule)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -180,3 +187,14 @@ def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir):
|
|||
plugins=[CustomDDPPlugin()]
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_v1_4_0_deprecated_lightning_data_parallel():
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(
|
||||
match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
||||
):
|
||||
dp_model = LightningDataParallel(model, device_ids=[0])
|
||||
assert isinstance(dp_model, torch.nn.DataParallel)
|
||||
assert isinstance(dp_model.module, LightningParallelModule)
|
||||
|
|
|
@ -393,7 +393,7 @@ def test_dp_resume(tmpdir):
|
|||
# haven't trained with the new loaded model
|
||||
dp_model = new_trainer.model
|
||||
dp_model.eval()
|
||||
dp_model.module.running_stage = RunningStage.EVALUATING
|
||||
dp_model.module.module.running_stage = RunningStage.EVALUATING
|
||||
|
||||
dataloader = trainer.train_dataloader
|
||||
tpipes.run_prediction(dp_model, dataloader, dp=True)
|
||||
|
|
|
@ -2,36 +2,57 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import warning_cache
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningParallelModule,
|
||||
python_scalar_to_tensor,
|
||||
unsqueeze_scalar_tensor,
|
||||
)
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from tests.base import BoringModel
|
||||
|
||||
|
||||
def test_lightning_distributed_module_methods():
|
||||
""" Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """
|
||||
@pytest.mark.parametrize("wrapper_class", [
|
||||
LightningParallelModule,
|
||||
LightningDistributedModule,
|
||||
])
|
||||
def test_lightning_wrapper_module_methods(wrapper_class):
|
||||
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
|
||||
pl_module = MagicMock()
|
||||
dist_module = LightningDistributedModule(pl_module)
|
||||
wrapped_module = wrapper_class(pl_module)
|
||||
|
||||
batch = torch.rand(5)
|
||||
batch_idx = 3
|
||||
|
||||
pl_module.running_stage = RunningStage.TRAINING
|
||||
dist_module(batch, batch_idx)
|
||||
wrapped_module(batch, batch_idx)
|
||||
pl_module.training_step.assert_called_with(batch, batch_idx)
|
||||
|
||||
pl_module.running_stage = RunningStage.TESTING
|
||||
dist_module(batch, batch_idx)
|
||||
wrapped_module(batch, batch_idx)
|
||||
pl_module.test_step.assert_called_with(batch, batch_idx)
|
||||
|
||||
pl_module.running_stage = RunningStage.EVALUATING
|
||||
dist_module(batch, batch_idx)
|
||||
wrapped_module(batch, batch_idx)
|
||||
pl_module.validation_step.assert_called_with(batch, batch_idx)
|
||||
|
||||
pl_module.running_stage = None
|
||||
wrapped_module(batch)
|
||||
pl_module.predict.assert_called_with(batch)
|
||||
|
||||
def test_lightning_distributed_module_warn_none_output():
|
||||
""" Test that the LightningDistributedModule warns about forgotten return statement. """
|
||||
|
||||
@pytest.mark.parametrize("wrapper_class", [
|
||||
LightningParallelModule,
|
||||
LightningDistributedModule,
|
||||
])
|
||||
def test_lightning_wrapper_module_warn_none_output(wrapper_class):
|
||||
""" Test that the LightningWrapper module warns about forgotten return statement. """
|
||||
warning_cache.clear()
|
||||
pl_module = MagicMock()
|
||||
dist_module = LightningDistributedModule(pl_module)
|
||||
wrapped_module = wrapper_class(pl_module)
|
||||
|
||||
pl_module.training_step.return_value = None
|
||||
pl_module.validation_step.return_value = None
|
||||
|
@ -39,12 +60,95 @@ def test_lightning_distributed_module_warn_none_output():
|
|||
|
||||
with pytest.warns(UserWarning, match="Your training_step returned None"):
|
||||
pl_module.running_stage = RunningStage.TRAINING
|
||||
dist_module()
|
||||
wrapped_module()
|
||||
|
||||
with pytest.warns(UserWarning, match="Your test_step returned None"):
|
||||
pl_module.running_stage = RunningStage.TESTING
|
||||
dist_module()
|
||||
wrapped_module()
|
||||
|
||||
with pytest.warns(UserWarning, match="Your validation_step returned None"):
|
||||
pl_module.running_stage = RunningStage.EVALUATING
|
||||
dist_module()
|
||||
wrapped_module()
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
pl_module.running_stage = None
|
||||
wrapped_module()
|
||||
assert not record
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inp,expected", [
|
||||
[torch.tensor(1.0), torch.tensor([1.0])],
|
||||
[torch.tensor([2.0]), torch.tensor([2.0])],
|
||||
[torch.ones(3, 4, 5), torch.ones(3, 4, 5)],
|
||||
])
|
||||
def test_unsqueeze_scalar_tensor(inp, expected):
|
||||
""" Test that the utility function unsqueezes only scalar tensors. """
|
||||
assert torch.all(unsqueeze_scalar_tensor(inp).eq(expected))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-gpu machine")
|
||||
def test_lightning_parallel_module_unsqueeze_scalar():
|
||||
""" Test that LightningParallelModule takes care of un-squeezeing 0-dim tensors. """
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = super().training_step(batch, batch_idx)
|
||||
loss = output["loss"]
|
||||
loss = loss.squeeze()
|
||||
assert loss.dim() == 0
|
||||
# PyTorch usually warns about 0-dim tensors returned in DP
|
||||
return {"loss": loss}
|
||||
|
||||
model = TestModel()
|
||||
model.running_stage = RunningStage.TRAINING
|
||||
batch = torch.rand(2, 32).cuda()
|
||||
batch_idx = 0
|
||||
|
||||
wrapped_model = LightningParallelModule(model).cuda()
|
||||
dp_module = DataParallel(wrapped_model, device_ids=[0, 1])
|
||||
|
||||
output = wrapped_model(batch, batch_idx)
|
||||
assert output["loss"].dim() == 1
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
output = dp_module(batch, batch_idx)
|
||||
|
||||
assert output["loss"].dim() == 1
|
||||
assert not record
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inp,expected", [
|
||||
[1.0, torch.tensor([1.0])],
|
||||
[2, torch.tensor([2.0])],
|
||||
[True, torch.tensor([True])],
|
||||
])
|
||||
def test_python_scalar_to_tensor(inp, expected):
|
||||
assert torch.all(python_scalar_to_tensor(inp).eq(expected))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", [
|
||||
torch.device("cpu"),
|
||||
torch.device("cuda", 0)
|
||||
])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_lightning_parallel_module_python_scalar_conversion(device):
|
||||
""" Test that LightningParallelModule can convert Python scalars to tensors. """
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = super().training_step(batch, batch_idx)
|
||||
# PyTorch DP does not support Python scalars, Lightning converts them to tensors
|
||||
output.update({"python scalar": 12.3})
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
model.to(device)
|
||||
model.running_stage = RunningStage.TRAINING
|
||||
batch = torch.rand(2, 32).to(device)
|
||||
batch_idx = 0
|
||||
|
||||
wrapped_model = LightningParallelModule(model)
|
||||
output = wrapped_model(batch, batch_idx)
|
||||
assert output["python scalar"] == torch.tensor([12.3], device=device)
|
||||
|
|
Loading…
Reference in New Issue