Moved `env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` (#10501)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
8ea39d2c8f
commit
ce0a977742
|
@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
|
||||
|
||||
|
||||
- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501))
|
||||
|
||||
|
||||
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))
|
||||
|
||||
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables
|
||||
|
||||
|
||||
def _defaults_from_env_vars(fn: Callable) -> Callable:
|
||||
"""Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should
|
||||
be moved automatically to the correct device."""
|
||||
|
||||
@wraps(fn)
|
||||
def insert_env_defaults(self, *args, **kwargs):
|
||||
cls = self.__class__ # get the class
|
||||
if args: # inace any args passed move them to kwargs
|
||||
# parse only the argument names
|
||||
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
|
||||
# convert args to kwargs
|
||||
kwargs.update(dict(zip(cls_arg_names, args)))
|
||||
env_variables = vars(parse_env_variables(cls))
|
||||
# update the kwargs by env variables
|
||||
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
|
||||
|
||||
# all args were already moved to kwargs
|
||||
return fn(self, **kwargs)
|
||||
|
||||
return insert_env_defaults
|
|
@ -53,7 +53,6 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import Accelerat
|
|||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
|
||||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
|
||||
|
@ -75,6 +74,7 @@ from pytorch_lightning.utilities import (
|
|||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
_defaults_from_env_vars,
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
parse_argparser,
|
||||
|
|
|
@ -16,6 +16,7 @@ import os
|
|||
from abc import ABC
|
||||
from argparse import _ArgumentGroup, ArgumentParser, Namespace
|
||||
from contextlib import suppress
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
|
|||
return int(x)
|
||||
except ValueError:
|
||||
return x
|
||||
|
||||
|
||||
def _defaults_from_env_vars(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
cls = self.__class__ # get the class
|
||||
if args: # in case any args passed move them to kwargs
|
||||
# parse only the argument names
|
||||
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
|
||||
# convert args to kwargs
|
||||
kwargs.update(dict(zip(cls_arg_names, args)))
|
||||
env_variables = vars(parse_env_variables(cls))
|
||||
# update the kwargs by env variables
|
||||
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
|
||||
|
||||
# all args were already moved to kwargs
|
||||
return fn(self, **kwargs)
|
||||
|
||||
return insert_env_defaults
|
||||
|
|
Loading…
Reference in New Issue