Deprecate `TrainerOptimizersMixin` and move functionality to `core/optimizer.py` (#11155)
This commit is contained in:
parent
81301dbba7
commit
a6a28e08d2
|
@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))
|
||||
|
||||
|
||||
- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
|
||||
|
||||
|
||||
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
|
||||
|
||||
### Removed
|
||||
|
@ -351,6 +354,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
|
||||
|
||||
|
||||
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
|
||||
|
||||
|
||||
## [1.5.7] - 2021-12-21
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from torch.optim.swa_utils import SWALR
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
|
|
@ -11,14 +11,17 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -54,7 +57,8 @@ class LightningOptimizer:
|
|||
return self._optimizer
|
||||
|
||||
def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
|
||||
self._trainer = proxy(trainer)
|
||||
# check if trainer is already of type weakproxy since we can't call proxy on a weakproxy
|
||||
self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer)
|
||||
for opt_idx, opt in enumerate(trainer.optimizers):
|
||||
if opt == self._optimizer:
|
||||
self._optimizer_idx = opt_idx
|
||||
|
@ -162,3 +166,227 @@ class LightningOptimizer:
|
|||
assert trainer is not None
|
||||
with trainer.profiler.profile(profiler_action):
|
||||
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
|
||||
|
||||
|
||||
def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
|
||||
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
|
||||
model.trainer._lightning_optimizers = None
|
||||
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
|
||||
|
||||
if optim_conf is None:
|
||||
rank_zero_warn(
|
||||
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
|
||||
)
|
||||
optim_conf = _MockOptimizer()
|
||||
|
||||
optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
|
||||
lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization)
|
||||
_validate_scheduler_optimizer(optimizers, lr_schedulers)
|
||||
return optimizers, lr_schedulers, optimizer_frequencies
|
||||
|
||||
|
||||
def _configure_optimizers(
|
||||
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
|
||||
) -> Tuple[List, List, List, Optional[str]]:
|
||||
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
|
||||
monitor = None
|
||||
|
||||
# single output, single optimizer
|
||||
if isinstance(optim_conf, Optimizer):
|
||||
optimizers = [optim_conf]
|
||||
# two lists, optimizer + lr schedulers
|
||||
elif (
|
||||
isinstance(optim_conf, (list, tuple))
|
||||
and len(optim_conf) == 2
|
||||
and isinstance(optim_conf[0], list)
|
||||
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
|
||||
):
|
||||
opt, sch = optim_conf
|
||||
optimizers = opt
|
||||
lr_schedulers = sch if isinstance(sch, list) else [sch]
|
||||
# single dictionary
|
||||
elif isinstance(optim_conf, dict):
|
||||
_validate_optim_conf(optim_conf)
|
||||
optimizers = [optim_conf["optimizer"]]
|
||||
monitor = optim_conf.get("monitor", None)
|
||||
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
|
||||
# multiple dictionaries
|
||||
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
|
||||
for opt_dict in optim_conf:
|
||||
_validate_optim_conf(opt_dict)
|
||||
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
|
||||
scheduler_dict = (
|
||||
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
|
||||
if isinstance(scheduler, dict)
|
||||
else {"scheduler": scheduler, "opt_idx": opt_idx}
|
||||
)
|
||||
|
||||
lr_schedulers = [
|
||||
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
|
||||
for opt_idx, opt_dict in enumerate(optim_conf)
|
||||
if "lr_scheduler" in opt_dict
|
||||
]
|
||||
optimizer_frequencies = [
|
||||
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
|
||||
]
|
||||
# assert that if frequencies are present, they are given for all optimizers
|
||||
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
|
||||
raise ValueError("A frequency must be given to each optimizer.")
|
||||
# single list or tuple, multiple optimizer
|
||||
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
|
||||
optimizers = list(optim_conf)
|
||||
# unknown configuration
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"Unknown configuration for model optimizers."
|
||||
" Output from `model.configure_optimizers()` should be one of:\n"
|
||||
" * `Optimizer`\n"
|
||||
" * [`Optimizer`]\n"
|
||||
" * ([`Optimizer`], [`_LRScheduler`])\n"
|
||||
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
|
||||
' * A list of the previously described dict format, with an optional "frequency" key (int)'
|
||||
)
|
||||
return optimizers, lr_schedulers, optimizer_frequencies, monitor
|
||||
|
||||
|
||||
def _configure_schedulers(
|
||||
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert each scheduler into dict structure with relevant information."""
|
||||
lr_schedulers = []
|
||||
default_config = _get_default_scheduler_config()
|
||||
# TODO: move is_manual_optimization check out of for loop
|
||||
for scheduler in schedulers:
|
||||
if is_manual_optimization:
|
||||
if isinstance(scheduler, dict):
|
||||
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
|
||||
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
|
||||
|
||||
if keys_to_warn:
|
||||
rank_zero_warn(
|
||||
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
|
||||
" You need to call `lr_scheduler.step()` manually in manual optimization.",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
|
||||
lr_schedulers.append({**default_config, **scheduler})
|
||||
else:
|
||||
lr_schedulers.append({**default_config, "scheduler": scheduler})
|
||||
else:
|
||||
if isinstance(scheduler, dict):
|
||||
# check provided keys
|
||||
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
|
||||
if extra_keys:
|
||||
rank_zero_warn(
|
||||
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
|
||||
)
|
||||
if "scheduler" not in scheduler:
|
||||
raise MisconfigurationException(
|
||||
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
|
||||
)
|
||||
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
|
||||
raise MisconfigurationException(
|
||||
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
|
||||
f' but is "{scheduler["interval"]}"'
|
||||
)
|
||||
scheduler["reduce_on_plateau"] = isinstance(
|
||||
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
|
||||
)
|
||||
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
|
||||
raise MisconfigurationException(
|
||||
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
|
||||
' For example: {"optimizer": optimizer, "lr_scheduler":'
|
||||
' {"scheduler": scheduler, "monitor": "your_loss"}}'
|
||||
)
|
||||
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
|
||||
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
|
||||
rank_zero_warn(
|
||||
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
|
||||
" Are you sure you didn't mean 'interval': 'step'?",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
lr_schedulers.append({**default_config, **scheduler})
|
||||
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
if monitor is None:
|
||||
raise MisconfigurationException(
|
||||
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
|
||||
" scheduler is used. For example:"
|
||||
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
||||
)
|
||||
lr_schedulers.append(
|
||||
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
|
||||
)
|
||||
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
|
||||
lr_schedulers.append({**default_config, "scheduler": scheduler})
|
||||
else:
|
||||
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
|
||||
return lr_schedulers
|
||||
|
||||
|
||||
def _get_default_scheduler_config() -> Dict[str, Any]:
|
||||
return {
|
||||
"scheduler": None,
|
||||
"name": None, # no custom name
|
||||
"interval": "epoch", # after epoch is over
|
||||
"frequency": 1, # every epoch/batch
|
||||
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
|
||||
"monitor": None, # value to monitor for ReduceLROnPlateau
|
||||
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
|
||||
}
|
||||
|
||||
|
||||
def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
|
||||
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
|
||||
raise MisconfigurationException(
|
||||
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
|
||||
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
|
||||
extra_keys = optim_conf.keys() - valid_keys
|
||||
if extra_keys:
|
||||
rank_zero_warn(
|
||||
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None:
|
||||
def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
|
||||
if not isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = LightningOptimizer(optimizer) # type: ignore [assignment]
|
||||
optimizer._on_trainer_init(trainer)
|
||||
return optimizer # type: ignore [return-value]
|
||||
|
||||
trainer._lightning_optimizers = { # type: ignore [assignment]
|
||||
opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers)
|
||||
}
|
||||
|
||||
|
||||
class _MockOptimizer(Optimizer):
|
||||
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
|
||||
`configure_optimizers`."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([torch.zeros(1)], {})
|
||||
|
||||
def add_param_group(self, param_group: Dict[Any, Any]) -> None:
|
||||
pass # Do Nothing
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
|
||||
pass # Do Nothing
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return {} # Return Empty
|
||||
|
||||
def step(self, closure: Callable = None) -> None:
|
||||
if closure is not None:
|
||||
closure()
|
||||
|
||||
def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
|
||||
pass # Do Nothing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "No Optimizer"
|
||||
|
|
|
@ -31,7 +31,7 @@ from torch.nn import Module
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
@ -347,7 +347,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
del optimizer
|
||||
trainer = self.lightning_module.trainer
|
||||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
_convert_to_lightning_optimizers(trainer)
|
||||
|
||||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
|
|
|
@ -27,11 +27,11 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -446,9 +446,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
self._initialize_deepspeed_inference(model)
|
||||
|
||||
def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
|
||||
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
|
||||
self.lightning_module
|
||||
)
|
||||
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
|
||||
if len(optimizers) > 1 or len(schedulers) > 1:
|
||||
raise MisconfigurationException(
|
||||
"DeepSpeed currently only supports single optimizer, single optional scheduler."
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||
|
@ -50,7 +50,7 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
optimizers=trainer.optimizers,
|
||||
)
|
||||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
_convert_to_lightning_optimizers(trainer)
|
||||
|
||||
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
|
||||
"""Wraps the model and optimizers with fairscale components.
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
|
||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.plugins import TorchCheckpointIO
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -377,7 +378,7 @@ class Strategy(ABC):
|
|||
return dataloader
|
||||
|
||||
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
|
||||
return trainer.init_optimizers(model)
|
||||
return _init_optimizers_and_lr_schedulers(model)
|
||||
|
||||
@property
|
||||
def restore_checkpoint_after_setup(self) -> bool:
|
||||
|
|
|
@ -13,239 +13,43 @@
|
|||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
from pytorch_lightning.core.optimizer import (
|
||||
_convert_to_lightning_optimizers,
|
||||
_init_optimizers_and_lr_schedulers,
|
||||
LightningOptimizer,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
|
||||
|
||||
class TrainerOptimizersMixin(ABC):
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
The `TrainerOptimizersMixin` was deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
|
||||
_lightning_optimizers: Optional[List[LightningOptimizer]]
|
||||
|
||||
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8."
|
||||
)
|
||||
pl_module = self.lightning_module or model
|
||||
self._lightning_optimizers = None
|
||||
optim_conf = self._call_lightning_module_hook("configure_optimizers", pl_module=pl_module)
|
||||
if optim_conf is None:
|
||||
rank_zero_warn(
|
||||
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
|
||||
)
|
||||
optim_conf = _MockOptimizer()
|
||||
|
||||
optimizers, lr_schedulers, optimizer_frequencies, monitor = self._configure_optimizers(optim_conf)
|
||||
lr_schedulers = self._configure_schedulers(lr_schedulers, monitor, not pl_module.automatic_optimization)
|
||||
_validate_scheduler_optimizer(optimizers, lr_schedulers)
|
||||
return optimizers, lr_schedulers, optimizer_frequencies
|
||||
|
||||
@staticmethod
|
||||
def _configure_optimizers(
|
||||
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
|
||||
) -> Tuple[List, List, List, Optional[str]]:
|
||||
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
|
||||
monitor = None
|
||||
|
||||
# single output, single optimizer
|
||||
if isinstance(optim_conf, Optimizer):
|
||||
optimizers = [optim_conf]
|
||||
# two lists, optimizer + lr schedulers
|
||||
elif (
|
||||
isinstance(optim_conf, (list, tuple))
|
||||
and len(optim_conf) == 2
|
||||
and isinstance(optim_conf[0], list)
|
||||
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
|
||||
):
|
||||
opt, sch = optim_conf
|
||||
optimizers = opt
|
||||
lr_schedulers = sch if isinstance(sch, list) else [sch]
|
||||
# single dictionary
|
||||
elif isinstance(optim_conf, dict):
|
||||
_validate_optim_conf(optim_conf)
|
||||
optimizers = [optim_conf["optimizer"]]
|
||||
monitor = optim_conf.get("monitor", None)
|
||||
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
|
||||
# multiple dictionaries
|
||||
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
|
||||
for opt_dict in optim_conf:
|
||||
_validate_optim_conf(opt_dict)
|
||||
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
|
||||
scheduler_dict = (
|
||||
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
|
||||
if isinstance(scheduler, dict)
|
||||
else {"scheduler": scheduler, "opt_idx": opt_idx}
|
||||
)
|
||||
|
||||
lr_schedulers = [
|
||||
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
|
||||
for opt_idx, opt_dict in enumerate(optim_conf)
|
||||
if "lr_scheduler" in opt_dict
|
||||
]
|
||||
optimizer_frequencies = [
|
||||
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
|
||||
]
|
||||
# assert that if frequencies are present, they are given for all optimizers
|
||||
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
|
||||
raise ValueError("A frequency must be given to each optimizer.")
|
||||
# single list or tuple, multiple optimizer
|
||||
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
|
||||
optimizers = list(optim_conf)
|
||||
# unknown configuration
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"Unknown configuration for model optimizers."
|
||||
" Output from `model.configure_optimizers()` should either be:\n"
|
||||
" * `torch.optim.Optimizer`\n"
|
||||
" * [`torch.optim.Optimizer`]\n"
|
||||
" * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n"
|
||||
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
|
||||
' * A list of the previously described dict format, with an optional "frequency" key (int)'
|
||||
)
|
||||
return optimizers, lr_schedulers, optimizer_frequencies, monitor
|
||||
return _init_optimizers_and_lr_schedulers(pl_module)
|
||||
|
||||
def convert_to_lightning_optimizers(self):
|
||||
def _convert_to_lightning_optimizer(trainer, optimizer):
|
||||
if not isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = LightningOptimizer(optimizer)
|
||||
optimizer._on_trainer_init(trainer)
|
||||
return optimizer
|
||||
|
||||
self._lightning_optimizers = {
|
||||
opt_idx: _convert_to_lightning_optimizer(self, opt) for opt_idx, opt in enumerate(self.optimizers)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _configure_schedulers(
|
||||
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
|
||||
) -> List[LRSchedulerConfig]:
|
||||
"""Convert each scheduler into dict structure with relevant information."""
|
||||
lr_schedulers = []
|
||||
default_config = _get_default_scheduler_config()
|
||||
for scheduler in schedulers:
|
||||
if is_manual_optimization:
|
||||
if isinstance(scheduler, dict):
|
||||
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
|
||||
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
|
||||
|
||||
if keys_to_warn:
|
||||
rank_zero_warn(
|
||||
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
|
||||
" You need to call `lr_scheduler.step()` manually in manual optimization.",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
|
||||
lr_schedulers.append({**default_config, **scheduler})
|
||||
else:
|
||||
lr_schedulers.append({**default_config, "scheduler": scheduler})
|
||||
else:
|
||||
if isinstance(scheduler, dict):
|
||||
# check provided keys
|
||||
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
|
||||
if extra_keys:
|
||||
rank_zero_warn(
|
||||
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
|
||||
)
|
||||
if "scheduler" not in scheduler:
|
||||
raise MisconfigurationException(
|
||||
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
|
||||
)
|
||||
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
|
||||
raise MisconfigurationException(
|
||||
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
|
||||
f' but is "{scheduler["interval"]}"'
|
||||
)
|
||||
scheduler["reduce_on_plateau"] = isinstance(
|
||||
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
|
||||
)
|
||||
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
|
||||
raise MisconfigurationException(
|
||||
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
|
||||
' For example: {"optimizer": optimizer, "lr_scheduler":'
|
||||
' {"scheduler": scheduler, "monitor": "your_loss"}}'
|
||||
)
|
||||
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
|
||||
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
|
||||
rank_zero_warn(
|
||||
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
|
||||
" Are you sure you didn't mean 'interval': 'step'?",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
lr_schedulers.append({**default_config, **scheduler})
|
||||
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
if monitor is None:
|
||||
raise MisconfigurationException(
|
||||
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
|
||||
" scheduler is used. For example:"
|
||||
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
||||
)
|
||||
lr_schedulers.append(
|
||||
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
|
||||
)
|
||||
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
|
||||
lr_schedulers.append({**default_config, "scheduler": scheduler})
|
||||
else:
|
||||
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
|
||||
return lr_schedulers
|
||||
|
||||
|
||||
class _MockOptimizer(Optimizer):
|
||||
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
|
||||
`configure_optimizers`."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__([torch.zeros(1)], {})
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
pass # Do Nothing
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass # Do Nothing
|
||||
|
||||
def state_dict(self):
|
||||
return {} # Return Empty
|
||||
|
||||
def step(self, closure=None):
|
||||
if closure is not None:
|
||||
closure()
|
||||
|
||||
def zero_grad(self):
|
||||
pass # Do Nothing
|
||||
|
||||
def __repr__(self):
|
||||
return "No Optimizer"
|
||||
|
||||
|
||||
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
|
||||
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
|
||||
extra_keys = optim_conf.keys() - valid_keys
|
||||
if extra_keys:
|
||||
rank_zero_warn(
|
||||
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
"`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in "
|
||||
"v1.8."
|
||||
)
|
||||
|
||||
|
||||
def _validate_scheduler_optimizer(optimizers, lr_schedulers):
|
||||
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
|
||||
raise MisconfigurationException(
|
||||
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
|
||||
)
|
||||
|
||||
|
||||
def _get_default_scheduler_config() -> Dict[str, Any]:
|
||||
return {
|
||||
"scheduler": None,
|
||||
"name": None, # no custom name
|
||||
"interval": "epoch", # after epoch is over
|
||||
"frequency": 1, # every epoch/batch
|
||||
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
|
||||
"monitor": None, # value to monitor for ReduceLROnPlateau
|
||||
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
|
||||
}
|
||||
_convert_to_lightning_optimizers(self)
|
||||
|
|
|
@ -33,7 +33,7 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
|
|||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
|
@ -480,8 +480,7 @@ class Trainer(
|
|||
# default .predict() loop
|
||||
self.predict_loop = PredictionLoop()
|
||||
|
||||
# Needed because of LightningOptimizer
|
||||
self._lightning_optimizers = None
|
||||
self._lightning_optimizers: Optional[Dict[int, LightningOptimizer]] = None
|
||||
|
||||
# .validate() and .test() set this when they load a checkpoint
|
||||
self.validated_ckpt_path: Optional[str] = None
|
||||
|
@ -1926,9 +1925,9 @@ class Trainer(
|
|||
return SLURMEnvironment.job_id()
|
||||
|
||||
@property
|
||||
def lightning_optimizers(self) -> List[LightningOptimizer]:
|
||||
def lightning_optimizers(self) -> Dict[int, LightningOptimizer]:
|
||||
if self._lightning_optimizers is None:
|
||||
self.convert_to_lightning_optimizers()
|
||||
_convert_to_lightning_optimizers(self)
|
||||
return self._lightning_optimizers
|
||||
|
||||
@property
|
||||
|
|
|
@ -24,8 +24,8 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -98,15 +98,15 @@ class _LRFinder:
|
|||
self.results = {}
|
||||
self._total_batch_idx = 0 # for debug purpose
|
||||
|
||||
def _exchange_scheduler(self, trainer: "pl.Trainer"):
|
||||
"""Decorate `trainer.init_optimizers` method such that it returns the users originally specified optimizer
|
||||
together with a new scheduler that that takes care of the learning rate search."""
|
||||
init_optimizers = trainer.init_optimizers
|
||||
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
|
||||
"""Decorate `trainer.strategy.init_optimizers` method such that it returns the user's originally specified
|
||||
optimizer together with a new scheduler that that takes care of the learning rate search."""
|
||||
init_optimizers = trainer.strategy.init_optimizers
|
||||
|
||||
@wraps(init_optimizers)
|
||||
def func(model):
|
||||
# Decide the structure of the output from init_optimizers
|
||||
optimizers, _, _ = init_optimizers(model)
|
||||
def func(trainer, model):
|
||||
# Decide the structure of the output from trainer.strategy.init_optimizers
|
||||
optimizers, _, _ = init_optimizers(trainer, model)
|
||||
|
||||
if len(optimizers) != 1:
|
||||
raise MisconfigurationException(
|
||||
|
@ -232,7 +232,7 @@ def lr_find(
|
|||
trainer.save_checkpoint(str(save_path))
|
||||
|
||||
# Configure optimizer and scheduler
|
||||
trainer.init_optimizers = lr_finder._exchange_scheduler(trainer)
|
||||
trainer.strategy.init_optimizers = lr_finder._exchange_scheduler(trainer, model)
|
||||
|
||||
# Fit, lr & loss logged in callback
|
||||
trainer.tuner._run(model)
|
||||
|
@ -278,7 +278,6 @@ def __lr_finder_dump_params(trainer, model):
|
|||
"max_steps": trainer.max_steps,
|
||||
"checkpoint_callback": trainer.checkpoint_callback,
|
||||
"current_epoch": trainer.current_epoch,
|
||||
"init_optimizers": trainer.init_optimizers,
|
||||
}
|
||||
|
||||
|
||||
|
@ -289,7 +288,6 @@ def __lr_finder_restore_params(trainer, model):
|
|||
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
|
||||
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
|
||||
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
|
||||
trainer.init_optimizers = trainer.__dumped_params["init_optimizers"]
|
||||
del trainer.__dumped_params
|
||||
|
||||
|
||||
|
|
|
@ -140,6 +140,24 @@ def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
|
|||
_ = trainer.should_rank_save_checkpoint
|
||||
|
||||
|
||||
def test_v1_8_0_trainer_optimizers_mixin():
|
||||
trainer = Trainer()
|
||||
model = BoringModel()
|
||||
trainer.strategy.connect(model)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match=r"`TrainerOptimizersMixin.init_optimizers` was deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
trainer.init_optimizers(model)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match=r"`TrainerOptimizersMixin.convert_to_lightning_optimizers` was deprecated in v1.6 and will be removed in "
|
||||
"v1.8."
|
||||
):
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
|
||||
|
||||
def test_v1_8_0_deprecate_trainer_callback_hook_mixin():
|
||||
methods_with_self = [
|
||||
"on_before_accelerator_backend_setup",
|
||||
|
|
|
@ -15,7 +15,6 @@ import torch
|
|||
from torchmetrics.functional import accuracy
|
||||
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.utilities import _StrategyType
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed
|
||||
|
||||
|
@ -70,7 +69,7 @@ def run_model_test(
|
|||
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
|
||||
|
||||
# test model loading
|
||||
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
|
||||
_ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
|
||||
|
||||
# test new model accuracy
|
||||
test_loaders = model.test_dataloader() if not data else data.test_dataloader()
|
||||
|
@ -82,12 +81,6 @@ def run_model_test(
|
|||
run_model_prediction(model, dataloader, min_acc=min_acc)
|
||||
|
||||
if with_hpc:
|
||||
if trainer._distrib_type in (_StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2):
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers(
|
||||
pretrained_model
|
||||
)
|
||||
|
||||
# test HPC saving
|
||||
trainer.checkpoint_connector.hpc_save(save_dir, logger)
|
||||
# test HPC loading
|
||||
|
|
|
@ -19,7 +19,11 @@ from torch import optim
|
|||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.core.optimizer import (
|
||||
_configure_optimizers,
|
||||
_configure_schedulers,
|
||||
_init_optimizers_and_lr_schedulers,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringDataModule, BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -106,7 +110,7 @@ def test_onecyclelr_with_epoch_interval_warns():
|
|||
optimizer = optim.Adam(model.parameters())
|
||||
lr_scheduler = {"scheduler": optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, total_steps=3)}
|
||||
with pytest.warns(RuntimeWarning, match="Are you sure you didn't mean 'interval': 'step'?"):
|
||||
TrainerOptimizersMixin._configure_schedulers([lr_scheduler], None, False)
|
||||
_configure_schedulers([lr_scheduler], None, False)
|
||||
|
||||
|
||||
def test_reducelronplateau_scheduling(tmpdir):
|
||||
|
@ -144,6 +148,8 @@ def test_reducelronplateau_scheduling(tmpdir):
|
|||
def test_optimizer_return_options(tmpdir):
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
model = BoringModel()
|
||||
trainer.strategy.connect(model)
|
||||
trainer.lightning_module.trainer = trainer
|
||||
|
||||
# single optimizer
|
||||
opt_a = optim.Adam(model.parameters(), lr=0.002)
|
||||
|
@ -153,18 +159,18 @@ def test_optimizer_return_options(tmpdir):
|
|||
|
||||
# single optimizer
|
||||
model.configure_optimizers = lambda: opt_a
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert len(opt) == 1 and len(lr_sched) == len(freq) == 0
|
||||
|
||||
# opt tuple
|
||||
model.configure_optimizers = lambda: (opt_a, opt_b)
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert opt == [opt_a, opt_b]
|
||||
assert len(lr_sched) == len(freq) == 0
|
||||
|
||||
# opt list
|
||||
model.configure_optimizers = lambda: [opt_a, opt_b]
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert opt == [opt_a, opt_b]
|
||||
assert len(lr_sched) == len(freq) == 0
|
||||
|
||||
|
@ -181,7 +187,7 @@ def test_optimizer_return_options(tmpdir):
|
|||
|
||||
# opt tuple of 2 lists
|
||||
model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert len(opt) == len(lr_sched) == 1
|
||||
assert len(freq) == 0
|
||||
assert opt[0] == opt_a
|
||||
|
@ -189,7 +195,7 @@ def test_optimizer_return_options(tmpdir):
|
|||
|
||||
# opt tuple of 1 list
|
||||
model.configure_optimizers = lambda: ([opt_a], scheduler_a)
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert len(opt) == len(lr_sched) == 1
|
||||
assert len(freq) == 0
|
||||
assert opt[0] == opt_a
|
||||
|
@ -197,7 +203,7 @@ def test_optimizer_return_options(tmpdir):
|
|||
|
||||
# opt single dictionary
|
||||
model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a}
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert len(opt) == len(lr_sched) == 1
|
||||
assert len(freq) == 0
|
||||
assert opt[0] == opt_a
|
||||
|
@ -208,7 +214,7 @@ def test_optimizer_return_options(tmpdir):
|
|||
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
|
||||
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
|
||||
)
|
||||
opt, lr_sched, freq = trainer.init_optimizers(model)
|
||||
opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model)
|
||||
assert len(opt) == len(lr_sched) == len(freq) == 2
|
||||
assert opt[0] == opt_a
|
||||
ref_lr_sched["opt_idx"] = 0
|
||||
|
@ -436,7 +442,7 @@ def test_optimizer_config_dict_with_extra_keys_warns(tmpdir):
|
|||
"bar": 2,
|
||||
}
|
||||
with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"):
|
||||
TrainerOptimizersMixin._configure_optimizers(optim_conf)
|
||||
_configure_optimizers(optim_conf)
|
||||
|
||||
|
||||
def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir):
|
||||
|
@ -451,7 +457,7 @@ def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir):
|
|||
{"optimizer": optimizer2, "lr_scheduler": lr_scheduler_config_2, "foo": 1, "bar": 2},
|
||||
]
|
||||
with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"):
|
||||
TrainerOptimizersMixin._configure_optimizers(optim_conf)
|
||||
_configure_optimizers(optim_conf)
|
||||
|
||||
|
||||
def test_lr_scheduler_with_unknown_interval_raises(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue