Black format pytorch_lightning/core/lightning.py (#3576)
* Black format pytorch_lightning/core/hooks.py Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter * Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
This commit is contained in:
parent
21cfdf6874
commit
cf1b946d4a
|
@ -23,21 +23,25 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch import Tensor, ScriptModule
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
|
||||
from pytorch_lightning.core.step_result import EvalResult, TrainResult
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.parsing import (
|
||||
AttributeDict,
|
||||
collect_init_args,
|
||||
get_init_args,
|
||||
)
|
||||
from torch import ScriptModule, Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import ModelHooks, DataHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
|
||||
from pytorch_lightning.core.step_result import TrainResult, EvalResult
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -47,7 +51,9 @@ else:
|
|||
XLA_AVAILABLE = True
|
||||
|
||||
|
||||
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module):
|
||||
class LightningModule(
|
||||
ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, DataHooks, Module
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -111,7 +117,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
True if your model is currently running on GPUs.
|
||||
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
|
||||
"""
|
||||
return self.device.type == 'cuda'
|
||||
return self.device.type == "cuda"
|
||||
|
||||
def print(self, *args, **kwargs) -> None:
|
||||
r"""
|
||||
|
@ -270,7 +276,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
||||
so it differs from the actual loss returned in train/validation step.
|
||||
"""
|
||||
rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer')
|
||||
rank_zero_warn(
|
||||
"`training_step` must be implemented to be used with the Lightning Trainer"
|
||||
)
|
||||
|
||||
def training_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -338,9 +346,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
See the :ref:`multi_gpu` guide for more details.
|
||||
"""
|
||||
|
||||
def training_epoch_end(
|
||||
self, outputs: Union[TrainResult, List[TrainResult]]
|
||||
):
|
||||
def training_epoch_end(self, outputs: Union[TrainResult, List[TrainResult]]):
|
||||
"""
|
||||
Called at the end of the training epoch with the outputs of all training steps.
|
||||
Use this in case you need to do something with all the outputs for every training_step.
|
||||
|
@ -542,7 +548,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
"""
|
||||
|
||||
def validation_epoch_end(
|
||||
self, outputs: Union[EvalResult, List[EvalResult]]
|
||||
self, outputs: Union[EvalResult, List[EvalResult]]
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Called at the end of the validation epoch with the outputs of all validation steps.
|
||||
|
@ -743,7 +749,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
"""
|
||||
|
||||
def test_epoch_end(
|
||||
self, outputs: Union[EvalResult, List[EvalResult]]
|
||||
self, outputs: Union[EvalResult, List[EvalResult]]
|
||||
) -> EvalResult:
|
||||
|
||||
"""
|
||||
|
@ -798,7 +804,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
return results
|
||||
"""
|
||||
|
||||
def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel:
|
||||
def configure_ddp(
|
||||
self, model: "LightningModule", device_ids: List[int]
|
||||
) -> DistributedDataParallel:
|
||||
r"""
|
||||
Override to init DDP in your own way or with your own wrapper.
|
||||
The only requirements are that:
|
||||
|
@ -828,7 +836,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
return model
|
||||
|
||||
"""
|
||||
model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True)
|
||||
model = LightningDistributedDataParallel(
|
||||
model, device_ids=device_ids, find_unused_parameters=True
|
||||
)
|
||||
return model
|
||||
|
||||
def _init_slurm_connection(self) -> None:
|
||||
|
@ -841,7 +851,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
# guarantees unique ports across jobs from same grid search
|
||||
try:
|
||||
# use the last 4 numbers in the job id as the id
|
||||
default_port = os.environ['SLURM_JOB_ID']
|
||||
default_port = os.environ["SLURM_JOB_ID"]
|
||||
default_port = default_port[-4:]
|
||||
|
||||
# all ports should be in the 10k+ range
|
||||
|
@ -852,20 +862,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
# if user gave a port number, use that one instead
|
||||
try:
|
||||
default_port = os.environ['MASTER_PORT']
|
||||
default_port = os.environ["MASTER_PORT"]
|
||||
except Exception:
|
||||
os.environ['MASTER_PORT'] = str(default_port)
|
||||
os.environ["MASTER_PORT"] = str(default_port)
|
||||
|
||||
# figure out the root node addr
|
||||
try:
|
||||
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
|
||||
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
|
||||
except Exception:
|
||||
root_node = '127.0.0.1'
|
||||
root_node = "127.0.0.1"
|
||||
|
||||
root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node)
|
||||
os.environ['MASTER_ADDR'] = root_node
|
||||
os.environ["MASTER_ADDR"] = root_node
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None:
|
||||
def init_ddp_connection(
|
||||
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Override to define your custom way of setting up a distributed environment.
|
||||
|
||||
|
@ -880,27 +892,35 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
if is_slurm_managing_tasks:
|
||||
self._init_slurm_connection()
|
||||
|
||||
if 'MASTER_ADDR' not in os.environ:
|
||||
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
rank_zero_warn(
|
||||
"MASTER_ADDR environment variable is not defined. Set as localhost"
|
||||
)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
||||
|
||||
if 'MASTER_PORT' not in os.environ:
|
||||
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
|
||||
os.environ['MASTER_PORT'] = '12910'
|
||||
if "MASTER_PORT" not in os.environ:
|
||||
rank_zero_warn(
|
||||
"MASTER_PORT environment variable is not defined. Set as 12910"
|
||||
)
|
||||
os.environ["MASTER_PORT"] = "12910"
|
||||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
||||
|
||||
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
|
||||
if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size:
|
||||
rank_zero_warn(
|
||||
f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
|
||||
f"is not equal to the computed world size ({world_size}). Ignored."
|
||||
)
|
||||
|
||||
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
|
||||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
|
||||
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
|
||||
log.info(
|
||||
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}"
|
||||
)
|
||||
torch_distrib.init_process_group(
|
||||
torch_backend, rank=global_rank, world_size=world_size
|
||||
)
|
||||
|
||||
def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
|
||||
def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
|
||||
"""
|
||||
Add global batchnorm for a model spread across multiple GPUs and nodes.
|
||||
|
||||
|
@ -918,8 +938,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
return model
|
||||
|
||||
def configure_apex(
|
||||
self, amp: object, model: 'LightningModule', optimizers: List[Optimizer], amp_level: str
|
||||
) -> Tuple['LightningModule', List[Optimizer]]:
|
||||
self,
|
||||
amp: object,
|
||||
model: "LightningModule",
|
||||
optimizers: List[Optimizer],
|
||||
amp_level: str,
|
||||
) -> Tuple["LightningModule", List[Optimizer]]:
|
||||
r"""
|
||||
Override to init AMP your own way.
|
||||
Must return a model and list of optimizers.
|
||||
|
@ -950,7 +974,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
) -> Optional[Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]]:
|
||||
) -> Optional[
|
||||
Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
|
||||
]:
|
||||
r"""
|
||||
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
||||
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
||||
|
@ -1064,7 +1090,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
}
|
||||
|
||||
"""
|
||||
rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer')
|
||||
rank_zero_warn(
|
||||
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
|
||||
)
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
|
@ -1152,7 +1180,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
else:
|
||||
optimizer.step()
|
||||
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
||||
def optimizer_zero_grad(
|
||||
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
|
||||
):
|
||||
optimizer.zero_grad()
|
||||
|
||||
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
|
||||
|
@ -1198,9 +1228,15 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
Each returned batch split is passed separately to :meth:`training_step`.
|
||||
|
||||
"""
|
||||
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
||||
time_dims = [
|
||||
len(x[0])
|
||||
for x in batch
|
||||
if isinstance(x, (torch.Tensor, collections.Sequence))
|
||||
]
|
||||
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
||||
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
||||
assert all(
|
||||
x == time_dims[0] for x in time_dims
|
||||
), "Batch time dimension length is ambiguous"
|
||||
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
|
@ -1221,7 +1257,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
|
||||
model_summary = ModelSummary(self, mode=mode)
|
||||
log.info('\n' + str(model_summary))
|
||||
log.info("\n" + str(model_summary))
|
||||
return model_summary
|
||||
|
||||
def freeze(self) -> None:
|
||||
|
@ -1323,17 +1359,21 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
"""
|
||||
# call .item() only once but store elements without graphs
|
||||
running_train_loss = self.trainer.train_loop.running_loss.mean()
|
||||
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
|
||||
tqdm_dict = {'loss': '{:.3f}'.format(avg_training_loss)}
|
||||
avg_training_loss = (
|
||||
running_train_loss.cpu().item()
|
||||
if running_train_loss is not None
|
||||
else float("NaN")
|
||||
)
|
||||
tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss)}
|
||||
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
tqdm_dict['split_idx'] = self.trainer.split_idx
|
||||
tqdm_dict["split_idx"] = self.trainer.split_idx
|
||||
|
||||
if self.trainer.logger is not None and self.trainer.logger.version is not None:
|
||||
version = self.trainer.logger.version
|
||||
# show last 4 places of long version strings
|
||||
version = version[-4:] if isinstance(version, str) else version
|
||||
tqdm_dict['v_num'] = version
|
||||
tqdm_dict["v_num"] = version
|
||||
|
||||
return tqdm_dict
|
||||
|
||||
|
@ -1434,10 +1474,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
if not frame:
|
||||
frame = inspect.currentframe().f_back
|
||||
init_args = get_init_args(frame)
|
||||
assert init_args, 'failed to inspect the self init'
|
||||
assert init_args, "failed to inspect the self init"
|
||||
if not args:
|
||||
hp = init_args
|
||||
self._hparams_name = 'kwargs' if hp else None
|
||||
self._hparams_name = "kwargs" if hp else None
|
||||
else:
|
||||
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
|
||||
if len(isx_non_str) == 1:
|
||||
|
@ -1446,7 +1486,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
self._hparams_name = cand_names[0] if cand_names else None
|
||||
else:
|
||||
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
|
||||
self._hparams_name = 'kwargs'
|
||||
self._hparams_name = "kwargs"
|
||||
|
||||
# `hparams` are expected here
|
||||
if hp:
|
||||
|
@ -1458,9 +1498,9 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
if isinstance(hp, dict):
|
||||
hp = AttributeDict(hp)
|
||||
elif isinstance(hp, PRIMITIVE_TYPES):
|
||||
raise ValueError(f'Primitives {PRIMITIVE_TYPES} are not allowed.')
|
||||
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
|
||||
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
|
||||
raise ValueError(f'Unsupported config type of {type(hp)}.')
|
||||
raise ValueError(f"Unsupported config type of {type(hp)}.")
|
||||
|
||||
if isinstance(hp, dict) and isinstance(self.hparams, dict):
|
||||
self.hparams.update(hp)
|
||||
|
@ -1498,19 +1538,25 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
input_data = self.example_input_array
|
||||
else:
|
||||
if input_sample is not None:
|
||||
raise ValueError(f'Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`')
|
||||
raise ValueError(
|
||||
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
|
||||
)
|
||||
else:
|
||||
raise ValueError('Could not export to ONNX since neither `input_sample` nor'
|
||||
' `model.example_input_array` attribute is set.')
|
||||
raise ValueError(
|
||||
"Could not export to ONNX since neither `input_sample` nor"
|
||||
" `model.example_input_array` attribute is set."
|
||||
)
|
||||
input_data = input_data.to(self.device)
|
||||
if 'example_outputs' not in kwargs:
|
||||
if "example_outputs" not in kwargs:
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
kwargs['example_outputs'] = self(input_data)
|
||||
kwargs["example_outputs"] = self(input_data)
|
||||
|
||||
torch.onnx.export(self, input_data, file_path, **kwargs)
|
||||
|
||||
def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
||||
def to_torchscript(
|
||||
self, file_path: Optional[str] = None, **kwargs
|
||||
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
||||
"""
|
||||
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
|
||||
If you would like to customize the modules that are scripted or you want to use tracing
|
||||
|
@ -1560,7 +1606,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
@property
|
||||
def hparams(self) -> Union[AttributeDict, str]:
|
||||
if not hasattr(self, '_hparams'):
|
||||
if not hasattr(self, "_hparams"):
|
||||
self._hparams = AttributeDict()
|
||||
return self._hparams
|
||||
|
||||
|
@ -1578,12 +1624,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
"""
|
||||
try:
|
||||
class_code = inspect.getsource(self.__class__)
|
||||
lines = class_code.split('\n')
|
||||
lines = class_code.split("\n")
|
||||
for line in lines:
|
||||
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
|
||||
if '.hparams=' in line:
|
||||
return line.split('=')[1]
|
||||
if ".hparams=" in line:
|
||||
return line.split("=")[1]
|
||||
except Exception as e:
|
||||
return 'hparams'
|
||||
return "hparams"
|
||||
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue