add parsing OS env vars (#4022)
* add parsing OS env vars * fix env * Apply suggestions from code review * overwrite init * Apply suggestions from code review
This commit is contained in:
parent
8a3c800641
commit
baf4f35027
|
@ -0,0 +1,43 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types
|
||||
|
||||
|
||||
def overwrite_by_env_vars(fn: Callable) -> Callable:
|
||||
"""
|
||||
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
|
||||
input arguments should be moved automatically to the correct device.
|
||||
|
||||
"""
|
||||
@wraps(fn)
|
||||
def overwrite_by_env_vars(self, *args, **kwargs):
|
||||
# get the class
|
||||
cls = self.__class__
|
||||
if args: # inace any args passed move them to kwargs
|
||||
# parse only the argument names
|
||||
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
|
||||
# convert args to kwargs
|
||||
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
|
||||
# update the kwargs by env variables
|
||||
# todo: maybe add a warning that some init args were overwritten by Env arguments
|
||||
kwargs.update(vars(parse_env_variables(cls)))
|
||||
|
||||
# all args were already moved to kwargs
|
||||
return fn(self, **kwargs)
|
||||
|
||||
return overwrite_by_env_vars
|
|
@ -112,6 +112,10 @@ class TrainerProperties(ABC):
|
|||
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
||||
return argparse_utils.parse_argparser(cls, arg_parser)
|
||||
|
||||
@classmethod
|
||||
def match_env_arguments(cls) -> Namespace:
|
||||
return argparse_utils.parse_env_variables(cls)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
return argparse_utils.add_argparse_args(cls, parent_parser)
|
||||
|
|
|
@ -28,6 +28,7 @@ from pytorch_lightning.loggers import LightningLoggerBase
|
|||
from pytorch_lightning.profiler import BaseProfiler
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
|
||||
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
|
@ -79,6 +80,7 @@ class Trainer(
|
|||
TrainerTrainingTricksMixin,
|
||||
TrainerDataLoadingMixin,
|
||||
):
|
||||
@overwrite_by_env_vars
|
||||
def __init__(
|
||||
self,
|
||||
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Union, List, Tuple, Any
|
||||
from pytorch_lightning.utilities import parsing
|
||||
|
@ -7,6 +8,7 @@ from pytorch_lightning.utilities import parsing
|
|||
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
||||
"""
|
||||
Create an instance from CLI arguments.
|
||||
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
|
||||
|
||||
Args:
|
||||
args: The parser or namespace to take arguments from. Only known arguments will be
|
||||
|
@ -22,8 +24,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
|||
>>> args = Trainer.parse_argparser(parser.parse_args(""))
|
||||
>>> trainer = Trainer.from_argparse_args(args, logger=False)
|
||||
"""
|
||||
# fist check if any args are defined in environment for the class and set as default
|
||||
|
||||
if isinstance(args, ArgumentParser):
|
||||
args = cls.parse_argparser(args)
|
||||
# if other arg passed, update parameters
|
||||
params = vars(args)
|
||||
|
||||
# we only want to pass in valid Trainer args, the rest may be user specific
|
||||
|
@ -61,6 +66,35 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
|
|||
return Namespace(**modified_args)
|
||||
|
||||
|
||||
def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
|
||||
"""Parse environment arguments if they are defined.
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> parse_env_variables(Trainer)
|
||||
Namespace()
|
||||
>>> import os
|
||||
>>> os.environ["PL_TRAINER_GPUS"] = '42'
|
||||
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
|
||||
>>> parse_env_variables(Trainer)
|
||||
Namespace(gpus=42)
|
||||
>>> del os.environ["PL_TRAINER_GPUS"]
|
||||
"""
|
||||
cls_arg_defaults = get_init_arguments_and_types(cls)
|
||||
|
||||
env_args = {}
|
||||
for arg_name, _, _ in cls_arg_defaults:
|
||||
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
|
||||
val = os.environ.get(env)
|
||||
if not (val is None or val == ''):
|
||||
try: # converting to native types like int/float/bool
|
||||
val = eval(val)
|
||||
except Exception:
|
||||
pass
|
||||
env_args[arg_name] = val
|
||||
return Namespace(**env_args)
|
||||
|
||||
|
||||
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
||||
r"""Scans the Trainer signature and returns argument names, types and default values.
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
def test_passing_env_variables(tmpdir):
|
||||
"""Testing overwriting trainer arguments """
|
||||
trainer = Trainer()
|
||||
assert trainer.logger is not None
|
||||
assert trainer.max_steps is None
|
||||
trainer = Trainer(False, max_steps=42)
|
||||
assert trainer.logger is None
|
||||
assert trainer.max_steps == 42
|
||||
|
||||
os.environ['PL_TRAINER_LOGGER'] = 'False'
|
||||
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
|
||||
trainer = Trainer()
|
||||
assert trainer.logger is None
|
||||
assert trainer.max_steps == 7
|
||||
|
||||
os.environ['PL_TRAINER_LOGGER'] = 'True'
|
||||
trainer = Trainer(False, max_steps=42)
|
||||
assert trainer.logger is not None
|
||||
assert trainer.max_steps == 7
|
||||
|
||||
# this has to be cleaned
|
||||
del os.environ['PL_TRAINER_LOGGER']
|
||||
del os.environ['PL_TRAINER_MAX_STEPS']
|
Loading…
Reference in New Issue