Centralize rank_zero_only utilities into their own module (#11747)

* Centralize rank_zero_only utilities into their own module

Fixes #11746

* PossibleUserWarning

* Update test_warnings.py

* update imports

* more imports

* Update CHANGELOG.md

* Update mlflow.py

* Update cli.py

* Update api_references.rst

* Update meta.py

* add deprecation tests

* debug standalone

* fix standalone tests

* Update CHANGELOG.md
This commit is contained in:
ananthsub 2022-02-07 00:09:55 -08:00 committed by GitHub
parent 34c454c756
commit a64438c897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 370 additions and 239 deletions

View File

@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146))
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
### Changed
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
@ -323,6 +326,24 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_only` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_only` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_debug` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_debug` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_info` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_info` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_warn` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_warn` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_deprecation` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_deprecation` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`
### Removed
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

View File

@ -289,5 +289,6 @@ Utilities API
memory
model_summary
parsing
rank_zero
seed
warnings

View File

@ -98,7 +98,7 @@ Lightning automatically ensures that the model is saved only on the main process
trainer.save_checkpoint("example.ckpt")
Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
model parallel distributed strategies such as deepspeed or sharded training.

View File

@ -205,7 +205,7 @@ Make a Custom Logger
********************
You can implement your own logger by writing a class that inherits from :class:`~pytorch_lightning.loggers.base.LightningLoggerBase`.
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.
.. testcode::

View File

@ -25,9 +25,9 @@ from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _TORCHVISION_AVAILABLE:
import torchvision

View File

@ -58,8 +58,8 @@ import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.rank_zero import rank_zero_info
log = logging.getLogger(__name__)
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"

View File

@ -26,8 +26,8 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -26,8 +26,8 @@ from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -29,9 +29,10 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities.types import STEP_OUTPUT

View File

@ -27,8 +27,8 @@ from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig

View File

@ -33,9 +33,9 @@ import yaml
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

View File

@ -15,7 +15,7 @@ from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
class ProgressBarBase(Callback):

View File

@ -27,7 +27,7 @@ else:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
_PAD_SIZE = 5

View File

@ -30,8 +30,8 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
log = logging.getLogger(__name__)

View File

@ -24,8 +24,8 @@ from torch.optim.swa_utils import SWALR
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]

View File

@ -24,8 +24,8 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info
log = logging.getLogger(__name__)

View File

@ -22,8 +22,9 @@ import time
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
rank_zero_deprecation(
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
@ -22,8 +22,6 @@ rank_zero_deprecation(
from functools import wraps # noqa: E402
from typing import Callable # noqa: E402
from pytorch_lightning.utilities import rank_zero_warn # noqa: E402
def parameter_validation(fn: Callable) -> Callable:
"""Validates that the module parameter lengths match after moving to the device. It is useful when tying

View File

@ -38,20 +38,15 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, Hyperparameter
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_10,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

View File

@ -21,9 +21,9 @@ from torch import optim
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau

View File

@ -26,12 +26,13 @@ from warnings import warn
import torch
import yaml
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.parsing import parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)
PRIMITIVE_TYPES = (bool, int, float, str)

View File

@ -26,8 +26,7 @@ import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
def rank_zero_experiment(fn: Callable) -> Callable:

View File

@ -26,9 +26,10 @@ from torch import is_tensor
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_only
log = logging.getLogger(__name__)
_COMET_AVAILABLE = _module_available("comet_ml")

View File

@ -28,9 +28,8 @@ import torch
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -23,8 +23,9 @@ from time import time
from typing import Any, Dict, Optional, Union
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.imports import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"

View File

@ -32,10 +32,10 @@ import torch
from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
try:

View File

@ -29,10 +29,11 @@ from torch.utils.tensorboard.summary import hparams
import pytorch_lightning as pl
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -20,9 +20,9 @@ from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
_TESTTUBE_AVAILABLE = _module_available("test_tube")

View File

@ -26,11 +26,10 @@ import torch.nn as nn
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.imports import _compare_version, _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from pytorch_lightning.utilities.warnings import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
_WANDB_AVAILABLE = _module_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")

View File

@ -23,14 +23,14 @@ from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
from pytorch_lightning.utilities.warnings import WarningCache
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

View File

@ -22,11 +22,10 @@ from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -19,8 +19,8 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
def _ignore_scalar_return_in_dp() -> None:

View File

@ -16,7 +16,7 @@ import os
import socket
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
class LightningEnvironment(ClusterEnvironment):

View File

@ -16,7 +16,7 @@ import logging
import os
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -17,9 +17,9 @@ from typing import Any, Callable, Dict, Optional
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _PATH
log = logging.getLogger(__name__)

View File

@ -24,9 +24,9 @@ from torch import nn, Tensor
from torch.autograd.profiler import record_function
from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache
if TYPE_CHECKING:

View File

@ -46,18 +46,13 @@ from pytorch_lightning.utilities import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_GREATER_EQUAL_1_10,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

View File

@ -32,19 +32,14 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
rank_zero_debug,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

View File

@ -34,11 +34,12 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info
from pytorch_lightning.utilities.distributed import log
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

View File

@ -23,11 +23,12 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

View File

@ -22,9 +22,10 @@ import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

View File

@ -21,9 +21,10 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

View File

@ -32,9 +32,10 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

View File

@ -20,7 +20,7 @@ from packaging.version import Version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT

View File

@ -15,8 +15,8 @@ import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
def verify_loop_configurations(trainer: "pl.Trainer") -> None:

View File

@ -64,15 +64,7 @@ from pytorch_lightning.strategies import (
StrategyRegistry,
TPUSpawnStrategy,
)
from pytorch_lightning.utilities import (
_AcceleratorType,
_StrategyType,
AMPType,
device_parser,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, AMPType, device_parser
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
@ -81,6 +73,7 @@ from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TPU_AVAILABLE,
)
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

View File

@ -26,9 +26,9 @@ from pytorch_lightning.callbacks import (
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities import ModelSummaryMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
class CallbackConnector:

View File

@ -24,11 +24,12 @@ import pytorch_lightning as pl
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

View File

@ -26,7 +26,6 @@ from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_teardown_dataloader_get_iterators,
@ -50,9 +49,10 @@ from pytorch_lightning.utilities.fetching import (
)
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn
from pytorch_lightning.utilities.warnings import PossibleUserWarning
class DataConnector:

View File

@ -21,12 +21,12 @@ from torchmetrics import Metric
from typing_extensions import TypedDict
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache
_IN_METRIC = Union[Metric, torch.Tensor] # Do not include scalars as they were converted to tensors

View File

@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Set, Union
import pytorch_lightning as pl
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
from pytorch_lightning.utilities.rank_zero import rank_zero_info
# copied from signal.pyi
_SIGNUM = Union[int, signal.Signals]

View File

@ -83,9 +83,6 @@ from pytorch_lightning.utilities import (
device_parser,
GradClipAlgorithmType,
parsing,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.argparse import (
@ -103,6 +100,7 @@ from pytorch_lightning.utilities.exceptions import ExitGracefullyException, Misc
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import (

View File

@ -20,11 +20,11 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -26,9 +26,9 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
# check if ipywidgets is installed before importing tqdm.auto

View File

@ -16,7 +16,7 @@
import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
_AcceleratorType,
_StrategyType,
@ -57,7 +57,12 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
)
from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401
from pytorch_lightning.utilities.rank_zero import ( # noqa: F401
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

View File

@ -27,10 +27,11 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion
if _JSONARGPARSE_AVAILABLE:
@ -795,7 +796,7 @@ class LightningCLI:
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)
if is_overridden("configure_optimizers", self.model):
warnings._warn(
_warn(
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
f"`{self.__class__.__name__}.configure_optimizers`."
)

View File

@ -25,10 +25,10 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler,
import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.seed import pl_worker_init_function
from pytorch_lightning.utilities.warnings import WarningCache

View File

@ -15,8 +15,6 @@
import logging
import os
from functools import wraps
from platform import python_version
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@ -25,6 +23,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -44,56 +46,6 @@ else:
log = logging.getLogger(__name__)
def rank_zero_only(fn: Callable) -> Callable:
"""Function that can be used as a decorator to enable a function/method being called only on rank 0."""
@wraps(fn)
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
if rank_zero_only.rank == 0:
return fn(*args, **kwargs)
return None
return wrapped_fn
# TODO: this should be part of the cluster environment
def _get_rank() -> int:
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.info(*args, **kwargs)
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.debug(*args, **kwargs)
@rank_zero_only
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log debug-level messages only on rank 0."""
_debug(*args, stacklevel=stacklevel, **kwargs)
@rank_zero_only
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log info-level messages only on rank 0."""
_info(*args, stacklevel=stacklevel, **kwargs)
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
@ -447,3 +399,21 @@ def _revert_sync_batchnorm(module: Module) -> Module:
converted_module.add_module(name, _revert_sync_batchnorm(child))
del module
return converted_module
def rank_zero_info(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_info(*args, **kwargs)
def rank_zero_debug(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_debug(*args, **kwargs)

View File

@ -26,9 +26,9 @@ from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
if _TORCH_GREATER_EQUAL_1_10:
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]

View File

@ -24,7 +24,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Literal
import pytorch_lightning as pl
from pytorch_lightning.utilities.warnings import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
def str_to_bool_or_str(val: str) -> Union[str, bool]:

View File

@ -0,0 +1,97 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities that can be used for calling functions on a particular rank."""
import logging
import os
import warnings
from functools import partial, wraps
from platform import python_version
from typing import Any, Callable, Optional, Union
log = logging.getLogger(__name__)
def rank_zero_only(fn: Callable) -> Callable:
"""Function that can be used as a decorator to enable a function/method being called only on rank 0."""
@wraps(fn)
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
if rank_zero_only.rank == 0:
return fn(*args, **kwargs)
return None
return wrapped_fn
# TODO: this should be part of the cluster environment
def _get_rank() -> int:
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.info(*args, **kwargs)
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.debug(*args, **kwargs)
@rank_zero_only
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log debug-level messages only on rank 0."""
_debug(*args, stacklevel=stacklevel, **kwargs)
@rank_zero_only
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log info-level messages only on rank 0."""
_info(*args, stacklevel=stacklevel, **kwargs)
def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
if type(stacklevel) is type and issubclass(stacklevel, Warning):
rank_zero_deprecation(
"Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8"
f" Please, use `category={stacklevel.__name__}`."
)
kwargs["category"] = stacklevel
stacklevel = kwargs.pop("stacklevel", 2)
warnings.warn(message, stacklevel=stacklevel, **kwargs)
@rank_zero_only
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log warn-level messages only on rank 0."""
_warn(message, stacklevel=stacklevel, **kwargs)
class LightningDeprecationWarning(DeprecationWarning):
"""Deprecation warnings raised by PyTorch Lightning."""
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)

View File

@ -21,8 +21,7 @@ from typing import Optional
import numpy as np
import torch
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)

View File

@ -14,50 +14,55 @@
"""Warning-related utilities."""
import warnings
from functools import partial
from typing import Any, Union
from typing import Any
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning as NewLightningDeprecationWarning
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation as new_rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn
def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
if type(stacklevel) is type and issubclass(stacklevel, Warning):
rank_zero_deprecation(
"Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8"
f" Please, use `category={stacklevel.__name__}`."
)
kwargs["category"] = stacklevel
stacklevel = kwargs.pop("stacklevel", 2)
warnings.warn(message, stacklevel=stacklevel, **kwargs)
@rank_zero_only
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log warn-level messages only on rank 0."""
_warn(message, stacklevel=stacklevel, **kwargs)
# enable our warnings
warnings.simplefilter("default", category=NewLightningDeprecationWarning)
class PossibleUserWarning(UserWarning):
"""Warnings that could be false positives."""
class LightningDeprecationWarning(DeprecationWarning):
"""Deprecation warnings raised by PyTorch Lightning."""
# enable our warnings
warnings.simplefilter("default", category=LightningDeprecationWarning)
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)
class WarningCache(set):
def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
if message not in self:
self.add(message)
rank_zero_warn(message, stacklevel=stacklevel, **kwargs)
new_rank_zero_warn(message, stacklevel=stacklevel, **kwargs)
def deprecation(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
if message not in self:
self.add(message)
rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
new_rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
def rank_zero_warn(*args: Any, **kwargs: Any) -> Any:
new_rank_zero_deprecation(
"pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_warn(*args, **kwargs)
def rank_zero_deprecation(*args: Any, **kwargs: Any) -> Any:
new_rank_zero_deprecation(
"pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_deprecation(*args, **kwargs)
class LightningDeprecationWarning(NewLightningDeprecationWarning):
def __init__(self, *args: Any, **kwargs: Any) -> None:
new_rank_zero_deprecation(
"pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent class from the pytorch_lightning.utilities.rank_zero module instead."
)
super().__init__(*args, **kwargs)

View File

@ -19,8 +19,8 @@ from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

View File

@ -35,7 +35,7 @@ def datadir():
@pytest.fixture(scope="function", autouse=True)
def preserve_global_rank_variable():
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
rank = getattr(rank_zero_only, "rank", None)
yield

View File

@ -32,10 +32,10 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
@ -413,3 +413,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir):
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model)
def test_v1_8_0_rank_zero_imports():
import warnings
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info
from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn
with pytest.deprecated_call(
match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
" and will be removed in v1.8."
):
rank_zero_debug("foo")
with pytest.deprecated_call(
match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
" and will be removed in v1.8."
):
rank_zero_info("foo")
with pytest.deprecated_call(
match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
" and will be removed in v1.8."
):
rank_zero_warn("foo")
with pytest.deprecated_call(
match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
" and will be removed in v1.8."
):
rank_zero_deprecation("foo")
with pytest.deprecated_call(
match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
" and will be removed in v1.8."
):
warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)

View File

@ -24,9 +24,9 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from tests.helpers.boring_model import BoringDataModule, BoringModel

View File

@ -20,7 +20,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.strategies import DDP2Strategy, DDPShardedStrategy, DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from tests.helpers.runif import RunIf

View File

@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Mapping
from unittest import mock
import pytest
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
with mock.patch.dict(os.environ, env_vars):
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
@rank_zero_only
def foo(): # The return type is optional because on non-zero ranks it will not be called
return 1
x = foo()
assert x == 1
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
def test_rank_zero_none_set(rank_key, rank):
"""Test that function is not called when rank environment variables are not global zero."""
with mock.patch.dict(os.environ, {rank_key: rank}):
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
@rank_zero_only
def foo():
return 1
x = foo()
assert x is None

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Mapping
from unittest import mock
import pytest
import torch
import torch.multiprocessing as mp
@ -24,43 +21,6 @@ from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from tests.helpers.runif import RunIf
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
with mock.patch.dict(os.environ, env_vars):
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
@rank_zero_only
def foo(): # The return type is optional because on non-zero ranks it will not be called
return 1
x = foo()
assert x == 1
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
def test_rank_zero_none_set(rank_key, rank):
"""Test that function is not called when rank environment variables are not global zero."""
with mock.patch.dict(os.environ, {rank_key: rank}):
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
@rank_zero_only
def foo():
return 1
x = foo()
assert x is None
def _test_collect_states(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"

View File

@ -19,7 +19,8 @@ import os
from contextlib import redirect_stderr
from io import StringIO
from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache
standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1"
if standalone:
@ -40,16 +41,16 @@ if standalone:
cache.deprecation("test7")
output = stderr.getvalue()
assert "test_warnings.py:30: UserWarning: test1" in output
assert "test_warnings.py:31: DeprecationWarning: test2" in output
assert "test_warnings.py:31: UserWarning: test1" in output
assert "test_warnings.py:32: DeprecationWarning: test2" in output
assert "test_warnings.py:33: UserWarning: test3" in output
assert "test_warnings.py:34: DeprecationWarning: test4" in output
assert "test_warnings.py:34: UserWarning: test3" in output
assert "test_warnings.py:35: DeprecationWarning: test4" in output
assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output
assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output
assert "test_warnings.py:39: UserWarning: test6" in output
assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output
assert "test_warnings.py:40: UserWarning: test6" in output
assert "test_warnings.py:41: LightningDeprecationWarning: test7" in output
# check that logging is properly configured
import logging