Changed LearningRateLogger to LearningRateMonitor (#3251)
* Change LearningRateLogger to LearningRateMonitor * file rename * docs * add LearningRateLogger with deprecation warning * deprecated LearningRateLogger * move deprecation check * chlog Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
parent
2d8c1b7c54
commit
4a22fca524
|
@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
|
||||
|
||||
- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -120,7 +120,7 @@ Lightning has a few built-in callbacks.
|
|||
|
||||
----------------
|
||||
|
||||
.. automodule:: pytorch_lightning.callbacks.lr_logger
|
||||
.. automodule:: pytorch_lightning.callbacks.lr_monitor
|
||||
:noindex:
|
||||
:exclude-members:
|
||||
_extract_lr,
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
|
||||
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
|
||||
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
|
||||
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
|
||||
|
||||
__all__ = [
|
||||
'Callback',
|
||||
'EarlyStopping',
|
||||
'ModelCheckpoint',
|
||||
'GPUStatsMonitor',
|
||||
'GradientAccumulationScheduler',
|
||||
'LearningRateLogger',
|
||||
'ProgressBarBase',
|
||||
'LearningRateMonitor',
|
||||
'ModelCheckpoint',
|
||||
'ProgressBar',
|
||||
'GPUStatsMonitor'
|
||||
'ProgressBarBase',
|
||||
]
|
||||
|
|
|
@ -1,156 +1,9 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""
|
||||
|
||||
Learning Rate Logger
|
||||
====================
|
||||
|
||||
Log learning rate for lr schedulers during training
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class LearningRateLogger(Callback):
|
||||
r"""
|
||||
Automatically logs learning rate for learning rate schedulers during training.
|
||||
|
||||
Args:
|
||||
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
|
||||
at the same interval, set to `None` to log at individual interval
|
||||
according to the `interval` key of each scheduler. Defaults to ``None``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import LearningRateLogger
|
||||
>>> lr_logger = LearningRateLogger(logging_interval='step')
|
||||
>>> trainer = Trainer(callbacks=[lr_logger])
|
||||
|
||||
Logging names are automatically determined based on optimizer class name.
|
||||
In case of multiple optimizers of same type, they will be named `Adam`,
|
||||
`Adam-1` etc. If a optimizer has multiple parameter groups they will
|
||||
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
|
||||
`name` keyword in the construction of the learning rate schdulers
|
||||
|
||||
Example::
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = torch.optim.Adam(...)
|
||||
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
|
||||
'name': 'my_logging_name'}
|
||||
return [optimizer], [lr_scheduler]
|
||||
"""
|
||||
def __init__(self, logging_interval: Optional[str] = None):
|
||||
if logging_interval not in (None, 'step', 'epoch'):
|
||||
raise MisconfigurationException(
|
||||
'logging_interval should be `step` or `epoch` or `None`.'
|
||||
)
|
||||
|
||||
self.logging_interval = logging_interval
|
||||
self.lrs = None
|
||||
self.lr_sch_names = []
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
""" Called before training, determines unique names for all lr
|
||||
schedulers in the case of multiple of the same type or in
|
||||
the case of multiple parameter groups
|
||||
"""
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException(
|
||||
'Cannot use LearningRateLogger callback with Trainer that has no logger.'
|
||||
)
|
||||
|
||||
if not trainer.lr_schedulers:
|
||||
rank_zero_warn(
|
||||
'You are using LearningRateLogger callback with models that'
|
||||
' have no learning rate schedulers. Please see documentation'
|
||||
' for `configure_optimizers` method.', RuntimeWarning
|
||||
)
|
||||
|
||||
# Find names for schedulers
|
||||
names = self._find_names(trainer.lr_schedulers)
|
||||
|
||||
# Initialize for storing values
|
||||
self.lrs = {name: [] for name in names}
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
if self.logging_interval != 'epoch':
|
||||
interval = 'step' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
if self.logging_interval != 'step':
|
||||
interval = 'epoch' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
|
||||
|
||||
def _extract_lr(self, trainer, interval):
|
||||
latest_stat = {}
|
||||
|
||||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
|
||||
if scheduler['interval'] == interval or interval == 'any':
|
||||
param_groups = scheduler['scheduler'].optimizer.param_groups
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
lr, key = pg['lr'], f'{name}/pg{i + 1}'
|
||||
self.lrs[key].append(lr)
|
||||
latest_stat[key] = lr
|
||||
else:
|
||||
self.lrs[name].append(param_groups[0]['lr'])
|
||||
latest_stat[name] = param_groups[0]['lr']
|
||||
|
||||
return latest_stat
|
||||
|
||||
def _find_names(self, lr_schedulers):
|
||||
# Create uniqe names in the case we have multiple of the same learning
|
||||
# rate schduler + multiple parameter groups
|
||||
names = []
|
||||
for scheduler in lr_schedulers:
|
||||
sch = scheduler['scheduler']
|
||||
if 'name' in scheduler:
|
||||
name = scheduler['name']
|
||||
else:
|
||||
opt_name = 'lr-' + sch.optimizer.__class__.__name__
|
||||
i, name = 1, opt_name
|
||||
# Multiple schduler of the same type
|
||||
while True:
|
||||
if name not in names:
|
||||
break
|
||||
i, name = i + 1, f'{opt_name}-{i}'
|
||||
|
||||
# Multiple param groups for the same schduler
|
||||
param_groups = sch.optimizer.param_groups
|
||||
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
temp = f'{name}/pg{i + 1}'
|
||||
names.append(temp)
|
||||
else:
|
||||
names.append(name)
|
||||
|
||||
self.lr_sch_names.append(name)
|
||||
|
||||
return names
|
||||
class LearningRateLogger(LearningRateMonitor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
rank_zero_warn("`LearningRateLogger` is now `LearningRateMonitor`"
|
||||
" and this will be removed in v0.11.0", DeprecationWarning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""
|
||||
|
||||
Learning Rate Monitor
|
||||
=====================
|
||||
|
||||
Monitor and logs learning rate for lr schedulers during training.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class LearningRateMonitor(Callback):
|
||||
r"""
|
||||
Automatically monitor and logs learning rate for learning rate schedulers during training.
|
||||
|
||||
Args:
|
||||
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
|
||||
at the same interval, set to `None` to log at individual interval
|
||||
according to the `interval` key of each scheduler. Defaults to ``None``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import LearningRateMonitor
|
||||
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
|
||||
>>> trainer = Trainer(callbacks=[lr_monitor])
|
||||
|
||||
Logging names are automatically determined based on optimizer class name.
|
||||
In case of multiple optimizers of same type, they will be named `Adam`,
|
||||
`Adam-1` etc. If a optimizer has multiple parameter groups they will
|
||||
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
|
||||
`name` keyword in the construction of the learning rate schdulers
|
||||
|
||||
Example::
|
||||
|
||||
def configure_optimizer(self):
|
||||
optimizer = torch.optim.Adam(...)
|
||||
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
|
||||
'name': 'my_logging_name'}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
"""
|
||||
def __init__(self, logging_interval: Optional[str] = None):
|
||||
if logging_interval not in (None, 'step', 'epoch'):
|
||||
raise MisconfigurationException(
|
||||
'logging_interval should be `step` or `epoch` or `None`.'
|
||||
)
|
||||
|
||||
self.logging_interval = logging_interval
|
||||
self.lrs = None
|
||||
self.lr_sch_names = []
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
"""
|
||||
Called before training, determines unique names for all lr
|
||||
schedulers in the case of multiple of the same type or in
|
||||
the case of multiple parameter groups
|
||||
"""
|
||||
if not trainer.logger:
|
||||
raise MisconfigurationException(
|
||||
'Cannot use LearningRateMonitor callback with Trainer that has no logger.'
|
||||
)
|
||||
|
||||
if not trainer.lr_schedulers:
|
||||
rank_zero_warn(
|
||||
'You are using LearningRateMonitor callback with models that'
|
||||
' have no learning rate schedulers. Please see documentation'
|
||||
' for `configure_optimizers` method.', RuntimeWarning
|
||||
)
|
||||
|
||||
# Find names for schedulers
|
||||
names = self._find_names(trainer.lr_schedulers)
|
||||
|
||||
# Initialize for storing values
|
||||
self.lrs = {name: [] for name in names}
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
if self.logging_interval != 'epoch':
|
||||
interval = 'step' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
if self.logging_interval != 'step':
|
||||
interval = 'epoch' if self.logging_interval is None else 'any'
|
||||
latest_stat = self._extract_lr(trainer, interval)
|
||||
|
||||
if trainer.logger is not None and latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
|
||||
|
||||
def _extract_lr(self, trainer, interval):
|
||||
latest_stat = {}
|
||||
|
||||
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
|
||||
if scheduler['interval'] == interval or interval == 'any':
|
||||
param_groups = scheduler['scheduler'].optimizer.param_groups
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
lr, key = pg['lr'], f'{name}/pg{i + 1}'
|
||||
self.lrs[key].append(lr)
|
||||
latest_stat[key] = lr
|
||||
else:
|
||||
self.lrs[name].append(param_groups[0]['lr'])
|
||||
latest_stat[name] = param_groups[0]['lr']
|
||||
|
||||
return latest_stat
|
||||
|
||||
def _find_names(self, lr_schedulers):
|
||||
# Create uniqe names in the case we have multiple of the same learning
|
||||
# rate schduler + multiple parameter groups
|
||||
names = []
|
||||
for scheduler in lr_schedulers:
|
||||
sch = scheduler['scheduler']
|
||||
if 'name' in scheduler:
|
||||
name = scheduler['name']
|
||||
else:
|
||||
opt_name = 'lr-' + sch.optimizer.__class__.__name__
|
||||
i, name = 1, opt_name
|
||||
|
||||
# Multiple schduler of the same type
|
||||
while True:
|
||||
if name not in names:
|
||||
break
|
||||
i, name = i + 1, f'{opt_name}-{i}'
|
||||
|
||||
# Multiple param groups for the same schduler
|
||||
param_groups = sch.optimizer.param_groups
|
||||
|
||||
if len(param_groups) != 1:
|
||||
for i, pg in enumerate(param_groups):
|
||||
temp = f'{name}/pg{i + 1}'
|
||||
names.append(temp)
|
||||
else:
|
||||
names.append(name)
|
||||
|
||||
self.lr_sch_names.append(name)
|
||||
|
||||
return names
|
|
@ -1,78 +1,96 @@
|
|||
import pytest
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateLogger
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
import tests.base.develop_utils as tutils
|
||||
|
||||
|
||||
def test_lr_logger_single_lr(tmpdir):
|
||||
def test_lr_monitor_single_lr(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__single_scheduler
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.5,
|
||||
callbacks=[lr_logger],
|
||||
callbacks=[lr_monitor],
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
assert lr_monitor.lrs, 'No learning rates logged'
|
||||
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
|
||||
assert all([k in ['lr-Adam'] for k in lr_monitor.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
|
||||
def test_lr_logger_no_lr(tmpdir):
|
||||
def test_lr_monitor_no_lr_scheduler(tmpdir):
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.5,
|
||||
callbacks=[lr_logger],
|
||||
callbacks=[lr_monitor],
|
||||
)
|
||||
|
||||
with pytest.warns(RuntimeWarning):
|
||||
with pytest.warns(RuntimeWarning, match='have no learning rate schedulers'):
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
|
||||
def test_lr_monitor_no_logger(tmpdir):
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
callbacks=[lr_monitor],
|
||||
logger=False
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='Trainer that has no logger'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
|
||||
def test_lr_logger_multi_lrs(tmpdir, logging_interval):
|
||||
def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
|
||||
""" Test that learning rates are extracted and logged for multi lr schedulers. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
|
||||
|
||||
lr_logger = LearningRateLogger(logging_interval=logging_interval)
|
||||
lr_monitor = LearningRateMonitor(logging_interval=logging_interval)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.5,
|
||||
callbacks=[lr_logger],
|
||||
callbacks=[lr_monitor],
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
|
||||
assert lr_monitor.lrs, 'No learning rates logged'
|
||||
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of lr schedulers'
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
|
||||
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_monitor.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
||||
|
||||
if logging_interval == 'step':
|
||||
|
@ -80,30 +98,30 @@ def test_lr_logger_multi_lrs(tmpdir, logging_interval):
|
|||
if logging_interval == 'epoch':
|
||||
expected_number_logged = trainer.max_epochs
|
||||
|
||||
assert all(len(lr) == expected_number_logged for lr in lr_logger.lrs.values()), \
|
||||
assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()), \
|
||||
'Length of logged learning rates do not match the expected number'
|
||||
|
||||
|
||||
def test_lr_logger_param_groups(tmpdir):
|
||||
def test_lr_monitor_param_groups(tmpdir):
|
||||
""" Test that learning rates are extracted and logged for single lr scheduler. """
|
||||
tutils.reset_seed()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.configure_optimizers = model.configure_optimizers__param_groups
|
||||
|
||||
lr_logger = LearningRateLogger()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.5,
|
||||
callbacks=[lr_logger],
|
||||
callbacks=[lr_monitor],
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
assert lr_logger.lrs, 'No learning rates logged'
|
||||
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
|
||||
assert lr_monitor.lrs, 'No learning rates logged'
|
||||
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \
|
||||
'Number of learning rates logged does not match number of param groups'
|
||||
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
|
||||
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_monitor.lrs.keys()]), \
|
||||
'Names of learning rates not set correctly'
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
|
@ -15,6 +16,11 @@ def _soft_unimport_module(str_module):
|
|||
del sys.modules[str_module]
|
||||
|
||||
|
||||
def test_tbd_remove_in_v0_11_0_trainer():
|
||||
with pytest.deprecated_call(match='will be removed in v0.11.0'):
|
||||
lr_logger = LearningRateLogger()
|
||||
|
||||
|
||||
def test_tbd_remove_in_v0_10_0_trainer():
|
||||
rnd_val = random.random()
|
||||
with pytest.deprecated_call(match='will be removed in v0.10.0'):
|
||||
|
|
Loading…
Reference in New Issue