Black format pytorch_lightning/core/lightning.py (#3576)

* Black format pytorch_lightning/core/hooks.py

Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter

* Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
This commit is contained in:
ananthsub 2020-09-20 19:59:21 -07:00 committed by GitHub
parent 21cfdf6874
commit cf1b946d4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 112 additions and 66 deletions

View File

@ -23,21 +23,25 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.distributed as torch_distrib
from torch import Tensor, ScriptModule
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import EvalResult, TrainResult
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks, DataHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
from pytorch_lightning.core.step_result import TrainResult, EvalResult
try:
import torch_xla.core.xla_model as xm
@ -47,7 +51,9 @@ else:
XLA_AVAILABLE = True
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module):
class LightningModule(
ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module
):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -111,7 +117,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
True if your model is currently running on GPUs.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
"""
return self.device.type == 'cuda'
return self.device.type == "cuda"
def print(self, *args, **kwargs) -> None:
r"""
@ -270,7 +276,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
The loss value shown in the progress bar is smoothed (averaged) over the last values,
so it differs from the actual loss returned in train/validation step.
"""
rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer')
rank_zero_warn(
"`training_step` must be implemented to be used with the Lightning Trainer"
)
def training_step_end(self, *args, **kwargs):
"""
@ -338,9 +346,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
See the :ref:`multi_gpu` guide for more details.
"""
def training_epoch_end(
self, outputs: Union[TrainResult, List[TrainResult]]
):
def training_epoch_end(self, outputs: Union[TrainResult, List[TrainResult]]):
"""
Called at the end of the training epoch with the outputs of all training steps.
Use this in case you need to do something with all the outputs for every training_step.
@ -542,7 +548,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
"""
def validation_epoch_end(
self, outputs: Union[EvalResult, List[EvalResult]]
self, outputs: Union[EvalResult, List[EvalResult]]
) -> EvalResult:
"""
Called at the end of the validation epoch with the outputs of all validation steps.
@ -743,7 +749,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
"""
def test_epoch_end(
self, outputs: Union[EvalResult, List[EvalResult]]
self, outputs: Union[EvalResult, List[EvalResult]]
) -> EvalResult:
"""
@ -798,7 +804,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
return results
"""
def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel:
def configure_ddp(
self, model: "LightningModule", device_ids: List[int]
) -> DistributedDataParallel:
r"""
Override to init DDP in your own way or with your own wrapper.
The only requirements are that:
@ -828,7 +836,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
return model
"""
model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True)
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model
def _init_slurm_connection(self) -> None:
@ -841,7 +851,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ['SLURM_JOB_ID']
default_port = os.environ["SLURM_JOB_ID"]
default_port = default_port[-4:]
# all ports should be in the 10k+ range
@ -852,20 +862,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
# if user gave a port number, use that one instead
try:
default_port = os.environ['MASTER_PORT']
default_port = os.environ["MASTER_PORT"]
except Exception:
os.environ['MASTER_PORT'] = str(default_port)
os.environ["MASTER_PORT"] = str(default_port)
# figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
except Exception:
root_node = '127.0.0.1'
root_node = "127.0.0.1"
root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
os.environ["MASTER_ADDR"] = root_node
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None:
def init_ddp_connection(
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
) -> None:
"""
Override to define your custom way of setting up a distributed environment.
@ -880,27 +892,35 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
if is_slurm_managing_tasks:
self._init_slurm_connection()
if 'MASTER_ADDR' not in os.environ:
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ['MASTER_ADDR'] = '127.0.0.1'
if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
if 'MASTER_PORT' not in os.environ:
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
os.environ['MASTER_PORT'] = '12910'
if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size:
rank_zero_warn(
f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored."
)
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
log.info(
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
@ -918,8 +938,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
return model
def configure_apex(
self, amp: object, model: 'LightningModule', optimizers: List[Optimizer], amp_level: str
) -> Tuple['LightningModule', List[Optimizer]]:
self,
amp: object,
model: "LightningModule",
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple["LightningModule", List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
@ -950,7 +974,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
def configure_optimizers(
self,
) -> Optional[Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]]:
) -> Optional[
Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
]:
r"""
Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
@ -1064,7 +1090,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
}
"""
rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer')
rank_zero_warn(
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
)
def optimizer_step(
self,
@ -1152,7 +1180,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
else:
optimizer.step()
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
def optimizer_zero_grad(
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
):
optimizer.zero_grad()
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
@ -1198,9 +1228,15 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
Each returned batch split is passed separately to :meth:`training_step`.
"""
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
time_dims = [
len(x[0])
for x in batch
if isinstance(x, (torch.Tensor, collections.Sequence))
]
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
assert all(
x == time_dims[0] for x in time_dims
), "Batch time dimension length is ambiguous"
splits = []
for t in range(0, time_dims[0], split_size):
@ -1221,7 +1257,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + str(model_summary))
log.info("\n" + str(model_summary))
return model_summary
def freeze(self) -> None:
@ -1323,17 +1359,21 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.train_loop.running_loss.mean()
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
tqdm_dict = {'loss': '{:.3f}'.format(avg_training_loss)}
avg_training_loss = (
running_train_loss.cpu().item()
if running_train_loss is not None
else float("NaN")
)
tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss)}
if self.trainer.truncated_bptt_steps is not None:
tqdm_dict['split_idx'] = self.trainer.split_idx
tqdm_dict["split_idx"] = self.trainer.split_idx
if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version
# show last 4 places of long version strings
version = version[-4:] if isinstance(version, str) else version
tqdm_dict['v_num'] = version
tqdm_dict["v_num"] = version
return tqdm_dict
@ -1434,10 +1474,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
if not frame:
frame = inspect.currentframe().f_back
init_args = get_init_args(frame)
assert init_args, 'failed to inspect the self init'
assert init_args, "failed to inspect the self init"
if not args:
hp = init_args
self._hparams_name = 'kwargs' if hp else None
self._hparams_name = "kwargs" if hp else None
else:
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
@ -1446,7 +1486,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
self._hparams_name = cand_names[0] if cand_names else None
else:
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
self._hparams_name = 'kwargs'
self._hparams_name = "kwargs"
# `hparams` are expected here
if hp:
@ -1458,9 +1498,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
if isinstance(hp, dict):
hp = AttributeDict(hp)
elif isinstance(hp, PRIMITIVE_TYPES):
raise ValueError(f'Primitives {PRIMITIVE_TYPES} are not allowed.')
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
raise ValueError(f'Unsupported config type of {type(hp)}.')
raise ValueError(f"Unsupported config type of {type(hp)}.")
if isinstance(hp, dict) and isinstance(self.hparams, dict):
self.hparams.update(hp)
@ -1498,19 +1538,25 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
input_data = self.example_input_array
else:
if input_sample is not None:
raise ValueError(f'Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`')
raise ValueError(
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
)
else:
raise ValueError('Could not export to ONNX since neither `input_sample` nor'
' `model.example_input_array` attribute is set.')
raise ValueError(
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
input_data = input_data.to(self.device)
if 'example_outputs' not in kwargs:
if "example_outputs" not in kwargs:
self.eval()
with torch.no_grad():
kwargs['example_outputs'] = self(input_data)
kwargs["example_outputs"] = self(input_data)
torch.onnx.export(self, input_data, file_path, **kwargs)
def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]:
def to_torchscript(
self, file_path: Optional[str] = None, **kwargs
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
If you would like to customize the modules that are scripted or you want to use tracing
@ -1560,7 +1606,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
@property
def hparams(self) -> Union[AttributeDict, str]:
if not hasattr(self, '_hparams'):
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams
@ -1578,12 +1624,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
"""
try:
class_code = inspect.getsource(self.__class__)
lines = class_code.split('\n')
lines = class_code.split("\n")
for line in lines:
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
if '.hparams=' in line:
return line.split('=')[1]
if ".hparams=" in line:
return line.split("=")[1]
except Exception as e:
return 'hparams'
return "hparams"
return None