diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5b37aae978..755cbf5a0d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -23,21 +23,25 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.distributed as torch_distrib -from torch import Tensor, ScriptModule +from pytorch_lightning import _logger as log +from pytorch_lightning.core.grads import GradInformation +from pytorch_lightning.core.hooks import DataHooks, ModelHooks +from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.step_result import EvalResult, TrainResult +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.parsing import ( + AttributeDict, + collect_init_args, + get_init_args, +) +from torch import ScriptModule, Tensor from torch.nn import Module from torch.nn.parallel import DistributedDataParallel from torch.optim.optimizer import Optimizer -from pytorch_lightning import _logger as log -from pytorch_lightning.core.grads import GradInformation -from pytorch_lightning.core.hooks import ModelHooks, DataHooks -from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -from pytorch_lightning.core.step_result import TrainResult, EvalResult try: import torch_xla.core.xla_model as xm @@ -47,7 +51,9 @@ else: XLA_AVAILABLE = True -class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module): +class LightningModule( + ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module +): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -111,7 +117,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod True if your model is currently running on GPUs. Useful to set flags around the LightningModule for different CPU vs GPU behavior. """ - return self.device.type == 'cuda' + return self.device.type == "cuda" def print(self, *args, **kwargs) -> None: r""" @@ -270,7 +276,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step. """ - rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer') + rank_zero_warn( + "`training_step` must be implemented to be used with the Lightning Trainer" + ) def training_step_end(self, *args, **kwargs): """ @@ -338,9 +346,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod See the :ref:`multi_gpu` guide for more details. """ - def training_epoch_end( - self, outputs: Union[TrainResult, List[TrainResult]] - ): + def training_epoch_end(self, outputs: Union[TrainResult, List[TrainResult]]): """ Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step. @@ -542,7 +548,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod """ def validation_epoch_end( - self, outputs: Union[EvalResult, List[EvalResult]] + self, outputs: Union[EvalResult, List[EvalResult]] ) -> EvalResult: """ Called at the end of the validation epoch with the outputs of all validation steps. @@ -743,7 +749,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod """ def test_epoch_end( - self, outputs: Union[EvalResult, List[EvalResult]] + self, outputs: Union[EvalResult, List[EvalResult]] ) -> EvalResult: """ @@ -798,7 +804,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod return results """ - def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel: + def configure_ddp( + self, model: "LightningModule", device_ids: List[int] + ) -> DistributedDataParallel: r""" Override to init DDP in your own way or with your own wrapper. The only requirements are that: @@ -828,7 +836,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod return model """ - model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True) + model = LightningDistributedDataParallel( + model, device_ids=device_ids, find_unused_parameters=True + ) return model def _init_slurm_connection(self) -> None: @@ -841,7 +851,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod # guarantees unique ports across jobs from same grid search try: # use the last 4 numbers in the job id as the id - default_port = os.environ['SLURM_JOB_ID'] + default_port = os.environ["SLURM_JOB_ID"] default_port = default_port[-4:] # all ports should be in the 10k+ range @@ -852,20 +862,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod # if user gave a port number, use that one instead try: - default_port = os.environ['MASTER_PORT'] + default_port = os.environ["MASTER_PORT"] except Exception: - os.environ['MASTER_PORT'] = str(default_port) + os.environ["MASTER_PORT"] = str(default_port) # figure out the root node addr try: - root_node = os.environ['SLURM_NODELIST'].split(' ')[0] + root_node = os.environ["SLURM_NODELIST"].split(" ")[0] except Exception: - root_node = '127.0.0.1' + root_node = "127.0.0.1" root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node) - os.environ['MASTER_ADDR'] = root_node + os.environ["MASTER_ADDR"] = root_node - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: + def init_ddp_connection( + self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True + ) -> None: """ Override to define your custom way of setting up a distributed environment. @@ -880,27 +892,35 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod if is_slurm_managing_tasks: self._init_slurm_connection() - if 'MASTER_ADDR' not in os.environ: - rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") - os.environ['MASTER_ADDR'] = '127.0.0.1' + if "MASTER_ADDR" not in os.environ: + rank_zero_warn( + "MASTER_ADDR environment variable is not defined. Set as localhost" + ) + os.environ["MASTER_ADDR"] = "127.0.0.1" log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - if 'MASTER_PORT' not in os.environ: - rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") - os.environ['MASTER_PORT'] = '12910' + if "MASTER_PORT" not in os.environ: + rank_zero_warn( + "MASTER_PORT environment variable is not defined. Set as 12910" + ) + os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size: + if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size: rank_zero_warn( f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) " f"is not equal to the computed world size ({world_size}). Ignored." ) torch_backend = "nccl" if self.trainer.on_gpu else "gloo" - log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") - torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) + log.info( + f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}" + ) + torch_distrib.init_process_group( + torch_backend, rank=global_rank, world_size=world_size + ) - def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule': + def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule": """ Add global batchnorm for a model spread across multiple GPUs and nodes. @@ -918,8 +938,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod return model def configure_apex( - self, amp: object, model: 'LightningModule', optimizers: List[Optimizer], amp_level: str - ) -> Tuple['LightningModule', List[Optimizer]]: + self, + amp: object, + model: "LightningModule", + optimizers: List[Optimizer], + amp_level: str, + ) -> Tuple["LightningModule", List[Optimizer]]: r""" Override to init AMP your own way. Must return a model and list of optimizers. @@ -950,7 +974,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod def configure_optimizers( self, - ) -> Optional[Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]]: + ) -> Optional[ + Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]] + ]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -1064,7 +1090,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod } """ - rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer') + rank_zero_warn( + "`configure_optimizers` must be implemented to be used with the Lightning Trainer" + ) def optimizer_step( self, @@ -1152,7 +1180,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod else: optimizer.step() - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + def optimizer_zero_grad( + self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int + ): optimizer.zero_grad() def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: @@ -1198,9 +1228,15 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod Each returned batch split is passed separately to :meth:`training_step`. """ - time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))] + time_dims = [ + len(x[0]) + for x in batch + if isinstance(x, (torch.Tensor, collections.Sequence)) + ] assert len(time_dims) >= 1, "Unable to determine batch time dimension" - assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous" + assert all( + x == time_dims[0] for x in time_dims + ), "Batch time dimension length is ambiguous" splits = [] for t in range(0, time_dims[0], split_size): @@ -1221,7 +1257,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: model_summary = ModelSummary(self, mode=mode) - log.info('\n' + str(model_summary)) + log.info("\n" + str(model_summary)) return model_summary def freeze(self) -> None: @@ -1323,17 +1359,21 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod """ # call .item() only once but store elements without graphs running_train_loss = self.trainer.train_loop.running_loss.mean() - avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN') - tqdm_dict = {'loss': '{:.3f}'.format(avg_training_loss)} + avg_training_loss = ( + running_train_loss.cpu().item() + if running_train_loss is not None + else float("NaN") + ) + tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss)} if self.trainer.truncated_bptt_steps is not None: - tqdm_dict['split_idx'] = self.trainer.split_idx + tqdm_dict["split_idx"] = self.trainer.split_idx if self.trainer.logger is not None and self.trainer.logger.version is not None: version = self.trainer.logger.version # show last 4 places of long version strings version = version[-4:] if isinstance(version, str) else version - tqdm_dict['v_num'] = version + tqdm_dict["v_num"] = version return tqdm_dict @@ -1434,10 +1474,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod if not frame: frame = inspect.currentframe().f_back init_args = get_init_args(frame) - assert init_args, 'failed to inspect the self init' + assert init_args, "failed to inspect the self init" if not args: hp = init_args - self._hparams_name = 'kwargs' if hp else None + self._hparams_name = "kwargs" if hp else None else: isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] if len(isx_non_str) == 1: @@ -1446,7 +1486,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod self._hparams_name = cand_names[0] if cand_names else None else: hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} - self._hparams_name = 'kwargs' + self._hparams_name = "kwargs" # `hparams` are expected here if hp: @@ -1458,9 +1498,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod if isinstance(hp, dict): hp = AttributeDict(hp) elif isinstance(hp, PRIMITIVE_TYPES): - raise ValueError(f'Primitives {PRIMITIVE_TYPES} are not allowed.') + raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, ALLOWED_CONFIG_TYPES): - raise ValueError(f'Unsupported config type of {type(hp)}.') + raise ValueError(f"Unsupported config type of {type(hp)}.") if isinstance(hp, dict) and isinstance(self.hparams, dict): self.hparams.update(hp) @@ -1498,19 +1538,25 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod input_data = self.example_input_array else: if input_sample is not None: - raise ValueError(f'Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`') + raise ValueError( + f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`" + ) else: - raise ValueError('Could not export to ONNX since neither `input_sample` nor' - ' `model.example_input_array` attribute is set.') + raise ValueError( + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) input_data = input_data.to(self.device) - if 'example_outputs' not in kwargs: + if "example_outputs" not in kwargs: self.eval() with torch.no_grad(): - kwargs['example_outputs'] = self(input_data) + kwargs["example_outputs"] = self(input_data) torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: + def to_torchscript( + self, file_path: Optional[str] = None, **kwargs + ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you would like to customize the modules that are scripted or you want to use tracing @@ -1560,7 +1606,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod @property def hparams(self) -> Union[AttributeDict, str]: - if not hasattr(self, '_hparams'): + if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams @@ -1578,12 +1624,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod """ try: class_code = inspect.getsource(self.__class__) - lines = class_code.split('\n') + lines = class_code.split("\n") for line in lines: line = re.sub(r"\s+", "", line, flags=re.UNICODE) - if '.hparams=' in line: - return line.split('=')[1] + if ".hparams=" in line: + return line.split("=")[1] except Exception as e: - return 'hparams' + return "hparams" return None