diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py new file mode 100644 index 0000000000..2cbbc8e40e --- /dev/null +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from typing import Callable + +from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types + + +def overwrite_by_env_vars(fn: Callable) -> Callable: + """ + Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which + input arguments should be moved automatically to the correct device. + + """ + @wraps(fn) + def overwrite_by_env_vars(self, *args, **kwargs): + # get the class + cls = self.__class__ + if args: # inace any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] + # convert args to kwargs + kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + # update the kwargs by env variables + # todo: maybe add a warning that some init args were overwritten by Env arguments + kwargs.update(vars(parse_env_variables(cls))) + + # all args were already moved to kwargs + return fn(self, **kwargs) + + return overwrite_by_env_vars diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ce8041e5b8..76df783d2d 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -112,6 +112,10 @@ class TrainerProperties(ABC): def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: return argparse_utils.parse_argparser(cls, arg_parser) + @classmethod + def match_env_arguments(cls) -> Namespace: + return argparse_utils.parse_env_variables(cls) + @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: return argparse_utils.add_argparse_args(cls, parent_parser) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f1bad385b8..c2926ddd13 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,6 +28,7 @@ from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator +from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -79,6 +80,7 @@ class Trainer( TrainerTrainingTricksMixin, TrainerDataLoadingMixin, ): + @overwrite_by_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 09b3fdafe0..8b61989dc9 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,4 +1,5 @@ import inspect +import os from argparse import ArgumentParser, Namespace from typing import Union, List, Tuple, Any from pytorch_lightning.utilities import parsing @@ -7,6 +8,7 @@ from pytorch_lightning.utilities import parsing def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): """ Create an instance from CLI arguments. + Eventually use varibles from OS environement which are defined as "PL__" Args: args: The parser or namespace to take arguments from. Only known arguments will be @@ -22,8 +24,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): >>> args = Trainer.parse_argparser(parser.parse_args("")) >>> trainer = Trainer.from_argparse_args(args, logger=False) """ + # fist check if any args are defined in environment for the class and set as default + if isinstance(args, ArgumentParser): args = cls.parse_argparser(args) + # if other arg passed, update parameters params = vars(args) # we only want to pass in valid Trainer args, the rest may be user specific @@ -61,6 +66,35 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp return Namespace(**modified_args) +def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: + """Parse environment arguments if they are defined. + + Example: + >>> from pytorch_lightning import Trainer + >>> parse_env_variables(Trainer) + Namespace() + >>> import os + >>> os.environ["PL_TRAINER_GPUS"] = '42' + >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' + >>> parse_env_variables(Trainer) + Namespace(gpus=42) + >>> del os.environ["PL_TRAINER_GPUS"] + """ + cls_arg_defaults = get_init_arguments_and_types(cls) + + env_args = {} + for arg_name, _, _ in cls_arg_defaults: + env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()} + val = os.environ.get(env) + if not (val is None or val == ''): + try: # converting to native types like int/float/bool + val = eval(val) + except Exception: + pass + env_args[arg_name] = val + return Namespace(**env_args) + + def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: r"""Scans the Trainer signature and returns argument names, types and default values. diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py new file mode 100644 index 0000000000..b34a419a87 --- /dev/null +++ b/tests/trainer/flags/test_env_vars.py @@ -0,0 +1,28 @@ +import os + +from pytorch_lightning import Trainer + + +def test_passing_env_variables(tmpdir): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.logger is not None + assert trainer.max_steps is None + trainer = Trainer(False, max_steps=42) + assert trainer.logger is None + assert trainer.max_steps == 42 + + os.environ['PL_TRAINER_LOGGER'] = 'False' + os.environ['PL_TRAINER_MAX_STEPS'] = '7' + trainer = Trainer() + assert trainer.logger is None + assert trainer.max_steps == 7 + + os.environ['PL_TRAINER_LOGGER'] = 'True' + trainer = Trainer(False, max_steps=42) + assert trainer.logger is not None + assert trainer.max_steps == 7 + + # this has to be cleaned + del os.environ['PL_TRAINER_LOGGER'] + del os.environ['PL_TRAINER_MAX_STEPS']