# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.7.0.""" import os from re import escape from unittest import mock from unittest.mock import Mock import pytest import torch from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, LSFEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) from pytorch_lightning.strategies import SingleDeviceStrategy from tests.callbacks.test_callbacks import OldStatefulCallback from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule from tests.helpers.runif import RunIf from tests.loggers.test_base import CustomLogger def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): model = BoringModel() with pytest.deprecated_call(match="The `LightningModule.summarize` method is deprecated in v1.5"): model.summarize(max_depth=1) def test_v1_7_0_moved_model_summary_and_layer_summary(tmpdir): _soft_unimport_module("pytorch_lightning.core.memory") with pytest.deprecated_call(match="to `pytorch_lightning.utilities.model_summary` since v1.5"): from pytorch_lightning.core.memory import LayerSummary, ModelSummary # noqa: F401 def test_v1_7_0_moved_get_memory_profile_and_get_gpu_memory_map(tmpdir): _soft_unimport_module("pytorch_lightning.core.memory") with pytest.deprecated_call(match="to `pytorch_lightning.utilities.memory` since v1.5"): from pytorch_lightning.core.memory import get_gpu_memory_map, get_memory_profile # noqa: F401 def test_v1_7_0_deprecated_model_size(): model = BoringModel() with pytest.deprecated_call( match="LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7" ): _ = model.model_size def test_v1_7_0_datamodule_transform_properties(tmpdir): dm = MNISTDataModule() with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"): dm.train_transforms = "a" with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"): dm.val_transforms = "b" with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): dm.test_transforms = "c" with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"): _ = LightningDataModule(train_transforms="a") with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"): _ = LightningDataModule(val_transforms="b") with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): _ = LightningDataModule(test_transforms="c") with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1)) def test_v1_7_0_datamodule_size_property(tmpdir): dm = MNISTDataModule() with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"): dm.size() def test_v1_7_0_datamodule_dims_property(tmpdir): dm = MNISTDataModule() with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): _ = dm.dims with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): _ = LightningDataModule(dims=(1, 1, 1)) def test_v1_7_0_moved_get_progress_bar_dict(tmpdir): class TestModel(BoringModel): def get_progress_bar_dict(self): items = super().get_progress_bar_dict() items.pop("v_num", None) return items trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=None, fast_dev_run=True, ) test_model = TestModel() with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"): trainer.fit(test_model) standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix assert "loss" in standard_metrics_postfix assert "v_num" not in standard_metrics_postfix with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"): _ = trainer.progress_bar_dict def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"): _ = Trainer(prepare_data_per_node=False) def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"): _ = Trainer(stochastic_weight_avg=True) @pytest.mark.parametrize("terminate_on_nan", [True, False]) def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): with pytest.deprecated_call( match="Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7" ): trainer = Trainer(terminate_on_nan=terminate_on_nan) assert trainer.terminate_on_nan is terminate_on_nan assert trainer._detect_anomaly is False trainer = Trainer() with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"): _ = trainer.terminate_on_nan with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"): trainer.terminate_on_nan = True def test_v1_7_0_deprecated_on_task_dataloader(tmpdir): class CustomBoringModel(BoringModel): def on_train_dataloader(self): print("on_train_dataloader") def on_val_dataloader(self): print("on_val_dataloader") def on_test_dataloader(self): print("on_test_dataloader") def on_predict_dataloader(self): print("on_predict_dataloader") def _run(model, task="fit"): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) getattr(trainer, task)(model) model = CustomBoringModel() with pytest.deprecated_call( match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "fit") with pytest.deprecated_call( match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "fit") with pytest.deprecated_call( match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "validate") with pytest.deprecated_call( match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "test") with pytest.deprecated_call( match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." ): _run(model, "predict") @mock.patch("pytorch_lightning.loggers.test_tube.Experiment") def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) def test_v1_7_0_on_interrupt(tmpdir): class HandleInterruptCallback(Callback): def on_keyboard_interrupt(self, trainer, pl_module): print("keyboard interrupt") model = BoringModel() handle_interrupt_callback = HandleInterruptCallback() trainer = Trainer( callbacks=[handle_interrupt_callback], max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, enable_progress_bar=False, logger=False, default_root_dir=tmpdir, ) with pytest.deprecated_call( match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" ): trainer.fit(model) def test_v1_7_0_process_position_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"): _ = Trainer(process_position=5) def test_v1_7_0_flush_logs_every_n_steps_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(flush_logs_every_n_steps=10\)` is deprecated in v1.5"): _ = Trainer(flush_logs_every_n_steps=10) class BoringCallbackDDPSpawnModel(BoringModel): def add_to_queue(self, queue): ... def get_from_queue(self, queue): ... def test_v1_7_0_deprecate_add_get_queue(tmpdir): model = BoringCallbackDDPSpawnModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"): trainer.fit(model) with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"): trainer.fit(model) def test_v1_7_0_progress_bar_refresh_rate_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(progress_bar_refresh_rate=1\)` is deprecated in v1.5"): _ = Trainer(progress_bar_refresh_rate=1) def test_v1_7_0_lightning_logger_base_close(tmpdir): logger = CustomLogger() with pytest.deprecated_call( match="`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7." ): logger.close() with pytest.deprecated_call( match="`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7." ): logger = LoggerCollection([logger]) logger.close() def test_v1_7_0_deprecate_lightning_distributed(tmpdir): with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."): from pytorch_lightning.distributed.dist import LightningDistributed _ = LightningDistributed() def test_v1_7_0_checkpoint_callback_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(checkpoint_callback=True\)` is deprecated in v1.5"): _ = Trainer(checkpoint_callback=True) def test_v1_7_0_old_on_train_batch_start(tmpdir): class OldSignature(Callback): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): ... class OldSignatureModel(BoringModel): def on_train_batch_start(self, batch, batch_idx, dataloader_idx): ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model) model = OldSignatureModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model) def test_v1_7_0_old_on_train_batch_end(tmpdir): class OldSignature(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): ... class OldSignatureModel(BoringModel): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model) model = OldSignatureModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): trainer.fit(model) def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir): class TestModel(BoringModel): def on_post_move_to_device(self): print("on_post_move_to_device") model = TestModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1) with pytest.deprecated_call( match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7" ): trainer.fit(model) def test_v1_7_0_deprecate_parameter_validation(): _soft_unimport_module("pytorch_lightning.core.decorators") with pytest.deprecated_call( match="Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5" ): from pytorch_lightning.core.decorators import parameter_validation # noqa: F401 def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag(): with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): Trainer(accelerator="ddp_spawn") def test_v1_7_0_passing_strategy_to_plugins_flag(): with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): Trainer(plugins="ddp_spawn") def test_v1_7_0_weights_summary_trainer(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"): t = Trainer(weights_summary="full") with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=None\)` is deprecated in v1.5"): t = Trainer(weights_summary=None) t = Trainer(weights_summary="top") with pytest.deprecated_call(match=r"`Trainer.weights_summary` is deprecated in v1.5"): _ = t.weights_summary with pytest.deprecated_call(match=r"Setting `Trainer.weights_summary` is deprecated in v1.5"): t.weights_summary = "blah" def test_v1_7_0_trainer_log_gpu_memory(tmpdir): with pytest.deprecated_call( match="Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed" ): _ = Trainer(log_gpu_memory="min_max") def test_v1_7_0_deprecated_slurm_job_id(): trainer = Trainer() with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."): trainer.slurm_job_id @RunIf(min_gpus=1) def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir): with pytest.deprecated_call(match="The `GPUStatsMonitor` callback was deprecated in v1.5"): _ = GPUStatsMonitor() @RunIf(tpu=True) def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"): _ = XLAStatsMonitor() def test_v1_7_0_progress_bar(): with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): _ = ProgressBar() def test_v1_7_0_deprecated_max_steps_none(tmpdir): with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"): _ = Trainer(max_steps=None) trainer = Trainer() with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"): trainer.fit_loop.max_steps = None def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): # test resume_from_checkpoint still works until v1.7 deprecation model = BoringModel() callback = OldStatefulCallback(state=111) trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path callback = OldStatefulCallback(state=222) with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) with pytest.deprecated_call( match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." ): _ = trainer.resume_from_checkpoint assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path trainer.validate(model=model, ckpt_path=ckpt_path) assert callback.state == 222 assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path with pytest.deprecated_call( match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." ): trainer.fit(model) assert callback.state == 111 assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None trainer.predict(model=model, ckpt_path=ckpt_path) assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None trainer.fit(model) assert trainer.checkpoint_connector.resume_checkpoint_path is None assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None # test fit(ckpt_path=) precedence over Trainer(resume_from_checkpoint=) path model = BoringModel() with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): trainer = Trainer(resume_from_checkpoint="trainer_arg_path") with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."): trainer.fit(model, ckpt_path="fit_arg_ckpt_path") def test_v1_7_0_deprecate_lr_sch_names(tmpdir): model = BoringModel() lr_monitor = LearningRateMonitor() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor]) trainer.fit(model) with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"): assert lr_monitor.lr_sch_names == ["lr-SGD"] @pytest.mark.parametrize( "cls", [ KubeflowEnvironment, LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment, ], ) def test_v1_7_0_cluster_environment_master_address(cls): class MyClusterEnvironment(cls): def master_address(self): pass with pytest.deprecated_call( match="MyClusterEnvironment.master_address` has been deprecated in v1.6 and will be removed in v1.7" ): MyClusterEnvironment() @pytest.mark.parametrize( "cls", [ KubeflowEnvironment, LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment, ], ) def test_v1_7_0_cluster_environment_master_port(cls): class MyClusterEnvironment(cls): def master_port(self): pass with pytest.deprecated_call( match="MyClusterEnvironment.master_port` has been deprecated in v1.6 and will be removed in v1.7" ): MyClusterEnvironment() @pytest.mark.parametrize( "cls,method_name", [ (KubeflowEnvironment, "is_using_kubeflow"), (LSFEnvironment, "is_using_lsf"), (TorchElasticEnvironment, "is_using_torchelastic"), ], ) @mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"}) def test_v1_7_0_cluster_environment_detection(cls, method_name): class MyClusterEnvironment(cls): @staticmethod def is_using_kubeflow(): pass @staticmethod def is_using_lsf(): pass @staticmethod def is_using_torchelastic(): pass with pytest.deprecated_call( match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7" ): MyClusterEnvironment() def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): sampler = IndexBatchSamplerWrapper(Mock()) with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): _ = sampler.batch_indices with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): sampler.batch_indices = [] def test_v1_7_0_post_dispatch_hook(): class CustomPlugin(SingleDeviceStrategy): def post_dispatch(self, trainer): pass with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): CustomPlugin(torch.device("cpu"))