Promote the CLI out of utilities (#13767)
This commit is contained in:
parent
e2ec51a5a2
commit
4f53e7132f
|
@ -15,7 +15,7 @@
|
|||
pass
|
||||
|
||||
|
||||
class LightningCLI(pl.utilities.cli.LightningCLI):
|
||||
class LightningCLI(pl.cli.LightningCLI):
|
||||
def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs):
|
||||
super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
pass
|
||||
|
||||
|
||||
class LightningCLI(pl.utilities.cli.LightningCLI):
|
||||
class LightningCLI(pl.cli.LightningCLI):
|
||||
def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs):
|
||||
super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs)
|
||||
|
||||
|
@ -88,7 +88,7 @@ Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.tr
|
|||
:class:`~pytorch_lightning.core.module.LightningModule` and
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be
|
||||
configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is
|
||||
imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of
|
||||
imported before the :class:`~pytorch_lightning.cli.LightningCLI` class is run, the name can be used instead of
|
||||
the full import path.
|
||||
|
||||
From command line the syntax is the following:
|
||||
|
@ -117,7 +117,7 @@ callback appended. Here is an example:
|
|||
|
||||
.. note::
|
||||
|
||||
Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`)
|
||||
Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.cli.SaveConfigCallback`)
|
||||
always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input
|
||||
config files.
|
||||
|
||||
|
@ -306,7 +306,7 @@ example can be when one wants to add support for multiple optimizers:
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.utilities.cli import instantiate_class
|
||||
from pytorch_lightning.cli import instantiate_class
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
|
@ -330,7 +330,7 @@ example can be when one wants to add support for multiple optimizers:
|
|||
cli = MyLightningCLI(MyModel)
|
||||
|
||||
The value given to :code:`optimizer*_init` will always be a dictionary including :code:`class_path` and
|
||||
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
|
||||
:code:`init_args` entries. The function :func:`~pytorch_lightning.cli.instantiate_class`
|
||||
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
|
||||
in this case :code:`self.parameters()`, and the :code:`init_args`.
|
||||
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
pass
|
||||
|
||||
|
||||
class LightningCLI(pl.utilities.cli.LightningCLI):
|
||||
class LightningCLI(pl.cli.LightningCLI):
|
||||
def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs):
|
||||
super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs)
|
||||
|
||||
|
@ -62,23 +62,23 @@ Eliminate config boilerplate (Advanced)
|
|||
Customize the LightningCLI
|
||||
**************************
|
||||
|
||||
The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some
|
||||
The init parameters of the :class:`~pytorch_lightning.cli.LightningCLI` class can be used to customize some
|
||||
things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to
|
||||
instantiate the trainer and configuration parser.
|
||||
|
||||
Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be
|
||||
extended to customize different parts of the command line tool. The argument parser class used by
|
||||
:class:`~pytorch_lightning.utilities.cli.LightningCLI` is
|
||||
:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus
|
||||
:class:`~pytorch_lightning.cli.LightningCLI` is
|
||||
:class:`~pytorch_lightning.cli.LightningArgumentParser` which is an extension of python's argparse, thus
|
||||
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
|
||||
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
|
||||
parameters to have type hints. For more details about this please refer to the `respective documentation
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#classes-methods-and-functions>`_.
|
||||
|
||||
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
|
||||
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
|
||||
The :class:`~pytorch_lightning.cli.LightningCLI` class has the
|
||||
:meth:`~pytorch_lightning.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
|
||||
more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The
|
||||
:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before
|
||||
:class:`~pytorch_lightning.cli.LightningCLI` class also has two methods that can be used to run code before
|
||||
and after the trainer runs: :code:`before_<subcommand>` and :code:`after_<subcommand>`.
|
||||
A realistic example for these would be to send an email before and after the execution.
|
||||
The code for the :code:`fit` subcommand would be something like:
|
||||
|
@ -104,7 +104,7 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train
|
|||
|
||||
.. tip::
|
||||
|
||||
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
|
||||
Have a look at the :class:`~pytorch_lightning.cli.LightningCLI` class API reference to learn about other
|
||||
methods that can be extended to customize a CLI.
|
||||
|
||||
----
|
||||
|
@ -211,7 +211,7 @@ A more compact version that avoids writing a dictionary would be:
|
|||
************************
|
||||
Connect two config files
|
||||
************************
|
||||
Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
|
||||
Another case in which it might be desired to extend :class:`~pytorch_lightning.cli.LightningCLI` is that the
|
||||
model and data module depend on a common parameter. For example in some cases both classes require to know the
|
||||
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
|
||||
parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
pass
|
||||
|
||||
|
||||
class LightningCLI(pl.utilities.cli.LightningCLI):
|
||||
class LightningCLI(pl.cli.LightningCLI):
|
||||
def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs):
|
||||
super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs)
|
||||
|
||||
|
@ -65,7 +65,7 @@ there is a failure an exception is raised and the full stack trace printed.
|
|||
Reproducibility with the LightningCLI
|
||||
*************************************
|
||||
The topic of reproducibility is complex and it is impossible to guarantee reproducibility by just providing a class that
|
||||
people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to
|
||||
people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.cli.LightningCLI` tries to
|
||||
give a framework and recommendations to make reproducibility simpler.
|
||||
|
||||
When an experiment is run, it is good practice to use a stable version of the source code, either being a released
|
||||
|
@ -85,7 +85,7 @@ For every CLI implemented, users are encouraged to learn how to run it by readin
|
|||
:code:`--help` option and use the :code:`--print_config` option to guide the writing of config files. A few more details
|
||||
that might not be clear by only reading the help are the following.
|
||||
|
||||
:class:`~pytorch_lightning.utilities.cli.LightningCLI` is based on argparse and as such follows the same arguments style
|
||||
:class:`~pytorch_lightning.cli.LightningCLI` is based on argparse and as such follows the same arguments style
|
||||
as many POSIX command line tools. Long options are prefixed with two dashes and its corresponding values should be
|
||||
provided with an empty space or an equal sign, as :code:`--option value` or :code:`--option=value`. Command line options
|
||||
are parsed from left to right, therefore if a setting appears multiple times the value most to the right will override
|
||||
|
|
|
@ -82,7 +82,7 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC
|
|||
|
||||
# main.py
|
||||
import torch
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
|
||||
# simple demo classes for your convenience
|
||||
from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule
|
||||
|
|
|
@ -23,9 +23,9 @@ from torch.nn import functional as F
|
|||
from torchmetrics import Accuracy
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ from torch.nn import functional as F
|
|||
from torchmetrics import Accuracy
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from pytorch_lightning import callbacks, cli_lightning_logo, LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
|
|
|
@ -23,8 +23,8 @@ from torch.nn import functional as F
|
|||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
|
|
|
@ -31,8 +31,8 @@ import torchvision.models as models
|
|||
import torchvision.transforms as T
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.profilers.pytorch import PyTorchProfiler
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
DEFAULT_CMD_LINE = (
|
||||
"fit",
|
||||
|
|
|
@ -56,7 +56,7 @@ from torchvision.datasets.utils import download_and_extract_archive
|
|||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -47,8 +47,8 @@ from torchmetrics import Accuracy
|
|||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.strategies import ParallelStrategy
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class ImageNetLightningModel(LightningModule):
|
||||
|
|
|
@ -16,9 +16,9 @@ from jsonargparse import lazy_instance
|
|||
from torch.nn import functional as F
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule
|
||||
from pytorch_lightning.plugins import HPUPrecisionPlugin
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class LitClassifier(LightningModule):
|
||||
|
|
|
@ -23,8 +23,8 @@ from torch.nn import functional as F
|
|||
from torch.utils.data import random_split
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.imports import _DALI_AVAILABLE, _TORCHVISION_AVAILABLE
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
|
|
|
@ -12,8 +12,8 @@ import torchvision.transforms as T
|
|||
from PIL import Image as PILImage
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.serve import ServableModule, ServableModuleValidator
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -187,6 +187,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221))
|
||||
|
||||
|
||||
- Deprecated public utilities in `pytorch_lightning.utilities.cli.LightningCLI` in favor of equivalent copies in `pytorch_lightning.cli.LightningCLI` ([#13767](https://github.com/PyTorchLightning/pytorch-lightning/pull/13767))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.profiler` in favor of `pytorch_lightning.profilers` ([#12308](https://github.com/PyTorchLightning/pytorch-lightning/pull/12308))
|
||||
|
||||
|
|
|
@ -0,0 +1,700 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import partial, update_wrapper
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2")
|
||||
|
||||
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
||||
import docstring_parser
|
||||
from jsonargparse import (
|
||||
ActionConfigFile,
|
||||
ArgumentParser,
|
||||
class_from_function,
|
||||
Namespace,
|
||||
register_unresolvable_import_paths,
|
||||
set_config_read_mode,
|
||||
)
|
||||
|
||||
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
|
||||
set_config_read_mode(fsspec_enabled=True)
|
||||
else:
|
||||
locals()["ArgumentParser"] = object
|
||||
locals()["Namespace"] = object
|
||||
|
||||
|
||||
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(optimizer, *args, **kwargs)
|
||||
self.monitor = monitor
|
||||
|
||||
|
||||
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
|
||||
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input.
|
||||
|
||||
For full details of accepted arguments see `ArgumentParser.__init__
|
||||
<https://jsonargparse.readthedocs.io/en/stable/index.html#jsonargparse.ArgumentParser.__init__>`_.
|
||||
"""
|
||||
if not _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}. Try `pip install -U 'jsonargparse[signatures]'`."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_argument(
|
||||
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
|
||||
)
|
||||
self.callback_keys: List[str] = []
|
||||
# separate optimizers and lr schedulers to know which were added
|
||||
self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
|
||||
self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
|
||||
|
||||
def add_lightning_class_args(
|
||||
self,
|
||||
lightning_class: Union[
|
||||
Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]],
|
||||
Type[Trainer],
|
||||
Type[LightningModule],
|
||||
Type[LightningDataModule],
|
||||
Type[Callback],
|
||||
],
|
||||
nested_key: str,
|
||||
subclass_mode: bool = False,
|
||||
required: bool = True,
|
||||
) -> List[str]:
|
||||
"""Adds arguments from a lightning class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
subclass_mode: Whether allow any subclass of the given class.
|
||||
required: Whether the argument group is required.
|
||||
|
||||
Returns:
|
||||
A list with the names of the class arguments added.
|
||||
"""
|
||||
if callable(lightning_class) and not isinstance(lightning_class, type):
|
||||
lightning_class = class_from_function(lightning_class)
|
||||
|
||||
if isinstance(lightning_class, type) and issubclass(
|
||||
lightning_class, (Trainer, LightningModule, LightningDataModule, Callback)
|
||||
):
|
||||
if issubclass(lightning_class, Callback):
|
||||
self.callback_keys.append(nested_key)
|
||||
if subclass_mode:
|
||||
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
|
||||
return self.add_class_arguments(
|
||||
lightning_class,
|
||||
nested_key,
|
||||
fail_untyped=False,
|
||||
instantiate=not issubclass(lightning_class, Trainer),
|
||||
sub_configs=True,
|
||||
)
|
||||
raise MisconfigurationException(
|
||||
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
|
||||
"Trainer, LightningModule, LightningDataModule, or Callback."
|
||||
)
|
||||
|
||||
def add_optimizer_args(
|
||||
self,
|
||||
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,),
|
||||
nested_key: str = "optimizer",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from an optimizer class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
if isinstance(optimizer_class, tuple):
|
||||
assert all(issubclass(o, Optimizer) for o in optimizer_class)
|
||||
else:
|
||||
assert issubclass(optimizer_class, Optimizer)
|
||||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
|
||||
if isinstance(optimizer_class, tuple):
|
||||
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
|
||||
else:
|
||||
self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs)
|
||||
self._optimizers[nested_key] = (optimizer_class, link_to)
|
||||
|
||||
def add_lr_scheduler_args(
|
||||
self,
|
||||
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple,
|
||||
nested_key: str = "lr_scheduler",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
|
||||
tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
if isinstance(lr_scheduler_class, tuple):
|
||||
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
|
||||
else:
|
||||
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
|
||||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
|
||||
if isinstance(lr_scheduler_class, tuple):
|
||||
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
|
||||
else:
|
||||
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
|
||||
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
|
||||
|
||||
|
||||
class SaveConfigCallback(Callback):
|
||||
"""Saves a LightningCLI config to the log_dir when training starts.
|
||||
|
||||
Args:
|
||||
parser: The parser object used to parse the configuration.
|
||||
config: The parsed configuration that will be saved.
|
||||
config_filename: Filename for the config file.
|
||||
overwrite: Whether to overwrite an existing config file.
|
||||
multifile: When input is multiple config files, saved config preserves this structure.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: LightningArgumentParser,
|
||||
config: Namespace,
|
||||
config_filename: str,
|
||||
overwrite: bool = False,
|
||||
multifile: bool = False,
|
||||
) -> None:
|
||||
self.parser = parser
|
||||
self.config = config
|
||||
self.config_filename = config_filename
|
||||
self.overwrite = overwrite
|
||||
self.multifile = multifile
|
||||
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||
log_dir = trainer.log_dir # this broadcasts the directory
|
||||
assert log_dir is not None
|
||||
config_path = os.path.join(log_dir, self.config_filename)
|
||||
fs = get_filesystem(log_dir)
|
||||
|
||||
if not self.overwrite:
|
||||
# check if the file exists on rank 0
|
||||
file_exists = fs.isfile(config_path) if trainer.is_global_zero else False
|
||||
# broadcast whether to fail to all ranks
|
||||
file_exists = trainer.strategy.broadcast(file_exists)
|
||||
if file_exists:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
|
||||
" results of a previous run. You can delete the previous config file,"
|
||||
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
|
||||
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
|
||||
)
|
||||
|
||||
# save the file on rank 0
|
||||
if trainer.is_global_zero:
|
||||
# save only on rank zero to avoid race conditions.
|
||||
# the `log_dir` needs to be created as we rely on the logger to do it usually
|
||||
# but it hasn't logged anything at this point
|
||||
fs.makedirs(log_dir, exist_ok=True)
|
||||
self.parser.save(
|
||||
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
|
||||
)
|
||||
|
||||
|
||||
class LightningCLI:
|
||||
"""Implementation of a configurable command line tool for pytorch-lightning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None,
|
||||
datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None,
|
||||
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
|
||||
save_config_filename: str = "config.yaml",
|
||||
save_config_overwrite: bool = False,
|
||||
save_config_multifile: bool = False,
|
||||
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
|
||||
trainer_defaults: Optional[Dict[str, Any]] = None,
|
||||
seed_everything_default: Union[bool, int] = True,
|
||||
description: str = "pytorch-lightning trainer command line tool",
|
||||
env_prefix: str = "PL",
|
||||
env_parse: bool = False,
|
||||
parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None,
|
||||
subclass_mode_model: bool = False,
|
||||
subclass_mode_data: bool = False,
|
||||
run: bool = True,
|
||||
auto_registry: bool = False,
|
||||
) -> None:
|
||||
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
|
||||
are called / instantiated using a parsed configuration file and / or command line args.
|
||||
|
||||
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
|
||||
A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
|
||||
Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.
|
||||
|
||||
For more info, read :ref:`the CLI docs <lightning-cli>`.
|
||||
|
||||
.. warning:: ``LightningCLI`` is in beta and subject to change.
|
||||
|
||||
Args:
|
||||
model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when
|
||||
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
|
||||
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
|
||||
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
|
||||
save_config_callback: A callback class to save the training config.
|
||||
save_config_filename: Filename for the config file.
|
||||
save_config_overwrite: Whether to overwrite an existing config file.
|
||||
save_config_multifile: When input is multiple config files, saved config preserves this structure.
|
||||
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
|
||||
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
|
||||
this argument will not be configurable from a configuration file and will always be present for
|
||||
this particular CLI. Alternatively, configurable callbacks can be added as explained in
|
||||
:ref:`the CLI docs <lightning-cli>`.
|
||||
seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
|
||||
seed argument. Set to True to automatically choose a valid seed.
|
||||
Setting it to False will not call seed_everything.
|
||||
description: Description of the tool shown when running ``--help``.
|
||||
env_prefix: Prefix for environment variables.
|
||||
env_parse: Whether environment variable parsing is enabled.
|
||||
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
|
||||
subclass_mode_model: Whether model can be any `subclass
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
|
||||
of the given class.
|
||||
subclass_mode_data: Whether datamodule can be any `subclass
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
|
||||
of the given class.
|
||||
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
method. If set to ``False``, the trainer and model classes will be instantiated only.
|
||||
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
|
||||
"""
|
||||
self.save_config_callback = save_config_callback
|
||||
self.save_config_filename = save_config_filename
|
||||
self.save_config_overwrite = save_config_overwrite
|
||||
self.save_config_multifile = save_config_multifile
|
||||
self.trainer_class = trainer_class
|
||||
self.trainer_defaults = trainer_defaults or {}
|
||||
self.seed_everything_default = seed_everything_default
|
||||
|
||||
if self.seed_everything_default is None:
|
||||
rank_zero_deprecation(
|
||||
"Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 "
|
||||
"and will be removed in v1.9. Set it to `False` instead."
|
||||
)
|
||||
self.seed_everything_default = False
|
||||
|
||||
self.model_class = model_class
|
||||
# used to differentiate between the original value and the processed value
|
||||
self._model_class = model_class or LightningModule
|
||||
self.subclass_mode_model = (model_class is None) or subclass_mode_model
|
||||
|
||||
self.datamodule_class = datamodule_class
|
||||
# used to differentiate between the original value and the processed value
|
||||
self._datamodule_class = datamodule_class or LightningDataModule
|
||||
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data
|
||||
|
||||
from pytorch_lightning.utilities.cli import _populate_registries
|
||||
|
||||
_populate_registries(auto_registry)
|
||||
|
||||
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
|
||||
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
|
||||
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
|
||||
)
|
||||
self.setup_parser(run, main_kwargs, subparser_kwargs)
|
||||
self.parse_arguments(self.parser)
|
||||
|
||||
self.subcommand = self.config["subcommand"] if run else None
|
||||
|
||||
self._set_seed()
|
||||
|
||||
self.before_instantiate_classes()
|
||||
self.instantiate_classes()
|
||||
|
||||
if self.subcommand is not None:
|
||||
self._run_subcommand(self.subcommand)
|
||||
|
||||
def _setup_parser_kwargs(
|
||||
self, kwargs: Dict[str, Any], defaults: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
if kwargs.keys() & self.subcommands().keys():
|
||||
# `kwargs` contains arguments per subcommand
|
||||
return defaults, kwargs
|
||||
main_kwargs = defaults
|
||||
main_kwargs.update(kwargs)
|
||||
return main_kwargs, {}
|
||||
|
||||
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
|
||||
"""Method that instantiates the argument parser."""
|
||||
kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"])
|
||||
return LightningArgumentParser(**kwargs)
|
||||
|
||||
def setup_parser(
|
||||
self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize and setup the parser, subcommands, and arguments."""
|
||||
self.parser = self.init_parser(**main_kwargs)
|
||||
if add_subcommands:
|
||||
self._subcommand_method_arguments: Dict[str, List[str]] = {}
|
||||
self._add_subcommands(self.parser, **subparser_kwargs)
|
||||
else:
|
||||
self._add_arguments(self.parser)
|
||||
|
||||
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Adds default arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--seed_everything",
|
||||
type=Union[bool, int],
|
||||
default=self.seed_everything_default,
|
||||
help=(
|
||||
"Set to an int to run seed_everything with this value before classes instantiation."
|
||||
"Set to True to use a random seed."
|
||||
),
|
||||
)
|
||||
|
||||
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Adds arguments from the core classes to the parser."""
|
||||
parser.add_lightning_class_args(self.trainer_class, "trainer")
|
||||
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
|
||||
parser.set_defaults(trainer_defaults)
|
||||
|
||||
parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
|
||||
|
||||
if self.datamodule_class is not None:
|
||||
parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data)
|
||||
else:
|
||||
# this should not be required because the user might want to use the `LightningModule` dataloaders
|
||||
parser.add_lightning_class_args(
|
||||
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False
|
||||
)
|
||||
|
||||
def _add_arguments(self, parser: LightningArgumentParser) -> None:
|
||||
# default + core + custom arguments
|
||||
self.add_default_arguments_to_parser(parser)
|
||||
self.add_core_arguments_to_parser(parser)
|
||||
self.add_arguments_to_parser(parser)
|
||||
# add default optimizer args if necessary
|
||||
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_optimizer_args((Optimizer,))
|
||||
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
|
||||
self.link_optimizers_and_lr_schedulers(parser)
|
||||
|
||||
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Implement to add extra arguments to the parser or link arguments.
|
||||
|
||||
Args:
|
||||
parser: The parser object to which arguments can be added
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def subcommands() -> Dict[str, Set[str]]:
|
||||
"""Defines the list of available subcommands and the arguments to skip."""
|
||||
return {
|
||||
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
|
||||
"validate": {"model", "dataloaders", "datamodule"},
|
||||
"test": {"model", "dataloaders", "datamodule"},
|
||||
"predict": {"model", "dataloaders", "datamodule"},
|
||||
"tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
|
||||
}
|
||||
|
||||
def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
|
||||
"""Adds subcommands to the input parser."""
|
||||
parser_subcommands = parser.add_subcommands()
|
||||
# the user might have passed a builder function
|
||||
trainer_class = (
|
||||
self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class)
|
||||
)
|
||||
# register all subcommands in separate subcommand parsers under the main parser
|
||||
for subcommand in self.subcommands():
|
||||
subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {}))
|
||||
fn = getattr(trainer_class, subcommand)
|
||||
# extract the first line description in the docstring for the subcommand help message
|
||||
description = _get_short_description(fn)
|
||||
parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description)
|
||||
|
||||
def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser:
|
||||
parser = self.init_parser(**kwargs)
|
||||
self._add_arguments(parser)
|
||||
# subcommand arguments
|
||||
skip = self.subcommands()[subcommand]
|
||||
added = parser.add_method_arguments(klass, subcommand, skip=skip)
|
||||
# need to save which arguments were added to pass them to the method later
|
||||
self._subcommand_method_arguments[subcommand] = added
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None:
|
||||
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
|
||||
optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers}
|
||||
for key, (class_type, link_to) in optimizers_and_lr_schedulers.items():
|
||||
if link_to == "AUTOMATIC":
|
||||
continue
|
||||
if isinstance(class_type, tuple):
|
||||
parser.link_arguments(key, link_to)
|
||||
else:
|
||||
add_class_path = _add_class_path_generator(class_type)
|
||||
parser.link_arguments(key, link_to, compute_fn=add_class_path)
|
||||
|
||||
def parse_arguments(self, parser: LightningArgumentParser) -> None:
|
||||
"""Parses command line arguments and stores it in ``self.config``."""
|
||||
self.config = parser.parse_args()
|
||||
|
||||
def before_instantiate_classes(self) -> None:
|
||||
"""Implement to run some code before instantiating the classes."""
|
||||
|
||||
def instantiate_classes(self) -> None:
|
||||
"""Instantiates the classes and sets their attributes."""
|
||||
self.config_init = self.parser.instantiate_classes(self.config)
|
||||
self.datamodule = self._get(self.config_init, "data")
|
||||
self.model = self._get(self.config_init, "model")
|
||||
self._add_configure_optimizers_method_to_model(self.subcommand)
|
||||
self.trainer = self.instantiate_trainer()
|
||||
|
||||
def instantiate_trainer(self, **kwargs: Any) -> Trainer:
|
||||
"""Instantiates the trainer.
|
||||
|
||||
Args:
|
||||
kwargs: Any custom trainer arguments.
|
||||
"""
|
||||
extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys]
|
||||
trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs}
|
||||
return self._instantiate_trainer(trainer_config, extra_callbacks)
|
||||
|
||||
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
|
||||
key = "callbacks"
|
||||
if key in config:
|
||||
if config[key] is None:
|
||||
config[key] = []
|
||||
elif not isinstance(config[key], list):
|
||||
config[key] = [config[key]]
|
||||
config[key].extend(callbacks)
|
||||
if key in self.trainer_defaults:
|
||||
value = self.trainer_defaults[key]
|
||||
config[key] += value if isinstance(value, list) else [value]
|
||||
if self.save_config_callback and not config.get("fast_dev_run", False):
|
||||
config_callback = self.save_config_callback(
|
||||
self._parser(self.subcommand),
|
||||
self.config.get(str(self.subcommand), self.config),
|
||||
self.save_config_filename,
|
||||
overwrite=self.save_config_overwrite,
|
||||
multifile=self.save_config_multifile,
|
||||
)
|
||||
config[key].append(config_callback)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will"
|
||||
" not be included."
|
||||
)
|
||||
return self.trainer_class(**config)
|
||||
|
||||
def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
|
||||
if subcommand is None:
|
||||
return self.parser
|
||||
# return the subcommand parser for the subcommand passed
|
||||
action_subcommand = self.parser._subcommands_action
|
||||
return action_subcommand._name_parser_map[subcommand]
|
||||
|
||||
@staticmethod
|
||||
def configure_optimizers(
|
||||
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
|
||||
) -> Any:
|
||||
"""Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers`
|
||||
method.
|
||||
|
||||
Args:
|
||||
lightning_module: A reference to the model.
|
||||
optimizer: The optimizer.
|
||||
lr_scheduler: The learning rate scheduler (if used).
|
||||
"""
|
||||
if lr_scheduler is None:
|
||||
return optimizer
|
||||
if isinstance(lr_scheduler, ReduceLROnPlateau):
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor},
|
||||
}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
|
||||
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
|
||||
parser = self._parser(subcommand)
|
||||
|
||||
def get_automatic(
|
||||
class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]
|
||||
) -> List[str]:
|
||||
automatic = []
|
||||
for key, (base_class, link_to) in register.items():
|
||||
if not isinstance(base_class, tuple):
|
||||
base_class = (base_class,)
|
||||
if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class):
|
||||
automatic.append(key)
|
||||
return automatic
|
||||
|
||||
optimizers = get_automatic(Optimizer, parser._optimizers)
|
||||
lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers)
|
||||
|
||||
if len(optimizers) == 0:
|
||||
return
|
||||
|
||||
if len(optimizers) > 1 or len(lr_schedulers) > 1:
|
||||
raise MisconfigurationException(
|
||||
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
|
||||
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
|
||||
"is expected to link the argument groups and implement `configure_optimizers`, see "
|
||||
"https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html"
|
||||
"#optimizers-and-learning-rate-schedulers"
|
||||
)
|
||||
|
||||
optimizer_class = parser._optimizers[optimizers[0]][0]
|
||||
optimizer_init = self._get(self.config_init, optimizers[0])
|
||||
if not isinstance(optimizer_class, tuple):
|
||||
optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
|
||||
if not optimizer_init:
|
||||
# optimizers were registered automatically but not passed by the user
|
||||
return
|
||||
|
||||
lr_scheduler_init = None
|
||||
if lr_schedulers:
|
||||
lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0]
|
||||
lr_scheduler_init = self._get(self.config_init, lr_schedulers[0])
|
||||
if not isinstance(lr_scheduler_class, tuple):
|
||||
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)
|
||||
|
||||
if is_overridden("configure_optimizers", self.model):
|
||||
_warn(
|
||||
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
|
||||
f"`{self.__class__.__name__}.configure_optimizers`."
|
||||
)
|
||||
|
||||
optimizer = instantiate_class(self.model.parameters(), optimizer_init)
|
||||
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
|
||||
fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden`
|
||||
# override the existing method
|
||||
self.model.configure_optimizers = MethodType(fn, self.model)
|
||||
|
||||
def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Utility to get a config value which might be inside a subcommand."""
|
||||
return config.get(str(self.subcommand), config).get(key, default)
|
||||
|
||||
def _run_subcommand(self, subcommand: str) -> None:
|
||||
"""Run the chosen subcommand."""
|
||||
before_fn = getattr(self, f"before_{subcommand}", None)
|
||||
if callable(before_fn):
|
||||
before_fn()
|
||||
|
||||
default = getattr(self.trainer, subcommand)
|
||||
fn = getattr(self, subcommand, default)
|
||||
fn_kwargs = self._prepare_subcommand_kwargs(subcommand)
|
||||
fn(**fn_kwargs)
|
||||
|
||||
after_fn = getattr(self, f"after_{subcommand}", None)
|
||||
if callable(after_fn):
|
||||
after_fn()
|
||||
|
||||
def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
|
||||
"""Prepares the keyword arguments to pass to the subcommand to run."""
|
||||
fn_kwargs = {
|
||||
k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand]
|
||||
}
|
||||
fn_kwargs["model"] = self.model
|
||||
if self.datamodule is not None:
|
||||
fn_kwargs["datamodule"] = self.datamodule
|
||||
return fn_kwargs
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
"""Sets the seed."""
|
||||
config_seed = self._get(self.config, "seed_everything")
|
||||
if config_seed is False:
|
||||
return
|
||||
if config_seed is True:
|
||||
# user requested seeding, choose randomly
|
||||
config_seed = seed_everything(workers=True)
|
||||
else:
|
||||
config_seed = seed_everything(config_seed, workers=True)
|
||||
self.config["seed_everything"] = config_seed
|
||||
|
||||
|
||||
def _class_path_from_class(class_type: Type) -> str:
|
||||
return class_type.__module__ + "." + class_type.__name__
|
||||
|
||||
|
||||
def _global_add_class_path(
|
||||
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(init_args, Namespace):
|
||||
init_args = init_args.as_dict()
|
||||
return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}}
|
||||
|
||||
|
||||
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
|
||||
def add_class_path(init_args: Namespace) -> Dict[str, Any]:
|
||||
return _global_add_class_path(class_type, init_args)
|
||||
|
||||
return add_class_path
|
||||
|
||||
|
||||
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
||||
"""Instantiates a class with the given args and init.
|
||||
|
||||
Args:
|
||||
args: Positional arguments required for instantiation.
|
||||
init: Dict of the form {"class_path":...,"init_args":...}.
|
||||
|
||||
Returns:
|
||||
The instantiated class object.
|
||||
"""
|
||||
kwargs = init.get("init_args", {})
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
class_module, class_name = init["class_path"].rsplit(".", 1)
|
||||
module = __import__(class_module, fromlist=[class_name])
|
||||
args_class = getattr(module, class_name)
|
||||
return args_class(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_short_description(component: object) -> Optional[str]:
|
||||
if component.__doc__ is None:
|
||||
return None
|
||||
try:
|
||||
docstring = docstring_parser.parse(component.__doc__)
|
||||
return docstring.short_description
|
||||
except (ValueError, docstring_parser.ParseError) as ex:
|
||||
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
|
|
@ -11,45 +11,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for LightningCLI."""
|
||||
"""Deprecated utilities for LightningCLI."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial, update_wrapper
|
||||
from types import MethodType, ModuleType
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union
|
||||
from types import ModuleType
|
||||
from typing import Any, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
import pytorch_lightning.cli as new_cli
|
||||
from pytorch_lightning.utilities.meta import get_all_subclasses
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2")
|
||||
|
||||
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
||||
import docstring_parser
|
||||
from jsonargparse import (
|
||||
ActionConfigFile,
|
||||
ArgumentParser,
|
||||
class_from_function,
|
||||
Namespace,
|
||||
register_unresolvable_import_paths,
|
||||
set_config_read_mode,
|
||||
)
|
||||
|
||||
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
|
||||
set_config_read_mode(fsspec_enabled=True)
|
||||
else:
|
||||
locals()["ArgumentParser"] = object
|
||||
locals()["Namespace"] = object
|
||||
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
_deprecate_registry_message = (
|
||||
"`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
|
||||
|
@ -130,18 +104,6 @@ DATAMODULE_REGISTRY = _Registry()
|
|||
LOGGER_REGISTRY = _Registry()
|
||||
|
||||
|
||||
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(optimizer, *args, **kwargs)
|
||||
self.monitor = monitor
|
||||
|
||||
|
||||
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
|
||||
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
||||
|
||||
|
||||
def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
|
||||
if subclasses:
|
||||
rank_zero_deprecation(_deprecate_auto_registry_message)
|
||||
|
@ -167,643 +129,37 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
|
|||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
|
||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
|
||||
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
|
||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False)
|
||||
LR_SCHEDULER_REGISTRY(cls=new_cli.ReduceLROnPlateau, show_deprecation=False)
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
|
||||
def _deprecation(cls: Type) -> None:
|
||||
rank_zero_deprecation(
|
||||
f"`pytorch_lightning.utilities.cli.{cls.__name__}` has been deprecated in v1.7 and will be removed in v1.9."
|
||||
f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead."
|
||||
)
|
||||
|
||||
|
||||
class LightningArgumentParser(new_cli.LightningArgumentParser):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input.
|
||||
|
||||
For full details of accepted arguments see `ArgumentParser.__init__
|
||||
<https://jsonargparse.readthedocs.io/en/stable/index.html#jsonargparse.ArgumentParser.__init__>`_.
|
||||
"""
|
||||
if not _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}. Try `pip install -U 'jsonargparse[signatures]'`."
|
||||
)
|
||||
_deprecation(type(self))
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_argument(
|
||||
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
|
||||
)
|
||||
self.callback_keys: List[str] = []
|
||||
# separate optimizers and lr schedulers to know which were added
|
||||
self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
|
||||
self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
|
||||
|
||||
def add_lightning_class_args(
|
||||
self,
|
||||
lightning_class: Union[
|
||||
Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]],
|
||||
Type[Trainer],
|
||||
Type[LightningModule],
|
||||
Type[LightningDataModule],
|
||||
Type[Callback],
|
||||
],
|
||||
nested_key: str,
|
||||
subclass_mode: bool = False,
|
||||
required: bool = True,
|
||||
) -> List[str]:
|
||||
"""Adds arguments from a lightning class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
subclass_mode: Whether allow any subclass of the given class.
|
||||
required: Whether the argument group is required.
|
||||
class SaveConfigCallback(new_cli.SaveConfigCallback):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
_deprecation(type(self))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Returns:
|
||||
A list with the names of the class arguments added.
|
||||
"""
|
||||
if callable(lightning_class) and not isinstance(lightning_class, type):
|
||||
lightning_class = class_from_function(lightning_class)
|
||||
|
||||
if isinstance(lightning_class, type) and issubclass(
|
||||
lightning_class, (Trainer, LightningModule, LightningDataModule, Callback)
|
||||
):
|
||||
if issubclass(lightning_class, Callback):
|
||||
self.callback_keys.append(nested_key)
|
||||
if subclass_mode:
|
||||
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required)
|
||||
return self.add_class_arguments(
|
||||
lightning_class,
|
||||
nested_key,
|
||||
fail_untyped=False,
|
||||
instantiate=not issubclass(lightning_class, Trainer),
|
||||
sub_configs=True,
|
||||
)
|
||||
raise MisconfigurationException(
|
||||
f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
|
||||
"Trainer, LightningModule, LightningDataModule, or Callback."
|
||||
)
|
||||
class LightningCLI(new_cli.LightningCLI):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
_deprecation(type(self))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def add_optimizer_args(
|
||||
self,
|
||||
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,),
|
||||
nested_key: str = "optimizer",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from an optimizer class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
if isinstance(optimizer_class, tuple):
|
||||
assert all(issubclass(o, Optimizer) for o in optimizer_class)
|
||||
else:
|
||||
assert issubclass(optimizer_class, Optimizer)
|
||||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
|
||||
if isinstance(optimizer_class, tuple):
|
||||
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
|
||||
else:
|
||||
self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs)
|
||||
self._optimizers[nested_key] = (optimizer_class, link_to)
|
||||
|
||||
def add_lr_scheduler_args(
|
||||
self,
|
||||
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple,
|
||||
nested_key: str = "lr_scheduler",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
|
||||
tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
if isinstance(lr_scheduler_class, tuple):
|
||||
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
|
||||
else:
|
||||
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
|
||||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
|
||||
if isinstance(lr_scheduler_class, tuple):
|
||||
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
|
||||
else:
|
||||
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
|
||||
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
|
||||
|
||||
|
||||
class SaveConfigCallback(Callback):
|
||||
"""Saves a LightningCLI config to the log_dir when training starts.
|
||||
|
||||
Args:
|
||||
parser: The parser object used to parse the configuration.
|
||||
config: The parsed configuration that will be saved.
|
||||
config_filename: Filename for the config file.
|
||||
overwrite: Whether to overwrite an existing config file.
|
||||
multifile: When input is multiple config files, saved config preserves this structure.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: LightningArgumentParser,
|
||||
config: Namespace,
|
||||
config_filename: str,
|
||||
overwrite: bool = False,
|
||||
multifile: bool = False,
|
||||
) -> None:
|
||||
self.parser = parser
|
||||
self.config = config
|
||||
self.config_filename = config_filename
|
||||
self.overwrite = overwrite
|
||||
self.multifile = multifile
|
||||
|
||||
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
|
||||
log_dir = trainer.log_dir # this broadcasts the directory
|
||||
assert log_dir is not None
|
||||
config_path = os.path.join(log_dir, self.config_filename)
|
||||
fs = get_filesystem(log_dir)
|
||||
|
||||
if not self.overwrite:
|
||||
# check if the file exists on rank 0
|
||||
file_exists = fs.isfile(config_path) if trainer.is_global_zero else False
|
||||
# broadcast whether to fail to all ranks
|
||||
file_exists = trainer.strategy.broadcast(file_exists)
|
||||
if file_exists:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
|
||||
" results of a previous run. You can delete the previous config file,"
|
||||
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
|
||||
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
|
||||
)
|
||||
|
||||
# save the file on rank 0
|
||||
if trainer.is_global_zero:
|
||||
# save only on rank zero to avoid race conditions.
|
||||
# the `log_dir` needs to be created as we rely on the logger to do it usually
|
||||
# but it hasn't logged anything at this point
|
||||
fs.makedirs(log_dir, exist_ok=True)
|
||||
self.parser.save(
|
||||
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
|
||||
)
|
||||
|
||||
|
||||
class LightningCLI:
|
||||
"""Implementation of a configurable command line tool for pytorch-lightning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None,
|
||||
datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None,
|
||||
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
|
||||
save_config_filename: str = "config.yaml",
|
||||
save_config_overwrite: bool = False,
|
||||
save_config_multifile: bool = False,
|
||||
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
|
||||
trainer_defaults: Optional[Dict[str, Any]] = None,
|
||||
seed_everything_default: Union[bool, int] = True,
|
||||
description: str = "pytorch-lightning trainer command line tool",
|
||||
env_prefix: str = "PL",
|
||||
env_parse: bool = False,
|
||||
parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None,
|
||||
subclass_mode_model: bool = False,
|
||||
subclass_mode_data: bool = False,
|
||||
run: bool = True,
|
||||
auto_registry: bool = False,
|
||||
) -> None:
|
||||
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
|
||||
are called / instantiated using a parsed configuration file and / or command line args.
|
||||
|
||||
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
|
||||
A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
|
||||
Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.
|
||||
|
||||
For more info, read :ref:`the CLI docs <lightning-cli>`.
|
||||
|
||||
.. warning:: ``LightningCLI`` is in beta and subject to change.
|
||||
|
||||
Args:
|
||||
model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when
|
||||
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
|
||||
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
|
||||
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
|
||||
save_config_callback: A callback class to save the training config.
|
||||
save_config_filename: Filename for the config file.
|
||||
save_config_overwrite: Whether to overwrite an existing config file.
|
||||
save_config_multifile: When input is multiple config files, saved config preserves this structure.
|
||||
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
|
||||
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
|
||||
this argument will not be configurable from a configuration file and will always be present for
|
||||
this particular CLI. Alternatively, configurable callbacks can be added as explained in
|
||||
:ref:`the CLI docs <lightning-cli>`.
|
||||
seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
|
||||
seed argument. Set to True to automatically choose a valid seed.
|
||||
Setting it to False will not call seed_everything.
|
||||
description: Description of the tool shown when running ``--help``.
|
||||
env_prefix: Prefix for environment variables.
|
||||
env_parse: Whether environment variable parsing is enabled.
|
||||
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
|
||||
subclass_mode_model: Whether model can be any `subclass
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
|
||||
of the given class.
|
||||
subclass_mode_data: Whether datamodule can be any `subclass
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
|
||||
of the given class.
|
||||
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
method. If set to ``False``, the trainer and model classes will be instantiated only.
|
||||
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
|
||||
"""
|
||||
self.save_config_callback = save_config_callback
|
||||
self.save_config_filename = save_config_filename
|
||||
self.save_config_overwrite = save_config_overwrite
|
||||
self.save_config_multifile = save_config_multifile
|
||||
self.trainer_class = trainer_class
|
||||
self.trainer_defaults = trainer_defaults or {}
|
||||
self.seed_everything_default = seed_everything_default
|
||||
|
||||
if self.seed_everything_default is None:
|
||||
rank_zero_deprecation(
|
||||
"Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 "
|
||||
"and will be removed in v1.9. Set it to `False` instead."
|
||||
)
|
||||
self.seed_everything_default = False
|
||||
|
||||
self.model_class = model_class
|
||||
# used to differentiate between the original value and the processed value
|
||||
self._model_class = model_class or LightningModule
|
||||
self.subclass_mode_model = (model_class is None) or subclass_mode_model
|
||||
|
||||
self.datamodule_class = datamodule_class
|
||||
# used to differentiate between the original value and the processed value
|
||||
self._datamodule_class = datamodule_class or LightningDataModule
|
||||
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data
|
||||
|
||||
_populate_registries(auto_registry)
|
||||
|
||||
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
|
||||
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
|
||||
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
|
||||
)
|
||||
self.setup_parser(run, main_kwargs, subparser_kwargs)
|
||||
self.parse_arguments(self.parser)
|
||||
|
||||
self.subcommand = self.config["subcommand"] if run else None
|
||||
|
||||
self._set_seed()
|
||||
|
||||
self.before_instantiate_classes()
|
||||
self.instantiate_classes()
|
||||
|
||||
if self.subcommand is not None:
|
||||
self._run_subcommand(self.subcommand)
|
||||
|
||||
def _setup_parser_kwargs(
|
||||
self, kwargs: Dict[str, Any], defaults: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
if kwargs.keys() & self.subcommands().keys():
|
||||
# `kwargs` contains arguments per subcommand
|
||||
return defaults, kwargs
|
||||
main_kwargs = defaults
|
||||
main_kwargs.update(kwargs)
|
||||
return main_kwargs, {}
|
||||
|
||||
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
|
||||
"""Method that instantiates the argument parser."""
|
||||
kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"])
|
||||
return LightningArgumentParser(**kwargs)
|
||||
|
||||
def setup_parser(
|
||||
self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize and setup the parser, subcommands, and arguments."""
|
||||
self.parser = self.init_parser(**main_kwargs)
|
||||
if add_subcommands:
|
||||
self._subcommand_method_arguments: Dict[str, List[str]] = {}
|
||||
self._add_subcommands(self.parser, **subparser_kwargs)
|
||||
else:
|
||||
self._add_arguments(self.parser)
|
||||
|
||||
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Adds default arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--seed_everything",
|
||||
type=Union[bool, int],
|
||||
default=self.seed_everything_default,
|
||||
help=(
|
||||
"Set to an int to run seed_everything with this value before classes instantiation."
|
||||
"Set to True to use a random seed."
|
||||
),
|
||||
)
|
||||
|
||||
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Adds arguments from the core classes to the parser."""
|
||||
parser.add_lightning_class_args(self.trainer_class, "trainer")
|
||||
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
|
||||
parser.set_defaults(trainer_defaults)
|
||||
|
||||
parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
|
||||
|
||||
if self.datamodule_class is not None:
|
||||
parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data)
|
||||
else:
|
||||
# this should not be required because the user might want to use the `LightningModule` dataloaders
|
||||
parser.add_lightning_class_args(
|
||||
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False
|
||||
)
|
||||
|
||||
def _add_arguments(self, parser: LightningArgumentParser) -> None:
|
||||
# default + core + custom arguments
|
||||
self.add_default_arguments_to_parser(parser)
|
||||
self.add_core_arguments_to_parser(parser)
|
||||
self.add_arguments_to_parser(parser)
|
||||
# add default optimizer args if necessary
|
||||
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_optimizer_args((Optimizer,))
|
||||
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
|
||||
self.link_optimizers_and_lr_schedulers(parser)
|
||||
|
||||
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Implement to add extra arguments to the parser or link arguments.
|
||||
|
||||
Args:
|
||||
parser: The parser object to which arguments can be added
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def subcommands() -> Dict[str, Set[str]]:
|
||||
"""Defines the list of available subcommands and the arguments to skip."""
|
||||
return {
|
||||
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
|
||||
"validate": {"model", "dataloaders", "datamodule"},
|
||||
"test": {"model", "dataloaders", "datamodule"},
|
||||
"predict": {"model", "dataloaders", "datamodule"},
|
||||
"tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
|
||||
}
|
||||
|
||||
def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
|
||||
"""Adds subcommands to the input parser."""
|
||||
parser_subcommands = parser.add_subcommands()
|
||||
# the user might have passed a builder function
|
||||
trainer_class = (
|
||||
self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class)
|
||||
)
|
||||
# register all subcommands in separate subcommand parsers under the main parser
|
||||
for subcommand in self.subcommands():
|
||||
subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {}))
|
||||
fn = getattr(trainer_class, subcommand)
|
||||
# extract the first line description in the docstring for the subcommand help message
|
||||
description = _get_short_description(fn)
|
||||
parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description)
|
||||
|
||||
def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser:
|
||||
parser = self.init_parser(**kwargs)
|
||||
self._add_arguments(parser)
|
||||
# subcommand arguments
|
||||
skip = self.subcommands()[subcommand]
|
||||
added = parser.add_method_arguments(klass, subcommand, skip=skip)
|
||||
# need to save which arguments were added to pass them to the method later
|
||||
self._subcommand_method_arguments[subcommand] = added
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None:
|
||||
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
|
||||
optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers}
|
||||
for key, (class_type, link_to) in optimizers_and_lr_schedulers.items():
|
||||
if link_to == "AUTOMATIC":
|
||||
continue
|
||||
if isinstance(class_type, tuple):
|
||||
parser.link_arguments(key, link_to)
|
||||
else:
|
||||
add_class_path = _add_class_path_generator(class_type)
|
||||
parser.link_arguments(key, link_to, compute_fn=add_class_path)
|
||||
|
||||
def parse_arguments(self, parser: LightningArgumentParser) -> None:
|
||||
"""Parses command line arguments and stores it in ``self.config``."""
|
||||
self.config = parser.parse_args()
|
||||
|
||||
def before_instantiate_classes(self) -> None:
|
||||
"""Implement to run some code before instantiating the classes."""
|
||||
|
||||
def instantiate_classes(self) -> None:
|
||||
"""Instantiates the classes and sets their attributes."""
|
||||
self.config_init = self.parser.instantiate_classes(self.config)
|
||||
self.datamodule = self._get(self.config_init, "data")
|
||||
self.model = self._get(self.config_init, "model")
|
||||
self._add_configure_optimizers_method_to_model(self.subcommand)
|
||||
self.trainer = self.instantiate_trainer()
|
||||
|
||||
def instantiate_trainer(self, **kwargs: Any) -> Trainer:
|
||||
"""Instantiates the trainer.
|
||||
|
||||
Args:
|
||||
kwargs: Any custom trainer arguments.
|
||||
"""
|
||||
extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys]
|
||||
trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs}
|
||||
return self._instantiate_trainer(trainer_config, extra_callbacks)
|
||||
|
||||
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
|
||||
key = "callbacks"
|
||||
if key in config:
|
||||
if config[key] is None:
|
||||
config[key] = []
|
||||
elif not isinstance(config[key], list):
|
||||
config[key] = [config[key]]
|
||||
config[key].extend(callbacks)
|
||||
if key in self.trainer_defaults:
|
||||
value = self.trainer_defaults[key]
|
||||
config[key] += value if isinstance(value, list) else [value]
|
||||
if self.save_config_callback and not config.get("fast_dev_run", False):
|
||||
config_callback = self.save_config_callback(
|
||||
self._parser(self.subcommand),
|
||||
self.config.get(str(self.subcommand), self.config),
|
||||
self.save_config_filename,
|
||||
overwrite=self.save_config_overwrite,
|
||||
multifile=self.save_config_multifile,
|
||||
)
|
||||
config[key].append(config_callback)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will"
|
||||
" not be included."
|
||||
)
|
||||
return self.trainer_class(**config)
|
||||
|
||||
def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
|
||||
if subcommand is None:
|
||||
return self.parser
|
||||
# return the subcommand parser for the subcommand passed
|
||||
action_subcommand = self.parser._subcommands_action
|
||||
return action_subcommand._name_parser_map[subcommand]
|
||||
|
||||
@staticmethod
|
||||
def configure_optimizers(
|
||||
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
|
||||
) -> Any:
|
||||
"""Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers`
|
||||
method.
|
||||
|
||||
Args:
|
||||
lightning_module: A reference to the model.
|
||||
optimizer: The optimizer.
|
||||
lr_scheduler: The learning rate scheduler (if used).
|
||||
"""
|
||||
if lr_scheduler is None:
|
||||
return optimizer
|
||||
if isinstance(lr_scheduler, ReduceLROnPlateau):
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor},
|
||||
}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
|
||||
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
|
||||
parser = self._parser(subcommand)
|
||||
|
||||
def get_automatic(
|
||||
class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]
|
||||
) -> List[str]:
|
||||
automatic = []
|
||||
for key, (base_class, link_to) in register.items():
|
||||
if not isinstance(base_class, tuple):
|
||||
base_class = (base_class,)
|
||||
if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class):
|
||||
automatic.append(key)
|
||||
return automatic
|
||||
|
||||
optimizers = get_automatic(Optimizer, parser._optimizers)
|
||||
lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers)
|
||||
|
||||
if len(optimizers) == 0:
|
||||
return
|
||||
|
||||
if len(optimizers) > 1 or len(lr_schedulers) > 1:
|
||||
raise MisconfigurationException(
|
||||
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
|
||||
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
|
||||
"is expected to link the argument groups and implement `configure_optimizers`, see "
|
||||
"https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html"
|
||||
"#optimizers-and-learning-rate-schedulers"
|
||||
)
|
||||
|
||||
optimizer_class = parser._optimizers[optimizers[0]][0]
|
||||
optimizer_init = self._get(self.config_init, optimizers[0])
|
||||
if not isinstance(optimizer_class, tuple):
|
||||
optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
|
||||
if not optimizer_init:
|
||||
# optimizers were registered automatically but not passed by the user
|
||||
return
|
||||
|
||||
lr_scheduler_init = None
|
||||
if lr_schedulers:
|
||||
lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0]
|
||||
lr_scheduler_init = self._get(self.config_init, lr_schedulers[0])
|
||||
if not isinstance(lr_scheduler_class, tuple):
|
||||
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)
|
||||
|
||||
if is_overridden("configure_optimizers", self.model):
|
||||
_warn(
|
||||
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
|
||||
f"`{self.__class__.__name__}.configure_optimizers`."
|
||||
)
|
||||
|
||||
optimizer = instantiate_class(self.model.parameters(), optimizer_init)
|
||||
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
|
||||
fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden`
|
||||
# override the existing method
|
||||
self.model.configure_optimizers = MethodType(fn, self.model)
|
||||
|
||||
def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Utility to get a config value which might be inside a subcommand."""
|
||||
return config.get(str(self.subcommand), config).get(key, default)
|
||||
|
||||
def _run_subcommand(self, subcommand: str) -> None:
|
||||
"""Run the chosen subcommand."""
|
||||
before_fn = getattr(self, f"before_{subcommand}", None)
|
||||
if callable(before_fn):
|
||||
before_fn()
|
||||
|
||||
default = getattr(self.trainer, subcommand)
|
||||
fn = getattr(self, subcommand, default)
|
||||
fn_kwargs = self._prepare_subcommand_kwargs(subcommand)
|
||||
fn(**fn_kwargs)
|
||||
|
||||
after_fn = getattr(self, f"after_{subcommand}", None)
|
||||
if callable(after_fn):
|
||||
after_fn()
|
||||
|
||||
def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
|
||||
"""Prepares the keyword arguments to pass to the subcommand to run."""
|
||||
fn_kwargs = {
|
||||
k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand]
|
||||
}
|
||||
fn_kwargs["model"] = self.model
|
||||
if self.datamodule is not None:
|
||||
fn_kwargs["datamodule"] = self.datamodule
|
||||
return fn_kwargs
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
"""Sets the seed."""
|
||||
config_seed = self._get(self.config, "seed_everything")
|
||||
if config_seed is False:
|
||||
return
|
||||
if config_seed is True:
|
||||
# user requested seeding, choose randomly
|
||||
config_seed = seed_everything(workers=True)
|
||||
else:
|
||||
config_seed = seed_everything(config_seed, workers=True)
|
||||
self.config["seed_everything"] = config_seed
|
||||
|
||||
|
||||
def _class_path_from_class(class_type: Type) -> str:
|
||||
return class_type.__module__ + "." + class_type.__name__
|
||||
|
||||
|
||||
def _global_add_class_path(
|
||||
class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(init_args, Namespace):
|
||||
init_args = init_args.as_dict()
|
||||
return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}}
|
||||
|
||||
|
||||
def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]:
|
||||
def add_class_path(init_args: Namespace) -> Dict[str, Any]:
|
||||
return _global_add_class_path(class_type, init_args)
|
||||
|
||||
return add_class_path
|
||||
|
||||
|
||||
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
||||
"""Instantiates a class with the given args and init.
|
||||
|
||||
Args:
|
||||
args: Positional arguments required for instantiation.
|
||||
init: Dict of the form {"class_path":...,"init_args":...}.
|
||||
|
||||
Returns:
|
||||
The instantiated class object.
|
||||
"""
|
||||
kwargs = init.get("init_args", {})
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
class_module, class_name = init["class_path"].rsplit(".", 1)
|
||||
module = __import__(class_module, fromlist=[class_name])
|
||||
args_class = getattr(module, class_name)
|
||||
return args_class(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_short_description(component: object) -> Optional[str]:
|
||||
if component.__doc__ is None:
|
||||
return None
|
||||
try:
|
||||
docstring = docstring_parser.parse(component.__doc__)
|
||||
return docstring.short_description
|
||||
except (ValueError, docstring_parser.ParseError) as ex:
|
||||
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
|
||||
def instantiate_class(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.cli.instantiate_class` has been deprecated in v1.7 and will be removed in v1.9."
|
||||
" Use the equivalent function in `pytorch_lightning.cli.instantiate_class` instead."
|
||||
)
|
||||
return new_cli.instantiate_class(*args, **kwargs)
|
||||
|
|
|
@ -5,9 +5,7 @@ if _is_torch_available():
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
if _is_pytorch_lightning_available():
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.utilities import cli
|
||||
|
||||
from pytorch_lightning import cli, LightningDataModule, LightningModule
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ if _is_pytorch_lightning_available():
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, size, length):
|
||||
|
|
|
@ -9,7 +9,7 @@ from lightning_app.utilities.introspection import Scanner
|
|||
|
||||
if _is_pytorch_lightning_available():
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
|
||||
from tests_app import _PROJECT_ROOT
|
||||
|
||||
|
|
|
@ -13,12 +13,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import pytorch_lightning.loggers.base as logger_base
|
||||
import pytorch_lightning.utilities.cli as old_cli
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.gpu import GPUAccelerator
|
||||
from pytorch_lightning.cli import LightningCLI, SaveConfigCallback
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.profiler.advanced import AdvancedProfiler
|
||||
|
@ -27,13 +30,6 @@ from pytorch_lightning.profiler.profiler import Profiler
|
|||
from pytorch_lightning.profiler.pytorch import PyTorchProfiler, RegisterRecordFunction, ScheduleWrapper
|
||||
from pytorch_lightning.profiler.simple import SimpleProfiler
|
||||
from pytorch_lightning.profiler.xla import XLAProfiler
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
_deprecate_auto_registry_message,
|
||||
_deprecate_registry_message,
|
||||
CALLBACK_REGISTRY,
|
||||
LightningCLI,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
@ -152,19 +148,34 @@ def test_deprecated_dataloader_reset():
|
|||
|
||||
|
||||
def test_lightningCLI_registries_register():
|
||||
with pytest.deprecated_call(match=_deprecate_registry_message):
|
||||
with pytest.deprecated_call(match=old_cli._deprecate_registry_message):
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
@old_cli.CALLBACK_REGISTRY
|
||||
class CustomCallback(SaveConfigCallback):
|
||||
pass
|
||||
|
||||
|
||||
def test_lightningCLI_registries_register_automatically():
|
||||
with pytest.deprecated_call(match=_deprecate_auto_registry_message):
|
||||
with pytest.deprecated_call(match=old_cli._deprecate_auto_registry_message):
|
||||
with mock.patch("sys.argv", ["any.py"]):
|
||||
LightningCLI(BoringModel, run=False, auto_registry=True)
|
||||
|
||||
|
||||
def test_lightningCLI_old_module_deprecation():
|
||||
with pytest.deprecated_call(match=r"LightningCLI.*deprecated in v1.7.*Use the equivalent class"):
|
||||
with mock.patch("sys.argv", ["any.py"]):
|
||||
old_cli.LightningCLI(BoringModel, run=False)
|
||||
|
||||
with pytest.deprecated_call(match=r"SaveConfigCallback.*deprecated in v1.7.*Use the equivalent class"):
|
||||
old_cli.SaveConfigCallback(Mock(), Mock(), Mock())
|
||||
|
||||
with pytest.deprecated_call(match=r"LightningArgumentParser.*deprecated in v1.7.*Use the equivalent class"):
|
||||
old_cli.LightningArgumentParser()
|
||||
|
||||
with pytest.deprecated_call(match=r"instantiate_class.*deprecated in v1.7.*Use the equivalent function"):
|
||||
assert isinstance(old_cli.instantiate_class(tuple(), {"class_path": "pytorch_lightning.Trainer"}), Trainer)
|
||||
|
||||
|
||||
def test_profiler_deprecation_warning():
|
||||
assert "Profiler` is deprecated in v1.7" in Profiler.__doc__
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
|
@ -27,20 +26,12 @@ from unittest.mock import ANY
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from packaging import version
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
||||
|
||||
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
||||
from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, TensorBoardLogger
|
||||
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
from pytorch_lightning.cli import (
|
||||
_JSONARGPARSE_SIGNATURES_AVAILABLE,
|
||||
instantiate_class,
|
||||
LightningArgumentParser,
|
||||
|
@ -48,15 +39,18 @@ from pytorch_lightning.utilities.cli import (
|
|||
LRSchedulerTypeTuple,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
||||
from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, TensorBoardLogger
|
||||
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.utils import no_warning_call
|
||||
|
||||
torchvision_version = version.parse("0")
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
torchvision_version = version.parse(__import__("torchvision").__version__)
|
||||
|
||||
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
||||
from jsonargparse import lazy_instance
|
||||
|
||||
|
@ -525,7 +519,7 @@ def test_lightning_cli_submodules(tmpdir):
|
|||
assert isinstance(cli.model.submodule2, BoringModel)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required")
|
||||
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available")
|
||||
def test_lightning_cli_torch_modules(tmpdir):
|
||||
class TestModule(BoringModel):
|
||||
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
|
||||
|
@ -594,7 +588,7 @@ def test_lightning_cli_link_arguments(tmpdir):
|
|||
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
|
||||
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
|
||||
|
||||
cli_args[-1] = "--model=tests_pytorch.utilities.test_cli.BoringModelRequiredClasses"
|
||||
cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = MyLightningCLI(
|
Loading…
Reference in New Issue