322 lines
12 KiB
Python
322 lines
12 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import os
|
|
from abc import ABC
|
|
from argparse import _ArgumentGroup, ArgumentParser, Namespace
|
|
from contextlib import suppress
|
|
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str
|
|
|
|
|
|
class ParseArgparserDataType(ABC):
|
|
def __init__(self, *_: Any, **__: Any) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def parse_argparser(cls, args: "ArgumentParser") -> Any:
|
|
pass
|
|
|
|
|
|
def from_argparse_args(
|
|
cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any
|
|
) -> ParseArgparserDataType:
|
|
"""Create an instance from CLI arguments.
|
|
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
|
|
|
|
Args:
|
|
cls: Lightning class
|
|
args: The parser or namespace to take arguments from. Only known arguments will be
|
|
parsed and passed to the :class:`Trainer`.
|
|
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
|
|
These must be valid Trainer arguments.
|
|
|
|
Example:
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> parser = ArgumentParser(add_help=False)
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
|
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
|
|
>>> args = Trainer.parse_argparser(parser.parse_args(""))
|
|
>>> trainer = Trainer.from_argparse_args(args, logger=False)
|
|
"""
|
|
if isinstance(args, ArgumentParser):
|
|
args = cls.parse_argparser(args)
|
|
|
|
params = vars(args)
|
|
|
|
# we only want to pass in valid Trainer args, the rest may be user specific
|
|
valid_kwargs = inspect.signature(cls.__init__).parameters
|
|
trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params}
|
|
trainer_kwargs.update(**kwargs)
|
|
|
|
return cls(**trainer_kwargs)
|
|
|
|
|
|
def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
|
"""Parse CLI arguments, required for custom bool types."""
|
|
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
|
|
|
types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)}
|
|
|
|
modified_args = {}
|
|
for k, v in vars(args).items():
|
|
if k in types_default and v is None:
|
|
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
|
|
arg_types, arg_default = types_default[k]
|
|
if bool in arg_types and isinstance(arg_default, bool):
|
|
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
|
# We always set to True, regardless of the default value.
|
|
# Users must pass False directly, but when passing nothing True is assumed.
|
|
# i.e. the only way to disable something that defaults to True is to use the long form:
|
|
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
|
# which then becomes True here.
|
|
|
|
v = True
|
|
|
|
modified_args[k] = v
|
|
return Namespace(**modified_args)
|
|
|
|
|
|
def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
|
|
"""Parse environment arguments if they are defined.
|
|
|
|
Example:
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> parse_env_variables(Trainer)
|
|
Namespace()
|
|
>>> import os
|
|
>>> os.environ["PL_TRAINER_GPUS"] = '42'
|
|
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
|
|
>>> parse_env_variables(Trainer)
|
|
Namespace(gpus=42)
|
|
>>> del os.environ["PL_TRAINER_GPUS"]
|
|
"""
|
|
cls_arg_defaults = get_init_arguments_and_types(cls)
|
|
|
|
env_args = {}
|
|
for arg_name, _, _ in cls_arg_defaults:
|
|
env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()}
|
|
val = os.environ.get(env)
|
|
if not (val is None or val == ""):
|
|
# todo: specify the possible exception
|
|
with suppress(Exception):
|
|
# converting to native types like int/float/bool
|
|
val = eval(val)
|
|
env_args[arg_name] = val
|
|
return Namespace(**env_args)
|
|
|
|
|
|
def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
|
|
r"""Scans the class signature and returns argument names, types and default values.
|
|
|
|
Returns:
|
|
List with tuples of 3 values:
|
|
(argument name, set with argument types, argument default value).
|
|
|
|
Examples:
|
|
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> args = get_init_arguments_and_types(Trainer)
|
|
|
|
"""
|
|
cls_default_params = inspect.signature(cls).parameters
|
|
name_type_default = []
|
|
for arg in cls_default_params:
|
|
arg_type = cls_default_params[arg].annotation
|
|
arg_default = cls_default_params[arg].default
|
|
try:
|
|
arg_types = tuple(arg_type.__args__)
|
|
except AttributeError:
|
|
arg_types = (arg_type,)
|
|
|
|
name_type_default.append((arg, arg_types, arg_default))
|
|
|
|
return name_type_default
|
|
|
|
|
|
def _get_abbrev_qualified_cls_name(cls: Any) -> str:
|
|
assert isinstance(cls, type), repr(cls)
|
|
if cls.__module__.startswith("pytorch_lightning."):
|
|
# Abbreviate.
|
|
return f"pl.{cls.__name__}"
|
|
# Fully qualified.
|
|
return f"{cls.__module__}.{cls.__qualname__}"
|
|
|
|
|
|
def add_argparse_args(
|
|
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
|
|
) -> Union[_ArgumentGroup, ArgumentParser]:
|
|
r"""Extends existing argparse by default attributes for ``cls``.
|
|
|
|
Args:
|
|
cls: Lightning class
|
|
parent_parser:
|
|
The custom cli arguments parser, which will be extended by
|
|
the class's default arguments.
|
|
use_argument_group:
|
|
By default, this is True, and uses ``add_argument_group`` to add
|
|
a new group.
|
|
If False, this will use old behavior.
|
|
|
|
Returns:
|
|
If use_argument_group is True, returns ``parent_parser`` to keep old
|
|
workflows. If False, will return the new ArgumentParser object.
|
|
|
|
Only arguments of the allowed types (str, float, int, bool) will
|
|
extend the ``parent_parser``.
|
|
|
|
Raises:
|
|
RuntimeError:
|
|
If ``parent_parser`` is not an ``ArgumentParser`` instance
|
|
|
|
Examples:
|
|
|
|
# Option 1: Default usage.
|
|
>>> import argparse
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> parser = argparse.ArgumentParser()
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
|
>>> args = parser.parse_args([])
|
|
|
|
# Option 2: Disable use_argument_group (old behavior).
|
|
>>> import argparse
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> parser = argparse.ArgumentParser()
|
|
>>> parser = Trainer.add_argparse_args(parser, use_argument_group=False)
|
|
>>> args = parser.parse_args([])
|
|
"""
|
|
if isinstance(parent_parser, _ArgumentGroup):
|
|
raise RuntimeError("Please only pass an ArgumentParser instance.")
|
|
if use_argument_group:
|
|
group_name = _get_abbrev_qualified_cls_name(cls)
|
|
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
|
|
else:
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
|
|
ignore_arg_names = ["self", "args", "kwargs"]
|
|
if hasattr(cls, "get_deprecated_arg_names"):
|
|
ignore_arg_names += cls.get_deprecated_arg_names()
|
|
|
|
allowed_types = (str, int, float, bool)
|
|
|
|
# Get symbols from cls or init function.
|
|
for symbol in (cls, cls.__init__):
|
|
args_and_types = get_init_arguments_and_types(symbol)
|
|
args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names]
|
|
if len(args_and_types) > 0:
|
|
break
|
|
|
|
args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "")
|
|
|
|
for arg, arg_types, arg_default in args_and_types:
|
|
arg_types = tuple(at for at in allowed_types if at in arg_types)
|
|
if not arg_types:
|
|
# skip argument with not supported type
|
|
continue
|
|
arg_kwargs: Dict[str, Any] = {}
|
|
if bool in arg_types:
|
|
arg_kwargs.update(nargs="?", const=True)
|
|
# if the only arg type is bool
|
|
if len(arg_types) == 1:
|
|
use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool
|
|
elif int in arg_types:
|
|
use_type = str_to_bool_or_int
|
|
elif str in arg_types:
|
|
use_type = str_to_bool_or_str
|
|
else:
|
|
# filter out the bool as we need to use more general
|
|
use_type = [at for at in arg_types if at is not bool][0]
|
|
else:
|
|
use_type = arg_types[0]
|
|
|
|
if arg == "gpus" or arg == "tpu_cores":
|
|
use_type = _gpus_allowed_type
|
|
|
|
# hack for types in (int, float)
|
|
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
|
|
use_type = _int_or_float_type
|
|
|
|
# hack for track_grad_norm
|
|
if arg == "track_grad_norm":
|
|
use_type = float
|
|
|
|
# hack for precision
|
|
if arg == "precision":
|
|
use_type = _precision_allowed_type
|
|
|
|
parser.add_argument(
|
|
f"--{arg}", dest=arg, default=arg_default, type=use_type, help=args_help.get(arg), **arg_kwargs
|
|
)
|
|
|
|
if use_argument_group:
|
|
return parent_parser
|
|
return parser
|
|
|
|
|
|
def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
|
|
arg_block_indent = None
|
|
current_arg = ""
|
|
parsed = {}
|
|
for line in docstring.split("\n"):
|
|
stripped = line.lstrip()
|
|
if not stripped:
|
|
continue
|
|
line_indent = len(line) - len(stripped)
|
|
if stripped.startswith(("Args:", "Arguments:", "Parameters:")):
|
|
arg_block_indent = line_indent + 4
|
|
elif arg_block_indent is None:
|
|
continue
|
|
elif line_indent < arg_block_indent:
|
|
break
|
|
elif line_indent == arg_block_indent:
|
|
current_arg, arg_description = stripped.split(":", maxsplit=1)
|
|
parsed[current_arg] = arg_description.lstrip()
|
|
elif line_indent > arg_block_indent:
|
|
parsed[current_arg] += f" {stripped}"
|
|
return parsed
|
|
|
|
|
|
def _gpus_allowed_type(x: str) -> Union[int, str]:
|
|
if "," in x:
|
|
return str(x)
|
|
return int(x)
|
|
|
|
|
|
def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover
|
|
# unused, but here for backward compatibility with old checkpoints that need to be able to
|
|
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
|
|
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
|
|
pass
|
|
|
|
|
|
def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
|
|
if "." in str(x):
|
|
return float(x)
|
|
return int(x)
|
|
|
|
|
|
def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
|
|
"""
|
|
>>> _precision_allowed_type("32")
|
|
32
|
|
>>> _precision_allowed_type("bf16")
|
|
'bf16'
|
|
"""
|
|
try:
|
|
return int(x)
|
|
except ValueError:
|
|
return x
|