2020-08-20 02:03:22 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-01-31 11:08:16 +00:00
|
|
|
import numbers
|
2021-01-13 19:35:42 +00:00
|
|
|
import warnings
|
2021-01-31 11:08:16 +00:00
|
|
|
from typing import Any
|
2019-06-25 23:42:15 +00:00
|
|
|
|
2019-06-25 23:52:26 +00:00
|
|
|
import torch
|
2021-01-31 11:08:16 +00:00
|
|
|
from torch.nn import DataParallel
|
2019-10-22 08:32:40 +00:00
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
2020-08-07 22:33:51 +00:00
|
|
|
|
2021-01-13 19:35:42 +00:00
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2021-01-31 11:08:16 +00:00
|
|
|
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
|
|
|
from pytorch_lightning.overrides.distributed import LightningDistributedModule
|
|
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
2020-09-23 04:19:46 +00:00
|
|
|
|
|
|
|
|
2019-06-25 23:42:15 +00:00
|
|
|
class LightningDataParallel(DataParallel):
|
2019-07-18 15:39:06 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def __init__(self, module: LightningModule, *args, **kwargs):
|
|
|
|
warnings.warn(
|
|
|
|
"The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
2021-02-01 13:34:59 +00:00
|
|
|
" From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning
|
2021-01-31 11:08:16 +00:00
|
|
|
)
|
|
|
|
super().__init__(LightningParallelModule(module), *args, **kwargs)
|
2019-06-25 23:52:26 +00:00
|
|
|
|
2019-07-03 20:44:18 +00:00
|
|
|
|
|
|
|
class LightningDistributedDataParallel(DistributedDataParallel):
|
|
|
|
|
2021-01-13 19:35:42 +00:00
|
|
|
def __init__(self, module: LightningModule, *args, **kwargs):
|
|
|
|
warnings.warn(
|
|
|
|
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
|
2021-01-31 11:08:16 +00:00
|
|
|
" From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
|
2021-01-13 19:35:42 +00:00
|
|
|
DeprecationWarning
|
|
|
|
)
|
|
|
|
super().__init__(LightningDistributedModule(module), *args, **kwargs)
|
2019-07-03 20:44:18 +00:00
|
|
|
|
2019-07-03 20:43:05 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
class LightningParallelModule(_LightningModuleWrapperBase):
|
|
|
|
"""
|
|
|
|
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
|
|
|
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
|
|
|
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as
|
|
|
|
shown in the example. It also takes care of converting Python scalars to Tensors and
|
|
|
|
un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
dp_model = torch.nn.DataParallel(
|
|
|
|
module=LightningParallelModule(lightning_module),
|
|
|
|
device_ids=[3, 4],
|
|
|
|
...
|
|
|
|
)
|
2019-06-25 23:52:26 +00:00
|
|
|
|
|
|
|
Args:
|
2021-01-31 11:08:16 +00:00
|
|
|
pl_module: the model to wrap
|
2019-06-25 23:52:26 +00:00
|
|
|
|
|
|
|
"""
|
2021-02-01 13:34:59 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def __init__(self, pl_module: LightningModule):
|
|
|
|
super().__init__(pl_module)
|
2019-10-04 19:07:54 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def forward(self, *inputs, **kwargs):
|
|
|
|
output = super().forward(*inputs, **kwargs)
|
2019-06-25 23:52:26 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def output_transform(data: Any):
|
|
|
|
data = python_scalar_to_tensor(data, self.module.device)
|
|
|
|
data = unsqueeze_scalar_tensor(data)
|
|
|
|
return data
|
2019-06-25 23:52:26 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
output = apply_to_collection(
|
|
|
|
output,
|
|
|
|
dtype=(numbers.Number, torch.Tensor),
|
|
|
|
function=output_transform,
|
|
|
|
)
|
|
|
|
return output
|
2020-04-02 15:46:20 +00:00
|
|
|
|
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
|
|
|
|
""" Converts a Python scalar number to a torch tensor and places it on the given device. """
|
|
|
|
if isinstance(data, numbers.Number):
|
|
|
|
data = torch.tensor([data], device=device)
|
|
|
|
return data
|
2020-09-23 04:19:46 +00:00
|
|
|
|
2020-04-02 15:46:20 +00:00
|
|
|
|
2021-01-31 11:08:16 +00:00
|
|
|
def unsqueeze_scalar_tensor(data: Any) -> Any:
|
|
|
|
""" Un-squeezes a 0-dim tensor. """
|
|
|
|
if isinstance(data, torch.Tensor) and data.dim() == 0:
|
|
|
|
data = data.unsqueeze(0)
|
|
|
|
return data
|