From ce0a977742872736f150f6d37ecaa301a318668f Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Nov 2021 13:36:35 +0530 Subject: [PATCH] Moved `env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` (#10501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 ++ .../trainer/connectors/env_vars_connector.py | 40 ------------------- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/utilities/argparse.py | 20 ++++++++++ 4 files changed, 24 insertions(+), 41 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/env_vars_connector.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d8..bba1cd319e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501)) + + - Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426)) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py deleted file mode 100644 index 4d130ca8e7..0000000000 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import wraps -from typing import Callable - -from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables - - -def _defaults_from_env_vars(fn: Callable) -> Callable: - """Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should - be moved automatically to the correct device.""" - - @wraps(fn) - def insert_env_defaults(self, *args, **kwargs): - cls = self.__class__ # get the class - if args: # inace any args passed move them to kwargs - # parse only the argument names - cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] - # convert args to kwargs - kwargs.update(dict(zip(cls_arg_names, args))) - env_variables = vars(parse_env_variables(cls)) - # update the kwargs by env variables - kwargs = dict(list(env_variables.items()) + list(kwargs.items())) - - # all args were already moved to kwargs - return fn(self, **kwargs) - - return insert_env_defaults diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f6e987635..38eb44bced 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,6 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import Accelerat from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector @@ -75,6 +74,7 @@ from pytorch_lightning.utilities import ( rank_zero_warn, ) from pytorch_lightning.utilities.argparse import ( + _defaults_from_env_vars, add_argparse_args, from_argparse_args, parse_argparser, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 61443bea07..ad707b0360 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -16,6 +16,7 @@ import os from abc import ABC from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress +from functools import wraps from typing import Any, Callable, Dict, List, Tuple, Type, Union import pytorch_lightning as pl @@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]: return int(x) except ValueError: return x + + +def _defaults_from_env_vars(fn: Callable) -> Callable: + @wraps(fn) + def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: + cls = self.__class__ # get the class + if args: # in case any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] + # convert args to kwargs + kwargs.update(dict(zip(cls_arg_names, args))) + env_variables = vars(parse_env_variables(cls)) + # update the kwargs by env variables + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + + # all args were already moved to kwargs + return fn(self, **kwargs) + + return insert_env_defaults