add parsing OS env vars (#4022)

* add parsing OS env vars

* fix env

* Apply suggestions from code review

* overwrite init

* Apply suggestions from code review
This commit is contained in:
Jirka Borovec 2020-10-10 01:34:09 +02:00 committed by GitHub
parent 8a3c800641
commit baf4f35027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 0 deletions

View File

@ -0,0 +1,43 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Callable
from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types
def overwrite_by_env_vars(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.
"""
@wraps(fn)
def overwrite_by_env_vars(self, *args, **kwargs):
# get the class
cls = self.__class__
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
# update the kwargs by env variables
# todo: maybe add a warning that some init args were overwritten by Env arguments
kwargs.update(vars(parse_env_variables(cls)))
# all args were already moved to kwargs
return fn(self, **kwargs)
return overwrite_by_env_vars

View File

@ -112,6 +112,10 @@ class TrainerProperties(ABC):
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
@classmethod
def match_env_arguments(cls) -> Namespace:
return argparse_utils.parse_env_variables(cls)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)

View File

@ -28,6 +28,7 @@ from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@ -79,6 +80,7 @@ class Trainer(
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
):
@overwrite_by_env_vars
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,

View File

@ -1,4 +1,5 @@
import inspect
import os
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing
@ -7,6 +8,7 @@ from pytorch_lightning.utilities import parsing
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
@ -22,8 +24,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
# fist check if any args are defined in environment for the class and set as default
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
# if other arg passed, update parameters
params = vars(args)
# we only want to pass in valid Trainer args, the rest may be user specific
@ -61,6 +66,35 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
return Namespace(**modified_args)
def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Example:
>>> from pytorch_lightning import Trainer
>>> parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_GPUS"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> parse_env_variables(Trainer)
Namespace(gpus=42)
>>> del os.environ["PL_TRAINER_GPUS"]
"""
cls_arg_defaults = get_init_arguments_and_types(cls)
env_args = {}
for arg_name, _, _ in cls_arg_defaults:
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ''):
try: # converting to native types like int/float/bool
val = eval(val)
except Exception:
pass
env_args[arg_name] = val
return Namespace(**env_args)
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.

View File

@ -0,0 +1,28 @@
import os
from pytorch_lightning import Trainer
def test_passing_env_variables(tmpdir):
"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is not None
assert trainer.max_steps is None
trainer = Trainer(False, max_steps=42)
assert trainer.logger is None
assert trainer.max_steps == 42
os.environ['PL_TRAINER_LOGGER'] = 'False'
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
trainer = Trainer()
assert trainer.logger is None
assert trainer.max_steps == 7
os.environ['PL_TRAINER_LOGGER'] = 'True'
trainer = Trainer(False, max_steps=42)
assert trainer.logger is not None
assert trainer.max_steps == 7
# this has to be cleaned
del os.environ['PL_TRAINER_LOGGER']
del os.environ['PL_TRAINER_MAX_STEPS']