diff --git a/CHANGELOG.md b/CHANGELOG.md index 6355b86096..bb46b7039d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146)) +- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) @@ -323,6 +326,24 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627)) +- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_only` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_only` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + +- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_debug` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_debug` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + +- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_info` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_info` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + +- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_warn` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_warn` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + +- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_deprecation` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_deprecation` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747)) + + +- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning` + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index d3e7ef38a6..67208a127f 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -289,5 +289,6 @@ Utilities API memory model_summary parsing + rank_zero seed warnings diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index c55130d49b..8d96bfb3c1 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -98,7 +98,7 @@ Lightning automatically ensures that the model is saved only on the main process trainer.save_checkpoint("example.ckpt") Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality. -If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using +If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using model parallel distributed strategies such as deepspeed or sharded training. diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index e8c4800df2..f1066bf672 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -205,7 +205,7 @@ Make a Custom Logger ******************** You can implement your own logger by writing a class that inherits from :class:`~pytorch_lightning.loggers.base.LightningLoggerBase`. -Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively. +Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively. .. testcode:: diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index dc1e3d09d0..bd34389b21 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -25,9 +25,9 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, cli_lightning_logo from pl_examples.basic_examples.mnist_datamodule import MNIST -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_only if _TORCHVISION_AVAILABLE: import torchvision diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index c507a6f0e9..a1f27c9f08 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -58,8 +58,8 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning -from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.utilities.rank_zero import rank_zero_info log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 0a2fe81ab2..90bc0643db 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -26,8 +26,8 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index b00065f36e..8f20416af6 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -26,8 +26,8 @@ from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 6ca9d7712f..f2aa17e111 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -29,9 +29,10 @@ import torch import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only +from pytorch_lightning.utilities import _AcceleratorType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 4e529a7175..0f3519d8fe 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -27,8 +27,8 @@ from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f857bbdeac..75b1adb10c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -33,9 +33,9 @@ import yaml import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 77340f9f28..291fb495a8 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn class ProgressBarBase(Callback): diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index bb80c22d3f..813cecd92c 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -27,7 +27,7 @@ else: import pytorch_lightning as pl from pytorch_lightning.callbacks.progress.base import ProgressBarBase -from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities.rank_zero import rank_zero_debug _PAD_SIZE = 5 diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index cc6f586159..fd237c5413 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -30,8 +30,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only log = logging.getLogger(__name__) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 70605457da..42dd67b724 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,8 +24,8 @@ from torch.optim.swa_utils import SWALR import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index dbc929385f..a6649ded7e 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -24,8 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import LightningEnum -from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_info log = logging.getLogger(__name__) diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 5e71ecabb1..20555f5228 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -22,8 +22,9 @@ import time import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info +from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index f2a3a5c41f..33c83b4b10 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, " @@ -22,8 +22,6 @@ rank_zero_deprecation( from functools import wraps # noqa: E402 from typing import Callable # noqa: E402 -from pytorch_lightning.utilities import rank_zero_warn # noqa: E402 - def parameter_validation(fn: Callable) -> Callable: """Validates that the module parameter lengths match after moving to the device. It is useful when tying diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c7444f149f..908bc9ab90 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,20 +38,15 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, Hyperparameter from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.utilities import ( - _IS_WINDOWS, - _TORCH_GREATER_EQUAL_1_10, - GradClipAlgorithmType, - rank_zero_deprecation, - rank_zero_warn, -) +from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp +from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 3825382b2f..03037f3826 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -21,9 +21,9 @@ from torch import optim from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 2c9d772493..c27b63ae60 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -26,12 +26,13 @@ from warnings import warn import torch import yaml -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys +from pytorch_lightning.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index ab159e8d2d..269066a044 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -26,8 +26,7 @@ import numpy as np import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.warnings import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only def rank_zero_experiment(fn: Callable) -> Callable: diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 7fe57b243b..d180fc21ab 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -26,9 +26,10 @@ from torch import is_tensor import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _module_available from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from pytorch_lightning.utilities.rank_zero import rank_zero_only log = logging.getLogger(__name__) _COMET_AVAILABLE = _module_available("comet_ml") diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 9d28154211..1ef17417f2 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -28,9 +28,8 @@ import torch from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.logger import _add_prefix, _convert_params +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 0548599329..c1b0df2f94 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -23,8 +23,9 @@ from time import time from typing import Any, Dict, Optional, Union from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.imports import _module_available from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 96fa45b18e..fe2a810163 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -32,10 +32,10 @@ import torch from pytorch_lightning import __version__ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from pytorch_lightning.utilities.model_summary import ModelSummary +from pytorch_lightning.utilities.rank_zero import rank_zero_only if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9: try: diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 39ddca40fe..80b5480477 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -29,10 +29,11 @@ from torch.utils.tensorboard.summary import hparams import pytorch_lightning as pl from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index df158c8253..31ec893a91 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -20,9 +20,9 @@ from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import _module_available from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn _TESTTUBE_AVAILABLE = _module_available("test_tube") diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 705d06d013..6e409029e0 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -26,11 +26,10 @@ import torch.nn as nn from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities.imports import _compare_version, _module_available from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params -from pytorch_lightning.utilities.warnings import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn _WANDB_AVAILABLE = _module_available("wandb") _WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 7438c25298..5f0a7f4666 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -23,14 +23,14 @@ from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _ from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache +from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4d59af02f8..4954a2b74d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -22,11 +22,10 @@ from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.warnings import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index fd32619ed8..ea44bc0683 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -19,8 +19,8 @@ import torch import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.rank_zero import rank_zero_warn def _ignore_scalar_return_in_dp() -> None: diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index 44ec210b56..dbf8b1cfbf 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -16,7 +16,7 @@ import os import socket from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only class LightningEnvironment(ClusterEnvironment): diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index a5eed77509..49bc0755d1 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -16,7 +16,7 @@ import logging import os from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 4413afc5d4..22c6ff02ef 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -17,9 +17,9 @@ from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import _PATH log = logging.getLogger(__name__) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 8c542b8876..2d4478a17f 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -24,9 +24,9 @@ from torch import nn, Tensor from torch.autograd.profiler import record_function from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.warnings import WarningCache if TYPE_CHECKING: diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index e3281f3ad3..7b7e4e83a2 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -46,18 +46,13 @@ from pytorch_lightning.utilities import ( _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, - rank_zero_warn, ) from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import DeadlockDetectedException +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index faddfd9b27..03407e1c14 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -32,19 +32,14 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - rank_zero_debug, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index bdf0f3afa8..9ca83e3c65 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -34,11 +34,12 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import log, rank_zero_info +from pytorch_lightning.utilities.distributed import log from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index bf21be4c74..3bd81fd754 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -23,11 +23,12 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as dist_group -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_only if _HOROVOD_AVAILABLE: import horovod.torch as hvd diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 166b2a1c5d..2d1584a2e1 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -22,9 +22,10 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel diff --git a/pytorch_lightning/strategies/sharded_spawn.py b/pytorch_lightning/strategies/sharded_spawn.py index b313420a9f..289e3491be 100644 --- a/pytorch_lightning/strategies/sharded_spawn.py +++ b/pytorch_lightning/strategies/sharded_spawn.py @@ -21,9 +21,10 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 3e601f6489..a6e82441da 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -32,9 +32,10 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index a1002bfd55..28c799e0f0 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,7 +20,7 @@ from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index e5349234e4..cd8754b938 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -15,8 +15,8 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn def verify_loop_configurations(trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 977e5807c0..fd65975618 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -64,15 +64,7 @@ from pytorch_lightning.strategies import ( StrategyRegistry, TPUSpawnStrategy, ) -from pytorch_lightning.utilities import ( - _AcceleratorType, - _StrategyType, - AMPType, - device_parser, - rank_zero_deprecation, - rank_zero_info, - rank_zero_warn, -) +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, AMPType, device_parser from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( @@ -81,6 +73,7 @@ from pytorch_lightning.utilities.imports import ( _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE, ) +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn if _HOROVOD_AVAILABLE: import horovod.torch as hvd diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 74f55c16ed..45a05a446b 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -26,9 +26,9 @@ from pytorch_lightning.callbacks import ( ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info +from pytorch_lightning.utilities import ModelSummaryMode from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info class CallbackConnector: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6c0b40927e..8e279ef318 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,11 +24,12 @@ import pytorch_lightning as pl from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index c83060244c..2d508de8d0 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -26,7 +26,6 @@ from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator -from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( _teardown_dataloader_get_iterators, @@ -50,9 +49,10 @@ from pytorch_lightning.utilities.fetching import ( ) from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn +from pytorch_lightning.utilities.warnings import PossibleUserWarning class DataConnector: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a3acea7fcb..f7f708fdd1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -21,12 +21,12 @@ from torchmetrics import Metric from typing_extensions import TypedDict from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.warnings import WarningCache _IN_METRIC = Union[Metric, torch.Tensor] # Do not include scalars as they were converted to tensors diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index f5a357cddf..3fe6a75df1 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Set, Union import pytorch_lightning as pl from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS +from pytorch_lightning.utilities.rank_zero import rank_zero_info # copied from signal.pyi _SIGNUM = Union[int, signal.Signals] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6fb1618317..f0b56e35e1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -83,9 +83,6 @@ from pytorch_lightning.utilities import ( device_parser, GradClipAlgorithmType, parsing, - rank_zero_deprecation, - rank_zero_info, - rank_zero_warn, ) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import ( @@ -103,6 +100,7 @@ from pytorch_lightning.utilities.exceptions import ExitGracefullyException, Misc from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import ( diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 1406be20cc..bf37e4357d 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -20,11 +20,11 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 878f754873..ebfa9a1dd5 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -26,9 +26,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig # check if ipywidgets is installed before importing tqdm.auto diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index b1e6dd7178..916532d964 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -16,7 +16,7 @@ import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 -from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401 +from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401 from pytorch_lightning.utilities.enums import ( # noqa: F401 _AcceleratorType, _StrategyType, @@ -57,7 +57,12 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 ) from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401 from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401 +from pytorch_lightning.utilities.rank_zero import ( # noqa: F401 + rank_zero_deprecation, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index ab2b932297..a9c06febbb 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -27,10 +27,11 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer -from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion if _JSONARGPARSE_AVAILABLE: @@ -795,7 +796,7 @@ class LightningCLI: lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) if is_overridden("configure_optimizers", self.model): - warnings._warn( + _warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.configure_optimizers`." ) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 7e698da8e4..2b68c51db5 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -25,10 +25,10 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 82c000197c..7c8b1162e3 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,8 +15,6 @@ import logging import os -from functools import wraps -from platform import python_version from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -25,6 +23,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug +from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401 +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -44,56 +46,6 @@ else: log = logging.getLogger(__name__) -def rank_zero_only(fn: Callable) -> Callable: - """Function that can be used as a decorator to enable a function/method being called only on rank 0.""" - - @wraps(fn) - def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: - if rank_zero_only.rank == 0: - return fn(*args, **kwargs) - return None - - return wrapped_fn - - -# TODO: this should be part of the cluster environment -def _get_rank() -> int: - rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK") - for key in rank_keys: - rank = os.environ.get(key) - if rank is not None: - return int(rank) - return 0 - - -# add the attribute to the function but don't overwrite in case Trainer has already set it -rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank()) - - -def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: - if python_version() >= "3.8.0": - kwargs["stacklevel"] = stacklevel - log.info(*args, **kwargs) - - -def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: - if python_version() >= "3.8.0": - kwargs["stacklevel"] = stacklevel - log.debug(*args, **kwargs) - - -@rank_zero_only -def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: - """Function used to log debug-level messages only on rank 0.""" - _debug(*args, stacklevel=stacklevel, **kwargs) - - -@rank_zero_only -def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: - """Function used to log info-level messages only on rank 0.""" - _info(*args, stacklevel=stacklevel, **kwargs) - - def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. @@ -447,3 +399,21 @@ def _revert_sync_batchnorm(module: Module) -> Module: converted_module.add_module(name, _revert_sync_batchnorm(child)) del module return converted_module + + +def rank_zero_info(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6" + " and will be removed in v1.8." + " Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead." + ) + return new_rank_zero_info(*args, **kwargs) + + +def rank_zero_debug(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6" + " and will be removed in v1.8." + " Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead." + ) + return new_rank_zero_debug(*args, **kwargs) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 6d3c1d6b5f..0b9b21193b 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -26,9 +26,9 @@ from torch.nn import Module from torch.nn.modules.container import ModuleDict, ModuleList, Sequential import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities.rank_zero import rank_zero_warn if _TORCH_GREATER_EQUAL_1_10: from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index ebb7661008..f9d50bc532 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -24,7 +24,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from typing_extensions import Literal import pytorch_lightning as pl -from pytorch_lightning.utilities.warnings import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn def str_to_bool_or_str(val: str) -> Union[str, bool]: diff --git a/pytorch_lightning/utilities/rank_zero.py b/pytorch_lightning/utilities/rank_zero.py new file mode 100644 index 0000000000..df1e679208 --- /dev/null +++ b/pytorch_lightning/utilities/rank_zero.py @@ -0,0 +1,97 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities that can be used for calling functions on a particular rank.""" +import logging +import os +import warnings +from functools import partial, wraps +from platform import python_version +from typing import Any, Callable, Optional, Union + +log = logging.getLogger(__name__) + + +def rank_zero_only(fn: Callable) -> Callable: + """Function that can be used as a decorator to enable a function/method being called only on rank 0.""" + + @wraps(fn) + def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: + if rank_zero_only.rank == 0: + return fn(*args, **kwargs) + return None + + return wrapped_fn + + +# TODO: this should be part of the cluster environment +def _get_rank() -> int: + rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +# add the attribute to the function but don't overwrite in case Trainer has already set it +rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank()) + + +def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: + if python_version() >= "3.8.0": + kwargs["stacklevel"] = stacklevel + log.info(*args, **kwargs) + + +def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: + if python_version() >= "3.8.0": + kwargs["stacklevel"] = stacklevel + log.debug(*args, **kwargs) + + +@rank_zero_only +def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log debug-level messages only on rank 0.""" + _debug(*args, stacklevel=stacklevel, **kwargs) + + +@rank_zero_only +def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log info-level messages only on rank 0.""" + _info(*args, stacklevel=stacklevel, **kwargs) + + +def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None: + if type(stacklevel) is type and issubclass(stacklevel, Warning): + rank_zero_deprecation( + "Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8" + f" Please, use `category={stacklevel.__name__}`." + ) + kwargs["category"] = stacklevel + stacklevel = kwargs.pop("stacklevel", 2) + warnings.warn(message, stacklevel=stacklevel, **kwargs) + + +@rank_zero_only +def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None: + """Function used to log warn-level messages only on rank 0.""" + _warn(message, stacklevel=stacklevel, **kwargs) + + +class LightningDeprecationWarning(DeprecationWarning): + """Deprecation warnings raised by PyTorch Lightning.""" + + +rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 99071edbc6..efafc0afad 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -21,8 +21,7 @@ from typing import Optional import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 0323d5333c..61d3a2b8e2 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -14,50 +14,55 @@ """Warning-related utilities.""" import warnings -from functools import partial -from typing import Any, Union +from typing import Any -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning as NewLightningDeprecationWarning +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation as new_rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn - -def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None: - if type(stacklevel) is type and issubclass(stacklevel, Warning): - rank_zero_deprecation( - "Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8" - f" Please, use `category={stacklevel.__name__}`." - ) - kwargs["category"] = stacklevel - stacklevel = kwargs.pop("stacklevel", 2) - warnings.warn(message, stacklevel=stacklevel, **kwargs) - - -@rank_zero_only -def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None: - """Function used to log warn-level messages only on rank 0.""" - _warn(message, stacklevel=stacklevel, **kwargs) +# enable our warnings +warnings.simplefilter("default", category=NewLightningDeprecationWarning) class PossibleUserWarning(UserWarning): """Warnings that could be false positives.""" -class LightningDeprecationWarning(DeprecationWarning): - """Deprecation warnings raised by PyTorch Lightning.""" - - -# enable our warnings -warnings.simplefilter("default", category=LightningDeprecationWarning) - -rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning) - - class WarningCache(set): def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: if message not in self: self.add(message) - rank_zero_warn(message, stacklevel=stacklevel, **kwargs) + new_rank_zero_warn(message, stacklevel=stacklevel, **kwargs) def deprecation(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: if message not in self: self.add(message) - rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs) + new_rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs) + + +def rank_zero_warn(*args: Any, **kwargs: Any) -> Any: + new_rank_zero_deprecation( + "pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6" + " and will be removed in v1.8." + " Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead." + ) + return new_rank_zero_warn(*args, **kwargs) + + +def rank_zero_deprecation(*args: Any, **kwargs: Any) -> Any: + new_rank_zero_deprecation( + "pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6" + " and will be removed in v1.8." + " Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead." + ) + return new_rank_zero_deprecation(*args, **kwargs) + + +class LightningDeprecationWarning(NewLightningDeprecationWarning): + def __init__(self, *args: Any, **kwargs: Any) -> None: + new_rank_zero_deprecation( + "pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6" + " and will be removed in v1.8." + " Use the equivalent class from the pytorch_lightning.utilities.rank_zero module instead." + ) + super().__init__(*args, **kwargs) diff --git a/tests/callbacks/test_device_stats_monitor.py b/tests/callbacks/test_device_stats_monitor.py index 582a0471d7..c90ce4a4ba 100644 --- a/tests/callbacks/test_device_stats_monitor.py +++ b/tests/callbacks/test_device_stats_monitor.py @@ -19,8 +19,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import DeviceStatsMonitor from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys from pytorch_lightning.loggers import CSVLogger -from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_only from tests.helpers import BoringModel from tests.helpers.runif import RunIf diff --git a/tests/conftest.py b/tests/conftest.py index e88ec1b3f6..8ad7faa3cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,7 +35,7 @@ def datadir(): @pytest.fixture(scope="function", autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from pytorch_lightning.utilities.distributed import rank_zero_only + from pytorch_lightning.utilities.rank_zero import rank_zero_only rank = getattr(rank_zero_only, "rank", None) yield diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 053f413e7d..7a969e9894 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -32,10 +32,10 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator @@ -413,3 +413,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir): match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8." ): trainer.fit(model) + + +def test_v1_8_0_rank_zero_imports(): + + import warnings + + from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info + from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn + + with pytest.deprecated_call( + match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6" + " and will be removed in v1.8." + ): + rank_zero_debug("foo") + with pytest.deprecated_call( + match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6" + " and will be removed in v1.8." + ): + rank_zero_info("foo") + with pytest.deprecated_call( + match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6" + " and will be removed in v1.8." + ): + rank_zero_warn("foo") + with pytest.deprecated_call( + match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6" + " and will be removed in v1.8." + ): + rank_zero_deprecation("foo") + with pytest.deprecated_call( + match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6" + " and will be removed in v1.8." + ): + warnings.warn("foo", LightningDeprecationWarning, stacklevel=5) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index eb7dd949e5..41a3aec103 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -24,9 +24,9 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params +from pytorch_lightning.utilities.rank_zero import rank_zero_only from tests.helpers.boring_model import BoringDataModule, BoringModel diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index c430081036..04fe995cdd 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -20,7 +20,7 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.strategies import DDP2Strategy, DDPShardedStrategy, DDPStrategy, DeepSpeedStrategy -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only from tests.helpers.runif import RunIf diff --git a/tests/utilities/rank_zero.py b/tests/utilities/rank_zero.py new file mode 100644 index 0000000000..61bcf61c0c --- /dev/null +++ b/tests/utilities/rank_zero.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Mapping +from unittest import mock + +import pytest + + +@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) +def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): + """Test that SLURM environment variables are properly checked for rank_zero_only.""" + from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only + + rank_zero_only.rank = _get_rank() + + with mock.patch.dict(os.environ, env_vars): + from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only + + rank_zero_only.rank = _get_rank() + + @rank_zero_only + def foo(): # The return type is optional because on non-zero ranks it will not be called + return 1 + + x = foo() + assert x == 1 + + +@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")]) +def test_rank_zero_none_set(rank_key, rank): + """Test that function is not called when rank environment variables are not global zero.""" + + with mock.patch.dict(os.environ, {rank_key: rank}): + from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only + + rank_zero_only.rank = _get_rank() + + @rank_zero_only + def foo(): + return 1 + + x = foo() + assert x is None diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index 6226aadecb..a78c86e131 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Mapping -from unittest import mock -import pytest import torch import torch.multiprocessing as mp @@ -24,43 +21,6 @@ from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from tests.helpers.runif import RunIf -@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) -def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): - """Test that SLURM environment variables are properly checked for rank_zero_only.""" - from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() - - with mock.patch.dict(os.environ, env_vars): - from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() - - @rank_zero_only - def foo(): # The return type is optional because on non-zero ranks it will not be called - return 1 - - x = foo() - assert x == 1 - - -@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")]) -def test_rank_zero_none_set(rank_key, rank): - """Test that function is not called when rank environment variables are not global zero.""" - - with mock.patch.dict(os.environ, {rank_key: rank}): - from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() - - @rank_zero_only - def foo(): - return 1 - - x = foo() - assert x is None - - def _test_collect_states(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index 5262426a16..45a0d5f8bb 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -19,7 +19,8 @@ import os from contextlib import redirect_stderr from io import StringIO -from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache +from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import WarningCache standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" if standalone: @@ -40,16 +41,16 @@ if standalone: cache.deprecation("test7") output = stderr.getvalue() - assert "test_warnings.py:30: UserWarning: test1" in output - assert "test_warnings.py:31: DeprecationWarning: test2" in output + assert "test_warnings.py:31: UserWarning: test1" in output + assert "test_warnings.py:32: DeprecationWarning: test2" in output - assert "test_warnings.py:33: UserWarning: test3" in output - assert "test_warnings.py:34: DeprecationWarning: test4" in output + assert "test_warnings.py:34: UserWarning: test3" in output + assert "test_warnings.py:35: DeprecationWarning: test4" in output - assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output + assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output - assert "test_warnings.py:39: UserWarning: test6" in output - assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output + assert "test_warnings.py:40: UserWarning: test6" in output + assert "test_warnings.py:41: LightningDeprecationWarning: test7" in output # check that logging is properly configured import logging