formatting 4/n: Trainer (#5720)

* yapf trainer

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* .

* fix

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Jirka Borovec 2021-02-03 10:25:42 +01:00 committed by GitHub
parent aa03b73e60
commit aba212341a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 187 additions and 191 deletions

View File

@ -28,10 +28,6 @@ pytorch_lightning/plugins/legacy/*
pytorch_lightning/profiler/*
# TODO
pytorch_lightning/trainer/*
# TODO
pytorch_lightning/tuner/*

View File

@ -95,10 +95,6 @@ class ConfigValidator(object):
has_step = is_overridden(step_name, model)
if has_loader and not has_step:
rank_zero_warn(
f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
)
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop')
if has_step and not has_loader:
rank_zero_warn(
f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
)
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop')

View File

@ -25,14 +25,14 @@ class CallbackConnector:
self.trainer = trainer
def on_trainer_init(
self,
callbacks,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint
self,
callbacks,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint,
):
self.trainer.resume_from_checkpoint = resume_from_checkpoint
@ -51,9 +51,7 @@ class CallbackConnector:
self.configure_checkpoint_callbacks(checkpoint_callback)
# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
if isinstance(checkpoint_callback, ModelCheckpoint):
@ -61,8 +59,7 @@ class CallbackConnector:
rank_zero_warn(
"Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)"
" is deprecated since v1.1 and will no longer be supported in v1.3."
" Use `callbacks` argument instead.",
DeprecationWarning
" Use `callbacks` argument instead.", DeprecationWarning
)
self.trainer.callbacks.append(checkpoint_callback)

View File

@ -230,7 +230,8 @@ class CheckpointConnector:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
'warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
@ -297,9 +298,10 @@ class CheckpointConnector:
checkpoint['lr_schedulers'] = lr_schedulers
# dump amp scaling
if (self.trainer.amp_backend == AMPType.NATIVE
and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None):
if (
self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None
):
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()
@ -409,6 +411,7 @@ class CheckpointConnector:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)

View File

@ -36,8 +36,7 @@ class DataConnector(object):
def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)),
"get_train_batch"
enumerate(self._with_is_last(train_dataloader)), "get_train_batch"
)
return profiled_dl

View File

@ -25,13 +25,13 @@ class DebuggingConnector:
self.trainer = trainer
def on_init_start(
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run,
):
if not isinstance(fast_dev_run, (bool, int)):
raise MisconfigurationException(

View File

@ -24,6 +24,7 @@ def overwrite_by_env_vars(fn: Callable) -> Callable:
input arguments should be moved automatically to the correct device.
"""
@wraps(fn)
def overwrite_by_env_vars(self, *args, **kwargs):
# get the class

View File

@ -18,8 +18,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
class CallbackHookNameValidator:
@staticmethod
def check_logging_in_callbacks(current_hook_fx_name: str = None, on_step: bool = None,
on_epoch: bool = None) -> None:
def check_logging_in_callbacks(
current_hook_fx_name: str = None, on_step: bool = None, on_epoch: bool = None
) -> None:
if current_hook_fx_name is None:
return

View File

@ -31,6 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
class LoggerConnector:
def __init__(self, trainer):
self.trainer = trainer
self._callback_metrics = MetricsHolder()
@ -76,14 +77,14 @@ class LoggerConnector:
@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage) # type: ignore
return self._cached_results.get(self.trainer._running_stage) # type: ignore
def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
metrics_holder.convert(
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref
model_ref.device if model_ref is not None else model_ref,
)
return metrics_holder.metrics

View File

@ -21,7 +21,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE
class MetricsHolder:
"""
This class acts as a dictonary holder.
It holds metrics and implements conversion functions.

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Root module for all distributed operations in Lightning.
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
@ -21,6 +20,7 @@ from weakref import proxy
class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer

View File

@ -16,6 +16,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
class OptimizerConnector:
def __init__(self, trainer):
self.trainer = trainer
@ -50,9 +51,8 @@ class OptimizerConnector:
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key)
if monitor_metrics is not None
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
if monitor_val is None:
if lr_scheduler.get('strict', True):

View File

@ -49,9 +49,11 @@ class PrecisionConnector:
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
if amp_type == 'native':
if not _NATIVE_AMP_AVAILABLE:
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
' Consider upgrading with `pip install torch>=1.6`.'
' We will attempt to use NVIDIA Apex for this session.')
rank_zero_warn(
'You have asked for native AMP but your PyTorch version does not support it.'
' Consider upgrading with `pip install torch>=1.6`.'
' We will attempt to use NVIDIA Apex for this session.'
)
amp_type = 'apex'
else:
self.trainer.amp_backend = AMPType.NATIVE
@ -60,8 +62,10 @@ class PrecisionConnector:
if amp_type == 'apex':
if not _APEX_AVAILABLE:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
rank_zero_warn(
'You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux'
)
else:
log.info('Using APEX 16bit precision.')
self.trainer.amp_backend = AMPType.APEX

View File

@ -24,11 +24,7 @@ from pytorch_lightning.profiler import (
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler
}
PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
class ProfilerConnector:
@ -40,14 +36,17 @@ class ProfilerConnector:
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}.")
raise MisconfigurationException(
"Only None, bool, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}."
)
if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
DeprecationWarning)
rank_zero_warn(
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
@ -55,8 +54,10 @@ class ProfilerConnector:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
raise ValueError(
"When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'"
)
self.trainer.profiler = profiler or PassThroughProfiler()
def on_train_start(self, trainer):

View File

@ -151,9 +151,5 @@ class SLURMConnector:
torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo"
if not torch.distributed.is_initialized():
log.info(
f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)
log.info(f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

View File

@ -21,12 +21,12 @@ class TrainingTricksConnector:
self.trainer = trainer
def on_trainer_init(
self,
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan
self,
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
):
self.trainer.terminate_on_nan = terminate_on_nan

View File

@ -37,12 +37,12 @@ class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
global_rank: int
shown_warnings: ...
shown_warnings:...
val_check_interval: float
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: ...
val_check_batch:...
val_dataloaders: List[DataLoader]
num_val_batches: List[Union[int, float]]
test_dataloaders: List[DataLoader]
@ -65,22 +65,27 @@ class TrainerDataLoadingMixin(ABC):
using_spawn = self.distributed_backend == "ddp_spawn"
if is_dataloader and not on_windows:
if dataloader.num_workers > 0 and using_spawn:
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting accelerator=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch')
rank_zero_warn(
'Dataloader(num_workers>0) and ddp_spawn do not mix well!'
' Your performance might suffer dramatically.'
' Please consider setting accelerator=ddp to use num_workers > 0'
' (this is a bottleneck of Python .spawn() and PyTorch'
)
elif dataloader.num_workers == 0 and using_spawn:
rank_zero_warn('You are using `accelerator=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `accelerator=ddp`'
' and set `num_workers>0`')
rank_zero_warn(
'You are using `accelerator=ddp_spawn` with num_workers=0.'
' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`'
)
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
num_cpus = multiprocessing.cpu_count()
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
f' (try {num_cpus} which is the number of cpus on this machine)'
' in the `DataLoader` init to improve performance.')
rank_zero_warn(
f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
f' (try {num_cpus} which is the number of cpus on this machine)'
f' in the `DataLoader` init to improve performance.'
)
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
@ -99,7 +104,8 @@ class TrainerDataLoadingMixin(ABC):
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
' `replace_sampler_ddp`=False if you want to use your custom sampler.'
)
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader, shuffle)
@ -110,9 +116,7 @@ class TrainerDataLoadingMixin(ABC):
def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}
dl_args = {k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys}
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
@ -138,17 +142,21 @@ class TrainerDataLoadingMixin(ABC):
if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.')
rank_zero_warn(
'You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.'
)
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
)
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True)
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
)
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')
@ -166,7 +174,8 @@ class TrainerDataLoadingMixin(ABC):
raise MisconfigurationException(
'When using an IterableDataset for `limit_train_batches`,'
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
' `num_training_batches` to use.')
' `num_training_batches` to use.'
)
# determine when to check validation
# if int passed in, val checks that often
@ -177,7 +186,8 @@ class TrainerDataLoadingMixin(ABC):
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
)
else:
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
@ -186,15 +196,16 @@ class TrainerDataLoadingMixin(ABC):
raise MisconfigurationException(
'When using an IterableDataset for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
' checking validation every k training batches.'
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(
self,
model: LightningModule,
mode: str
self,
model: LightningModule,
mode: str,
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
@ -229,13 +240,17 @@ class TrainerDataLoadingMixin(ABC):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
' We are turning it off for you.')
rank_zero_warn(
'You requested to overfit but enabled test/val dataloader shuffling.'
' We are turning it off for you.'
)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
else:
rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for validation and test dataloaders.')
rank_zero_warn(
f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for validation and test dataloaders.'
)
if any([dl is None for dl in dataloaders]):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
@ -264,7 +279,8 @@ class TrainerDataLoadingMixin(ABC):
raise MisconfigurationException(
'When using an IterableDataset for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
f' `num_{mode}_batches` to use.')
f' `num_{mode}_batches` to use.'
)
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)

View File

@ -100,16 +100,12 @@ class DeprecatedDistDeviceAttributes:
@property
def use_horovod(self) -> bool:
rank_zero_warn(
"Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
)
rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
return self._distrib_type == DistributedType.HOROVOD
@use_horovod.setter
def use_horovod(self, val: bool) -> None:
rank_zero_warn(
"Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
)
rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.HOROVOD
@ -119,14 +115,16 @@ class DeprecatedDistDeviceAttributes:
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
)
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
return (self._device_type and self._device_type == DeviceType.GPU
and self.num_gpus == 1
and self._distrib_type not in (DistributedType.DDP2, ))
return (
self._device_type and self._device_type == DeviceType.GPU and self.num_gpus == 1
and self._distrib_type != DistributedType.DDP2
)
@use_single_gpu.setter
def use_single_gpu(self, val: bool) -> None:
rank_zero_warn(
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning,
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.",
DeprecationWarning,
)
if val:
self._device_type = DeviceType.GPU

View File

@ -21,6 +21,7 @@ from pytorch_lightning.utilities.warnings import WarningCache
class EvaluationLoop(object):
def __init__(self, trainer):
self.trainer = trainer
self.testing = False
@ -303,7 +304,8 @@ class EvaluationLoop(object):
def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx):
# set dataloader_idx to model and track batch_size
self.trainer.logger_connector.on_evaluation_batch_start(
self.testing, batch, dataloader_idx, self.num_dataloaders)
self.testing, batch, dataloader_idx, self.num_dataloaders
)
if self.testing:
self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)

View File

@ -17,9 +17,11 @@ import warnings
def ignore_scalar_return_in_dp():
# Users get confused by this warning so we silence it
warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all'
' input tensors were scalars; will instead unsqueeze'
' and return a vector.')
warnings.filterwarnings(
'ignore',
message='Was asked to gather along dimension 0, but all input tensors were scalars;'
' will instead unsqueeze and return a vector.'
)
ignore_scalar_return_in_dp()

View File

@ -31,14 +31,14 @@ class TrainerLoggingMixin(ABC):
current_epoch: int
_device_type: DeviceType
_distrib_type: DistributedType
log_gpu_memory: ...
log_gpu_memory:...
logger: Union[LightningLoggerBase, bool]
global_step: int
global_rank: int
default_root_dir: str
slurm_job_id: int
num_gpus: int
logged_metrics: ...
logged_metrics:...
def metrics_to_scalars(self, metrics):
new_metrics = {}
@ -66,13 +66,12 @@ class TrainerLoggingMixin(ABC):
for k, v in output.items():
if k in ['log', 'progress_bar']:
m = inspect.cleandoc(
f"""The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0
Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
""")
f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n"
" Please use self.log(...) inside the lightningModule instead.\n"
" # log on a step or aggregate epoch metric to the logger and/or progress bar"
" (inside LightningModule)\n"
" self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)"
)
rank_zero_warn(m)
# --------------------------

View File

@ -19,6 +19,7 @@ from pytorch_lightning.core.lightning import LightningModule
class TrainerModelHooksMixin(ABC):
def is_function_implemented(self, f_name, model=None):
if model is None:
model = self.get_model()

View File

@ -26,6 +26,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
class TrainerOptimizersMixin(ABC):
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
optim_conf = model.configure_optimizers()
if optim_conf is None:
@ -82,6 +83,7 @@ class TrainerOptimizersMixin(ABC):
return optimizers, lr_schedulers, optimizer_frequencies
def convert_to_lightning_optimizers(self):
def _convert_to_lightning_optimizer(trainer, optimizer):
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer)
@ -132,9 +134,11 @@ class TrainerOptimizersMixin(ABC):
' For example:'
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
)
lr_schedulers.append({
**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True,
'monitor': monitor
})
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:

View File

@ -59,6 +59,7 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[
"""
def wrapper(fn) -> Callable:
@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
if not isinstance(self, pytorch_lightning.Trainer):

View File

@ -101,10 +101,11 @@ class TensorRunningAccum(object):
if self.rotated:
return getattr(self.memory, how)()
else:
return getattr(self.memory[: self.current_idx], how)()
return getattr(self.memory[:self.current_idx], how)()
class Accumulator(object):
def __init__(self):
self.num_values = 0
self.total = 0
@ -119,6 +120,7 @@ class Accumulator(object):
class PredictionCollection(object):
def __init__(self, global_rank: int, world_size: int):
self.global_rank = global_rank
self.world_size = world_size
@ -131,9 +133,7 @@ class PredictionCollection(object):
elif name not in self.predictions[filename]:
self.predictions[filename][name] = values
elif isinstance(values, Tensor):
self.predictions[filename][name] = torch.cat(
(self.predictions[filename][name], values)
)
self.predictions[filename][name] = torch.cat((self.predictions[filename][name], values))
elif isinstance(values, list):
self.predictions[filename][name].extend(values)
@ -161,10 +161,7 @@ class PredictionCollection(object):
fs.mkdirs(dirpath, exist_ok=True)
# Convert any tensor values to list
predictions = {
k: v if not isinstance(v, Tensor) else v.tolist()
for k, v in predictions.items()
}
predictions = {k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items()}
# Check if all features for this file add up to same length
feature_lens = {k: len(v) for k, v in predictions.items()}
@ -186,6 +183,7 @@ class CycleIterator(object):
"""
Iterator for restarting a dataloader if it runs out of samples
"""
def __init__(self, loader: Any, length: Optional[int] = None):
"""
@ -296,8 +294,9 @@ class CombinedDataset(object):
raise MisconfigurationException(f"Invalid Mode: {mode}")
# extract the lengths
all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len,
wrong_dtype=(Sequence, Mapping))
all_lengths = apply_to_collection(
datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)
)
compute_func = CombinedDataset.COMPUTE_FUNCS[mode]
@ -351,8 +350,9 @@ class CombinedLoader(object):
"""
self.loaders = loaders
datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None,
wrong_dtype=(Sequence, Mapping))
datasets = apply_to_collection(
self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping)
)
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
self.dataset = CombinedDataset(datasets, mode)
@ -367,8 +367,7 @@ class CombinedLoader(object):
@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None,
wrong_dtype=(Sequence, Mapping))
return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping))
def _wrap_loaders_max_size_cycle(self) -> Any:
"""
@ -378,8 +377,7 @@ class CombinedLoader(object):
the wrapped loaders
"""
all_lengths = apply_to_collection(self.loaders, Iterable, get_len,
wrong_dtype=(Sequence, Mapping))
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
if isinstance(all_lengths, (int, float)):
length = all_lengths
@ -391,13 +389,10 @@ class CombinedLoader(object):
length = max(all_lengths)
if isinstance(self.loaders, Mapping):
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length)
for k, v in self.loaders.items()})
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()})
elif isinstance(self.loaders, Sequence):
self.loaders = type(self.loaders)([
CycleIterator(v, length=length) for v in self.loaders
])
self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders])
# dataloaders are iterable but not sequence
elif isinstance(self.loaders, Iterable):
@ -424,8 +419,7 @@ class CombinedLoader(object):
length: the minimum length of loaders
"""
all_lengths = apply_to_collection(loaders, Iterable, get_len,
wrong_dtype=(Sequence, Mapping))
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
if isinstance(all_lengths, (int, float)):
return all_lengths
@ -441,6 +435,7 @@ class CombinedLoaderIterator(object):
"""
Custom Iterator returning data from multple loaders, and allows sampling in parallel
"""
def __init__(self, loaders: Any):
"""

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer to automate the training."""
import os
@ -66,7 +65,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
# warnings to ignore in trainer
warnings.filterwarnings(
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
'ignore', message='torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead'
)
@ -80,6 +79,7 @@ class Trainer(
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
):
@overwrite_by_env_vars
def __init__(
self,
@ -896,9 +896,7 @@ class Trainer(
# --------------------
# If you supply a datamodule you can't supply dataloaders
if dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule'
)
raise MisconfigurationException('You cannot pass dataloaders to trainer.predict if you supply a datamodule')
if model is None:
raise MisconfigurationException('You need to pass a model to `trainer.predict`. ')

View File

@ -35,6 +35,7 @@ from pytorch_lightning.utilities.warnings import WarningCache
class TrainLoop:
def __init__(self, trainer, multiple_trainloader_mode):
self.trainer = trainer
self.early_stopping_accumulator = None
@ -154,9 +155,10 @@ class TrainLoop:
self.trainer.model_connector.copy_trainer_model_properties(ref_model)
# init amp. Must be done here instead of __init__ to allow ddp to work
if (self.trainer.amp_backend == AMPType.NATIVE
and self.trainer.precision == 16
and self.trainer._device_type != DeviceType.TPU):
if (
self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16
and self.trainer._device_type != DeviceType.TPU
):
self.trainer.scaler = self.trainer.precision_connector.backend.scaler
# log hyper-parameters
@ -498,7 +500,8 @@ class TrainLoop:
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
' To request, please file a Github issue in PyTorch and tag @mcarilli'
)
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
@ -641,10 +644,7 @@ class TrainLoop:
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output,
self.checkpoint_accumulator,
self.early_stopping_accumulator,
self.num_optimizers
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
)
# when no val loop is present or fast-dev-run still need to call checkpoints
@ -699,11 +699,8 @@ class TrainLoop:
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.trainer.hiddens)
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
)
batch_outputs = self._process_closure_result(
batch_outputs=batch_outputs,
@ -720,11 +717,7 @@ class TrainLoop:
def train_step_and_backward_closure():
result = self.training_step_and_backward(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.trainer.hiddens
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
)
return None if result is None else result.loss
@ -733,10 +726,7 @@ class TrainLoop:
else:
self._curr_step_result = self.training_step(
split_batch,
batch_idx,
opt_idx,
self.trainer.hiddens
split_batch, batch_idx, opt_idx, self.trainer.hiddens
)
if self._curr_step_result is None:
@ -783,9 +773,7 @@ class TrainLoop:
else:
yield None
def _process_closure_result(
self, batch_outputs: list, opt_idx: int
) -> list:
def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:
opt_closure_result = self._curr_step_result
if opt_closure_result is not None:

View File

@ -29,7 +29,7 @@ class TrainerTrainingTricksMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
default_root_dir: str
progress_bar_callback: ...
progress_bar_callback:...
on_gpu: bool
@abstractmethod
@ -47,9 +47,7 @@ class TrainerTrainingTricksMixin(ABC):
# check if loss is nan
if not torch.isfinite(loss).all():
raise ValueError(
'The loss returned in `training_step` is nan or inf.'
)
raise ValueError('The loss returned in `training_step` is nan or inf.')
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():