lightning/tests/deprecated_api/test_remove_1-7.py

554 lines
21 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.7.0."""
import os
from re import escape
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.strategies import SingleDeviceStrategy
from tests.callbacks.test_callbacks import OldStatefulCallback
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
from tests.loggers.test_base import CustomLogger
def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
model = BoringModel()
with pytest.deprecated_call(match="The `LightningModule.summarize` method is deprecated in v1.5"):
model.summarize(max_depth=1)
def test_v1_7_0_moved_model_summary_and_layer_summary(tmpdir):
_soft_unimport_module("pytorch_lightning.core.memory")
with pytest.deprecated_call(match="to `pytorch_lightning.utilities.model_summary` since v1.5"):
from pytorch_lightning.core.memory import LayerSummary, ModelSummary # noqa: F401
def test_v1_7_0_moved_get_memory_profile_and_get_gpu_memory_map(tmpdir):
_soft_unimport_module("pytorch_lightning.core.memory")
with pytest.deprecated_call(match="to `pytorch_lightning.utilities.memory` since v1.5"):
from pytorch_lightning.core.memory import get_gpu_memory_map, get_memory_profile # noqa: F401
def test_v1_7_0_deprecated_model_size():
model = BoringModel()
with pytest.deprecated_call(
match="LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7"
):
_ = model.model_size
def test_v1_7_0_datamodule_transform_properties(tmpdir):
dm = MNISTDataModule()
with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"):
dm.train_transforms = "a"
with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"):
dm.val_transforms = "b"
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
dm.test_transforms = "c"
with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"):
_ = LightningDataModule(train_transforms="a")
with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"):
_ = LightningDataModule(val_transforms="b")
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
_ = LightningDataModule(test_transforms="c")
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
_ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
def test_v1_7_0_datamodule_size_property(tmpdir):
dm = MNISTDataModule()
with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"):
dm.size()
def test_v1_7_0_datamodule_dims_property(tmpdir):
dm = MNISTDataModule()
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
_ = dm.dims
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
_ = LightningDataModule(dims=(1, 1, 1))
def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
class TestModel(BoringModel):
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=None,
fast_dev_run=True,
)
test_model = TestModel()
with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"):
trainer.fit(test_model)
standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix
assert "loss" in standard_metrics_postfix
assert "v_num" not in standard_metrics_postfix
with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"):
_ = trainer.progress_bar_dict
def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"):
_ = Trainer(prepare_data_per_node=False)
def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"):
_ = Trainer(stochastic_weight_avg=True)
@pytest.mark.parametrize("terminate_on_nan", [True, False])
def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
with pytest.deprecated_call(
match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7"
):
trainer = Trainer(terminate_on_nan=terminate_on_nan)
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False
trainer = Trainer()
with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"):
_ = trainer.terminate_on_nan
with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"):
trainer.terminate_on_nan = True
def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
def on_train_dataloader(self):
print("on_train_dataloader")
def on_val_dataloader(self):
print("on_val_dataloader")
def on_test_dataloader(self):
print("on_test_dataloader")
def on_predict_dataloader(self):
print("on_predict_dataloader")
def _run(model, task="fit"):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
getattr(trainer, task)(model)
model = CustomBoringModel()
with pytest.deprecated_call(
match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "fit")
with pytest.deprecated_call(
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "fit")
with pytest.deprecated_call(
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "validate")
with pytest.deprecated_call(
match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "test")
with pytest.deprecated_call(
match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
):
_run(model, "predict")
@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
def test_v1_7_0_test_tube_logger(_, tmpdir):
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
_ = TestTubeLogger(tmpdir)
def test_v1_7_0_on_interrupt(tmpdir):
class HandleInterruptCallback(Callback):
def on_keyboard_interrupt(self, trainer, pl_module):
print("keyboard interrupt")
model = BoringModel()
handle_interrupt_callback = HandleInterruptCallback()
trainer = Trainer(
callbacks=[handle_interrupt_callback],
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7"
):
trainer.fit(model)
def test_v1_7_0_process_position_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"):
_ = Trainer(process_position=5)
def test_v1_7_0_flush_logs_every_n_steps_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(flush_logs_every_n_steps=10\)` is deprecated in v1.5"):
_ = Trainer(flush_logs_every_n_steps=10)
class BoringCallbackDDPSpawnModel(BoringModel):
def add_to_queue(self, queue):
...
def get_from_queue(self, queue):
...
def test_v1_7_0_deprecate_add_get_queue(tmpdir):
model = BoringCallbackDDPSpawnModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"):
trainer.fit(model)
with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"):
trainer.fit(model)
def test_v1_7_0_progress_bar_refresh_rate_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(progress_bar_refresh_rate=1\)` is deprecated in v1.5"):
_ = Trainer(progress_bar_refresh_rate=1)
def test_v1_7_0_lightning_logger_base_close(tmpdir):
logger = CustomLogger()
with pytest.deprecated_call(
match="`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7."
):
logger.close()
with pytest.deprecated_call(
match="`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7."
):
logger = LoggerCollection([logger])
logger.close()
def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."):
from pytorch_lightning.distributed.dist import LightningDistributed
_ = LightningDistributed()
def test_v1_7_0_checkpoint_callback_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(checkpoint_callback=True\)` is deprecated in v1.5"):
_ = Trainer(checkpoint_callback=True)
def test_v1_7_0_old_on_train_batch_start(tmpdir):
class OldSignature(Callback):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
...
class OldSignatureModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
def test_v1_7_0_old_on_train_batch_end(tmpdir):
class OldSignature(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
...
class OldSignatureModel(BoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
class TestModel(BoringModel):
def on_post_move_to_device(self):
print("on_post_move_to_device")
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1)
with pytest.deprecated_call(
match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7"
):
trainer.fit(model)
def test_v1_7_0_deprecate_parameter_validation():
_soft_unimport_module("pytorch_lightning.core.decorators")
with pytest.deprecated_call(
match="Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5"
):
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401
def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
Trainer(accelerator="ddp_spawn")
def test_v1_7_0_passing_strategy_to_plugins_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
Trainer(plugins="ddp_spawn")
def test_v1_7_0_weights_summary_trainer(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
t = Trainer(weights_summary="full")
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=None\)` is deprecated in v1.5"):
t = Trainer(weights_summary=None)
t = Trainer(weights_summary="top")
with pytest.deprecated_call(match=r"`Trainer.weights_summary` is deprecated in v1.5"):
_ = t.weights_summary
with pytest.deprecated_call(match=r"Setting `Trainer.weights_summary` is deprecated in v1.5"):
t.weights_summary = "blah"
def test_v1_7_0_trainer_log_gpu_memory(tmpdir):
with pytest.deprecated_call(
match="Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed"
):
_ = Trainer(log_gpu_memory="min_max")
def test_v1_7_0_deprecated_slurm_job_id():
trainer = Trainer()
with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."):
trainer.slurm_job_id
@RunIf(min_gpus=1)
def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir):
with pytest.deprecated_call(match="The `GPUStatsMonitor` callback was deprecated in v1.5"):
_ = GPUStatsMonitor()
@RunIf(tpu=True)
def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"):
_ = XLAStatsMonitor()
def test_v1_7_0_progress_bar():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
_ = ProgressBar()
def test_v1_7_0_deprecated_max_steps_none(tmpdir):
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
_ = Trainer(max_steps=None)
trainer = Trainer()
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
trainer.fit_loop.max_steps = None
def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v1.7 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
callback = OldStatefulCallback(state=222)
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
_ = trainer.resume_from_checkpoint
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.validate(model=model, ckpt_path=ckpt_path)
assert callback.state == 222
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
with pytest.deprecated_call(
match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
):
trainer.fit(model)
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.predict(model=model, ckpt_path=ckpt_path)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
trainer.fit(model)
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
# test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path
model = BoringModel()
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
model = BoringModel()
lr_monitor = LearningRateMonitor()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor])
trainer.fit(model)
with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
assert lr_monitor.lr_sch_names == ["lr-SGD"]
@pytest.mark.parametrize(
"cls",
[
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
],
)
def test_v1_7_0_cluster_environment_master_address(cls):
class MyClusterEnvironment(cls):
def master_address(self):
pass
with pytest.deprecated_call(
match="MyClusterEnvironment.master_address` has been deprecated in v1.6 and will be removed in v1.7"
):
MyClusterEnvironment()
@pytest.mark.parametrize(
"cls",
[
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
],
)
def test_v1_7_0_cluster_environment_master_port(cls):
class MyClusterEnvironment(cls):
def master_port(self):
pass
with pytest.deprecated_call(
match="MyClusterEnvironment.master_port` has been deprecated in v1.6 and will be removed in v1.7"
):
MyClusterEnvironment()
@pytest.mark.parametrize(
"cls,method_name",
[
(KubeflowEnvironment, "is_using_kubeflow"),
(LSFEnvironment, "is_using_lsf"),
(TorchElasticEnvironment, "is_using_torchelastic"),
],
)
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
def test_v1_7_0_cluster_environment_detection(cls, method_name):
class MyClusterEnvironment(cls):
@staticmethod
def is_using_kubeflow():
pass
@staticmethod
def is_using_lsf():
pass
@staticmethod
def is_using_torchelastic():
pass
with pytest.deprecated_call(
match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7"
):
MyClusterEnvironment()
def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():
sampler = IndexBatchSamplerWrapper(Mock())
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
_ = sampler.batch_indices
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
sampler.batch_indices = []
def test_v1_7_0_post_dispatch_hook():
class CustomPlugin(SingleDeviceStrategy):
def post_dispatch(self, trainer):
pass
with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")):
CustomPlugin(torch.device("cpu"))