Refactor LightningDataParallel (#5670)

* module

* fix model access

* scalar conversion

* refactor

* kwargs

* auto unsqueeze

* refactor code duplication

* clean up

* docs

* update dp docs

* changelog

* generalize test

* test

* rename

* warning cache

* isort

* unsqueezing test

* device

* device

* scalar test

* device

* device

* include coverage of overrides

* clear

* add deprecation test

* docs

* improve coverage

* increase coverage

* fix merge

* extend test

* rename base class

* mention the predict method in docs

* combine iteration over collection

* remove override

* move

* line

* Apply suggestions from code review

* fix running stage

* f401

* fix cyclic import

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-01-31 12:08:16 +01:00 committed by GitHub
parent 5d239ccd70
commit 692f77b8a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 348 additions and 348 deletions

View File

@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645))
- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670))
### Removed
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))

View File

@ -21,7 +21,7 @@ from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -74,7 +74,7 @@ class DataParallelAccelerator(Accelerator):
# set dp device
torch.cuda.set_device(self.trainer.root_gpu)
model = LightningDataParallel(model, device_ids=device_ids)
model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids)
return model
def __init_half_precision(self, model):
@ -181,8 +181,10 @@ class DataParallelAccelerator(Accelerator):
scheduler.load_state_dict(state)
def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, LightningParallelModule):
model = model.module
return model
@property

View File

@ -0,0 +1,2 @@
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401

View File

@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
class _LightningModuleWrapperBase(torch.nn.Module):
def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
If the LightningModule is in none of the states `training`, `testing` or `validation`,
the inputs will be redirected to the
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
Inheriting classes may also modify the inputs or outputs of forward.
Args:
pl_module: the model to wrap
"""
super().__init__()
self.module = pl_module
def forward(self, *inputs, **kwargs):
running_stage = self.module.running_stage
if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:
output = self.module.predict(*inputs, **kwargs)
return output
def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')

View File

@ -11,154 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import threading
import numbers
import warnings
from collections.abc import Iterable, Mapping
from itertools import chain
from typing import Any, Optional
from typing import Any
import torch
from torch import Tensor
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel, Module
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache
def _find_tensors(obj): # pragma: no-cover
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []
def get_a_var(obj): # pragma: no-cover
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
warning_cache = WarningCache()
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.utilities.apply_func import apply_to_collection
class LightningDataParallel(DataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])"
f" but found one of them on device: {t.device}"
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.",
DeprecationWarning
)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
running_stage = self.module.running_stage
if running_stage == RunningStage.TRAINING:
return self.module.training_step(*inputs[0], **kwargs[0])
elif running_stage == RunningStage.TESTING:
return self.module.test_step(*inputs[0], **kwargs[0])
elif running_stage == RunningStage.EVALUATING:
return self.module.validation_step(*inputs[0], **kwargs[0])
else:
return self.module.predict(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs)
return outputs
def __gather_structured_result(self, outputs):
prototype_output = outputs[0]
original_class = prototype_output.__class__
outputs = [dict(x) for x in outputs]
# remove all the meta info
meta = outputs[0]['meta']
for i, output in enumerate(outputs):
del output['meta']
outputs = self.gather(outputs)
result = original_class()
result.update(outputs)
result['meta'] = meta
return result
def gather(self, outputs):
r"""
Override the gather method to support python scalars as well.
"""
def gather_map(outputs):
elem = outputs[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
return Gather.apply(self.output_device, self.dim, *outputs)
if elem is None:
return None
if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem))
if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))
return outputs
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
gather_map = None
return res
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
super().__init__(LightningParallelModule(module), *args, **kwargs)
class LightningDistributedDataParallel(DistributedDataParallel):
@ -166,26 +41,25 @@ class LightningDistributedDataParallel(DistributedDataParallel):
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.",
" From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
DeprecationWarning
)
super().__init__(LightningDistributedModule(module), *args, **kwargs)
class LightningDistributedModule(torch.nn.Module):
def __init__(self, pl_module: LightningModule):
class LightningParallelModule(_LightningModuleWrapperBase):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ```test_step``.
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
shown in the example.
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as
shown in the example. It also takes care of converting Python scalars to Tensors and
un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`.
Example:
ddp_model = DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
dp_model = torch.nn.DataParallel(
module=LightningParallelModule(lightning_module),
device_ids=[3, 4],
...
)
@ -193,182 +67,34 @@ class LightningDistributedModule(torch.nn.Module):
pl_module: the model to wrap
"""
super().__init__()
self.module = pl_module
def __init__(self, pl_module: LightningModule):
super().__init__(pl_module)
def forward(self, *inputs, **kwargs):
output = super().forward(*inputs, **kwargs)
running_stage = self.module.running_stage
if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:
output = self.module.predict(*inputs, **kwargs)
def output_transform(data: Any):
data = python_scalar_to_tensor(data, self.module.device)
data = unsqueeze_scalar_tensor(data)
return data
output = apply_to_collection(
output,
dtype=(numbers.Number, torch.Tensor),
function=output_transform,
)
return output
# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
""" Converts a Python scalar number to a torch tensor and places it on the given device. """
if isinstance(data, numbers.Number):
data = torch.tensor([data], device=device)
return data
def warn_if_output_is_none(output: Any, method_name: str) -> None:
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
def warn_missing_output(fx_called):
if fx_called == 'training_step':
warning_cache.warn("Your training_step returned None. Make sure that was your intention!")
def parallel_apply(
modules: Module,
inputs: Tensor,
kwargs_tup: Optional[tuple] = None,
devices: Optional[list] = None,
): # pragma: no-cover
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules: modules to be parallelized
inputs: inputs to the modules
devices: CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({}, ) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input, )
module = module.to(device)
# ---------------
# CHANGE
if module.running_stage == RunningStage.TRAINING:
output = module.training_step(*input, **kwargs)
fx_called = 'training_step'
elif module.running_stage == RunningStage.TESTING:
output = module.test_step(*input, **kwargs)
fx_called = 'test_step'
elif module.running_stage == RunningStage.EVALUATING:
output = module.validation_step(*input, **kwargs)
fx_called = 'validation_step'
else:
output = module.predict(*input, **kwargs)
fx_called = 'predict'
if output is None:
warn_missing_output(fx_called)
if output is not None and module._distrib_type in ('dp', 'ddp2'):
auto_squeeze_dim_zeros(output)
# ---------------
with lock:
results[i] = output
# todo: specify the possible exception
except Exception as ex:
with lock:
results[i] = ex
# TODO: fix hack (maybe not a hack)
# make sure each module knows what training state it's in...
# fixes weird bug where copies are out of sync
root_m = modules[0]
for m in modules[1:]:
m.training = root_m.training
m.testing = root_m.testing
if len(modules) > 1:
threads = [
threading.Thread(target=_worker, args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
def auto_squeeze_dim_zeros(output):
"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
if isinstance(output, torch.Tensor):
output = output.unsqueeze(0)
return output
for k, v in output.items():
if not isinstance(v, torch.Tensor):
continue
is_scalar = v.dim() == 0
if is_scalar:
output[k] = output[k].unsqueeze(0)
def unsqueeze_scalar_tensor(data: Any) -> Any:
""" Un-squeezes a 0-dim tensor. """
if isinstance(data, torch.Tensor) and data.dim() == 0:
data = data.unsqueeze(0)
return data

View File

@ -0,0 +1,77 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any
import torch
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
shown in the example.
Example:
ddp_model = torch.nn.parallel.DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
...
)
Args:
pl_module: the model to wrap
"""
super().__init__(pl_module)
def _find_tensors(obj): # pragma: no-cover
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []
# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False

View File

@ -21,7 +21,7 @@ from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward
from pytorch_lightning.overrides.distributed import LightningDistributedModule, prepare_for_backward
from pytorch_lightning.plugins.legacy.plugin import LightningPlugin
from pytorch_lightning.utilities import DeviceType

View File

@ -23,3 +23,6 @@ class WarningCache:
if m not in self.warnings:
self.warnings.add(m)
rank_zero_warn(m)
def clear(self):
self.warnings.clear()

View File

@ -49,7 +49,6 @@ omit =
pytorch_lightning/accelerators/dp_*.py
pytorch_lightning/accelerators/tpu_*.py
pytorch_lightning/cluster_environments/*.py
pytorch_lightning/overrides/data_parallel.py
pytorch_lightning/utilities/xla_device_utils.py
pytorch_lightning/utilities/distributed.py
pytorch_lightning/tuner/auto_gpu_select.py

View File

@ -18,7 +18,12 @@ import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.overrides.data_parallel import (
LightningDataParallel,
LightningDistributedDataParallel,
LightningParallelModule,
)
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin
from tests.base import BoringModel
from tests.deprecated_api import _soft_unimport_module
@ -165,6 +170,8 @@ class CustomDDPPlugin(DDPPlugin):
device_ids=device_ids,
**self._ddp_kwargs,
)
assert isinstance(model, torch.nn.parallel.DistributedDataParallel)
assert isinstance(model.module, LightningDistributedModule)
return model
@ -180,3 +187,14 @@ def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir):
plugins=[CustomDDPPlugin()]
)
trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_v1_4_0_deprecated_lightning_data_parallel():
model = BoringModel()
with pytest.deprecated_call(
match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
):
dp_model = LightningDataParallel(model, device_ids=[0])
assert isinstance(dp_model, torch.nn.DataParallel)
assert isinstance(dp_model.module, LightningParallelModule)

View File

@ -393,7 +393,7 @@ def test_dp_resume(tmpdir):
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
dp_model.module.running_stage = RunningStage.EVALUATING
dp_model.module.module.running_stage = RunningStage.EVALUATING
dataloader = trainer.train_dataloader
tpipes.run_prediction(dp_model, dataloader, dp=True)

View File

@ -2,36 +2,57 @@ from unittest.mock import MagicMock
import pytest
import torch
from torch.nn import DataParallel
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import warning_cache
from pytorch_lightning.overrides.data_parallel import (
LightningParallelModule,
python_scalar_to_tensor,
unsqueeze_scalar_tensor,
)
from pytorch_lightning.trainer.states import RunningStage
from tests.base import BoringModel
def test_lightning_distributed_module_methods():
""" Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """
@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
def test_lightning_wrapper_module_methods(wrapper_class):
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
pl_module = MagicMock()
dist_module = LightningDistributedModule(pl_module)
wrapped_module = wrapper_class(pl_module)
batch = torch.rand(5)
batch_idx = 3
pl_module.running_stage = RunningStage.TRAINING
dist_module(batch, batch_idx)
wrapped_module(batch, batch_idx)
pl_module.training_step.assert_called_with(batch, batch_idx)
pl_module.running_stage = RunningStage.TESTING
dist_module(batch, batch_idx)
wrapped_module(batch, batch_idx)
pl_module.test_step.assert_called_with(batch, batch_idx)
pl_module.running_stage = RunningStage.EVALUATING
dist_module(batch, batch_idx)
wrapped_module(batch, batch_idx)
pl_module.validation_step.assert_called_with(batch, batch_idx)
pl_module.running_stage = None
wrapped_module(batch)
pl_module.predict.assert_called_with(batch)
def test_lightning_distributed_module_warn_none_output():
""" Test that the LightningDistributedModule warns about forgotten return statement. """
@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
def test_lightning_wrapper_module_warn_none_output(wrapper_class):
""" Test that the LightningWrapper module warns about forgotten return statement. """
warning_cache.clear()
pl_module = MagicMock()
dist_module = LightningDistributedModule(pl_module)
wrapped_module = wrapper_class(pl_module)
pl_module.training_step.return_value = None
pl_module.validation_step.return_value = None
@ -39,12 +60,95 @@ def test_lightning_distributed_module_warn_none_output():
with pytest.warns(UserWarning, match="Your training_step returned None"):
pl_module.running_stage = RunningStage.TRAINING
dist_module()
wrapped_module()
with pytest.warns(UserWarning, match="Your test_step returned None"):
pl_module.running_stage = RunningStage.TESTING
dist_module()
wrapped_module()
with pytest.warns(UserWarning, match="Your validation_step returned None"):
pl_module.running_stage = RunningStage.EVALUATING
dist_module()
wrapped_module()
with pytest.warns(None) as record:
pl_module.running_stage = None
wrapped_module()
assert not record
@pytest.mark.parametrize("inp,expected", [
[torch.tensor(1.0), torch.tensor([1.0])],
[torch.tensor([2.0]), torch.tensor([2.0])],
[torch.ones(3, 4, 5), torch.ones(3, 4, 5)],
])
def test_unsqueeze_scalar_tensor(inp, expected):
""" Test that the utility function unsqueezes only scalar tensors. """
assert torch.all(unsqueeze_scalar_tensor(inp).eq(expected))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-gpu machine")
def test_lightning_parallel_module_unsqueeze_scalar():
""" Test that LightningParallelModule takes care of un-squeezeing 0-dim tensors. """
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
loss = output["loss"]
loss = loss.squeeze()
assert loss.dim() == 0
# PyTorch usually warns about 0-dim tensors returned in DP
return {"loss": loss}
model = TestModel()
model.running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).cuda()
batch_idx = 0
wrapped_model = LightningParallelModule(model).cuda()
dp_module = DataParallel(wrapped_model, device_ids=[0, 1])
output = wrapped_model(batch, batch_idx)
assert output["loss"].dim() == 1
with pytest.warns(None) as record:
output = dp_module(batch, batch_idx)
assert output["loss"].dim() == 1
assert not record
@pytest.mark.parametrize("inp,expected", [
[1.0, torch.tensor([1.0])],
[2, torch.tensor([2.0])],
[True, torch.tensor([True])],
])
def test_python_scalar_to_tensor(inp, expected):
assert torch.all(python_scalar_to_tensor(inp).eq(expected))
@pytest.mark.parametrize("device", [
torch.device("cpu"),
torch.device("cuda", 0)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_lightning_parallel_module_python_scalar_conversion(device):
""" Test that LightningParallelModule can convert Python scalars to tensors. """
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
# PyTorch DP does not support Python scalars, Lightning converts them to tensors
output.update({"python scalar": 12.3})
return output
model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0
wrapped_model = LightningParallelModule(model)
output = wrapped_model(batch, batch_idx)
assert output["python scalar"] == torch.tensor([12.3], device=device)