Centralize rank_zero_only utilities into their own module (#11747)
* Centralize rank_zero_only utilities into their own module Fixes #11746 * PossibleUserWarning * Update test_warnings.py * update imports * more imports * Update CHANGELOG.md * Update mlflow.py * Update cli.py * Update api_references.rst * Update meta.py * add deprecation tests * debug standalone * fix standalone tests * Update CHANGELOG.md
This commit is contained in:
parent
34c454c756
commit
a64438c897
21
CHANGELOG.md
21
CHANGELOG.md
|
@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146))
|
||||
|
||||
|
||||
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
|
||||
|
@ -323,6 +326,24 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_only` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_only` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_debug` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_debug` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.distributed.rank_zero_info` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_info` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_warn` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_warn` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.warnings.rank_zero_deprecation` in favor of `pytorch_lightning.utilities.rank_zero.rank_zero_deprecation` ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
|
||||
|
|
|
@ -289,5 +289,6 @@ Utilities API
|
|||
memory
|
||||
model_summary
|
||||
parsing
|
||||
rank_zero
|
||||
seed
|
||||
warnings
|
||||
|
|
|
@ -98,7 +98,7 @@ Lightning automatically ensures that the model is saved only on the main process
|
|||
trainer.save_checkpoint("example.ckpt")
|
||||
|
||||
Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
|
||||
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
|
||||
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
|
||||
model parallel distributed strategies such as deepspeed or sharded training.
|
||||
|
||||
|
||||
|
|
|
@ -205,7 +205,7 @@ Make a Custom Logger
|
|||
********************
|
||||
|
||||
You can implement your own logger by writing a class that inherits from :class:`~pytorch_lightning.loggers.base.LightningLoggerBase`.
|
||||
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.
|
||||
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.
|
||||
|
||||
.. testcode::
|
||||
|
||||
|
|
|
@ -25,9 +25,9 @@ from torch.utils.data import DataLoader, random_split
|
|||
import pytorch_lightning as pl
|
||||
from pl_examples import _DATASETS_PATH, cli_lightning_logo
|
||||
from pl_examples.basic_examples.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
import torchvision
|
||||
|
|
|
@ -58,8 +58,8 @@ import pytorch_lightning as pl
|
|||
from pl_examples import cli_lightning_logo
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
|
||||
|
|
|
@ -26,8 +26,8 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -29,9 +29,10 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only
|
||||
from pytorch_lightning.utilities import _AcceleratorType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ import yaml
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
class ProgressBarBase(Callback):
|
||||
|
|
|
@ -27,7 +27,7 @@ else:
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
|
||||
|
||||
_PAD_SIZE = 5
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ from torch.optim.swa_utils import SWALR
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
|
||||
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
|
||||
|
|
|
@ -24,8 +24,8 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import LightningEnum
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -22,8 +22,9 @@ import time
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
rank_zero_deprecation(
|
||||
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
|
||||
|
@ -22,8 +22,6 @@ rank_zero_deprecation(
|
|||
from functools import wraps # noqa: E402
|
||||
from typing import Callable # noqa: E402
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn # noqa: E402
|
||||
|
||||
|
||||
def parameter_validation(fn: Callable) -> Callable:
|
||||
"""Validates that the module parameter lengths match after moving to the device. It is useful when tying
|
||||
|
|
|
@ -38,20 +38,15 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, Hyperparameter
|
|||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.saving import ModelIO
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.utilities import (
|
||||
_IS_WINDOWS,
|
||||
_TORCH_GREATER_EQUAL_1_10,
|
||||
GradClipAlgorithmType,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import get_model_size_mb
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
||||
from pytorch_lightning.utilities.parsing import collect_init_args
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
|
|
@ -21,9 +21,9 @@ from torch import optim
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
|
||||
|
||||
|
||||
|
|
|
@ -26,12 +26,13 @@ from warnings import warn
|
|||
import torch
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
|
|
|
@ -26,8 +26,7 @@ import numpy as np
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
|
||||
|
||||
|
||||
def rank_zero_experiment(fn: Callable) -> Callable:
|
||||
|
|
|
@ -26,9 +26,10 @@ from torch import is_tensor
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _module_available
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_COMET_AVAILABLE = _module_available("comet_ml")
|
||||
|
|
|
@ -28,9 +28,8 @@ import torch
|
|||
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ from time import time
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.imports import _module_available
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
|
|
@ -32,10 +32,10 @@ import torch
|
|||
from pytorch_lightning import __version__
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
|
||||
try:
|
||||
|
|
|
@ -29,10 +29,11 @@ from torch.utils.tensorboard.summary import hparams
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import _module_available
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
|
||||
|
||||
_TESTTUBE_AVAILABLE = _module_available("test_tube")
|
||||
|
||||
|
|
|
@ -26,11 +26,10 @@ import torch.nn as nn
|
|||
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _module_available
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
_WANDB_AVAILABLE = _module_available("wandb")
|
||||
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")
|
||||
|
|
|
@ -23,14 +23,14 @@ from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _
|
|||
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
||||
|
||||
|
|
|
@ -22,11 +22,10 @@ from pytorch_lightning.loops.utilities import _is_max_limit_reached
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import Progress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
def _ignore_scalar_return_in_dp() -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ import os
|
|||
import socket
|
||||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
|
||||
class LightningEnvironment(ClusterEnvironment):
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
import os
|
||||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ from typing import Any, Callable, Dict, Optional
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -24,9 +24,9 @@ from torch import nn, Tensor
|
|||
from torch.autograd.profiler import record_function
|
||||
|
||||
from pytorch_lightning.profiler.base import BaseProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -46,18 +46,13 @@ from pytorch_lightning.utilities import (
|
|||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
_TORCH_GREATER_EQUAL_1_10,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
init_dist_connection,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -32,19 +32,14 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
init_dist_connection,
|
||||
rank_zero_debug,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -34,11 +34,12 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import log, rank_zero_info
|
||||
from pytorch_lightning.utilities.distributed import log
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
|
||||
|
|
|
@ -23,11 +23,12 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as dist_group
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
|
|
@ -22,9 +22,10 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
|
|
@ -21,9 +21,10 @@ from torch.optim import Optimizer
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
|
|
@ -32,9 +32,10 @@ from pytorch_lightning.trainer.states import TrainerFn
|
|||
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.data import has_len
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from packaging.version import Version
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
|
||||
def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
||||
|
|
|
@ -64,15 +64,7 @@ from pytorch_lightning.strategies import (
|
|||
StrategyRegistry,
|
||||
TPUSpawnStrategy,
|
||||
)
|
||||
from pytorch_lightning.utilities import (
|
||||
_AcceleratorType,
|
||||
_StrategyType,
|
||||
AMPType,
|
||||
device_parser,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, AMPType, device_parser
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
|
@ -81,6 +73,7 @@ from pytorch_lightning.utilities.imports import (
|
|||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TPU_AVAILABLE,
|
||||
)
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
|
|
@ -26,9 +26,9 @@ from pytorch_lightning.callbacks import (
|
|||
)
|
||||
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
|
||||
from pytorch_lightning.utilities import ModelSummaryMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
|
||||
|
||||
|
||||
class CallbackConnector:
|
||||
|
|
|
@ -24,11 +24,12 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.loops.utilities import _is_max_limit_reached
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ from pytorch_lightning.accelerators import GPUAccelerator
|
|||
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_teardown_dataloader_get_iterators,
|
||||
|
@ -50,9 +49,10 @@ from pytorch_lightning.utilities.fetching import (
|
|||
)
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
|
||||
class DataConnector:
|
||||
|
|
|
@ -21,12 +21,12 @@ from torchmetrics import Metric
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
|
||||
from pytorch_lightning.utilities.data import extract_batch_size
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.metrics import metrics_to_scalars
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
_IN_METRIC = Union[Metric, torch.Tensor] # Do not include scalars as they were converted to tensors
|
||||
|
|
|
@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Set, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
|
||||
# copied from signal.pyi
|
||||
_SIGNUM = Union[int, signal.Signals]
|
||||
|
|
|
@ -83,9 +83,6 @@ from pytorch_lightning.utilities import (
|
|||
device_parser,
|
||||
GradClipAlgorithmType,
|
||||
parsing,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
|
@ -103,6 +100,7 @@ from pytorch_lightning.utilities.exceptions import ExitGracefullyException, Misc
|
|||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import (
|
||||
|
|
|
@ -20,11 +20,11 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.data import has_len_all_ranks
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
||||
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
|
||||
# check if ipywidgets is installed before importing tqdm.auto
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
|
||||
from pytorch_lightning.utilities.enums import ( # noqa: F401
|
||||
_AcceleratorType,
|
||||
_StrategyType,
|
||||
|
@ -57,7 +57,12 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
)
|
||||
from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401
|
||||
from pytorch_lightning.utilities.rank_zero import ( # noqa: F401
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
)
|
||||
|
||||
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
|
||||
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
||||
|
|
|
@ -27,10 +27,11 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion
|
||||
|
||||
if _JSONARGPARSE_AVAILABLE:
|
||||
|
@ -795,7 +796,7 @@ class LightningCLI:
|
|||
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)
|
||||
|
||||
if is_overridden("configure_optimizers", self.model):
|
||||
warnings._warn(
|
||||
_warn(
|
||||
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
|
||||
f"`{self.__class__.__name__}.configure_optimizers`."
|
||||
)
|
||||
|
|
|
@ -25,10 +25,10 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler,
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import pl_worker_init_function
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
from platform import python_version
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -25,6 +23,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -44,56 +46,6 @@ else:
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rank_zero_only(fn: Callable) -> Callable:
|
||||
"""Function that can be used as a decorator to enable a function/method being called only on rank 0."""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
|
||||
if rank_zero_only.rank == 0:
|
||||
return fn(*args, **kwargs)
|
||||
return None
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
# TODO: this should be part of the cluster environment
|
||||
def _get_rank() -> int:
|
||||
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
|
||||
for key in rank_keys:
|
||||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
return int(rank)
|
||||
return 0
|
||||
|
||||
|
||||
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
||||
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
|
||||
|
||||
|
||||
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if python_version() >= "3.8.0":
|
||||
kwargs["stacklevel"] = stacklevel
|
||||
log.info(*args, **kwargs)
|
||||
|
||||
|
||||
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if python_version() >= "3.8.0":
|
||||
kwargs["stacklevel"] = stacklevel
|
||||
log.debug(*args, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log debug-level messages only on rank 0."""
|
||||
_debug(*args, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log info-level messages only on rank 0."""
|
||||
_info(*args, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
|
||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
||||
|
||||
|
@ -447,3 +399,21 @@ def _revert_sync_batchnorm(module: Module) -> Module:
|
|||
converted_module.add_module(name, _revert_sync_batchnorm(child))
|
||||
del module
|
||||
return converted_module
|
||||
|
||||
|
||||
def rank_zero_info(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
||||
)
|
||||
return new_rank_zero_info(*args, **kwargs)
|
||||
|
||||
|
||||
def rank_zero_debug(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
||||
)
|
||||
return new_rank_zero_debug(*args, **kwargs)
|
||||
|
|
|
@ -26,9 +26,9 @@ from torch.nn import Module
|
|||
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]
|
||||
|
|
|
@ -24,7 +24,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
|||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
def str_to_bool_or_str(val: str) -> Union[str, bool]:
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities that can be used for calling functions on a particular rank."""
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from platform import python_version
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def rank_zero_only(fn: Callable) -> Callable:
|
||||
"""Function that can be used as a decorator to enable a function/method being called only on rank 0."""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
|
||||
if rank_zero_only.rank == 0:
|
||||
return fn(*args, **kwargs)
|
||||
return None
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
# TODO: this should be part of the cluster environment
|
||||
def _get_rank() -> int:
|
||||
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
|
||||
for key in rank_keys:
|
||||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
return int(rank)
|
||||
return 0
|
||||
|
||||
|
||||
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
||||
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
|
||||
|
||||
|
||||
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if python_version() >= "3.8.0":
|
||||
kwargs["stacklevel"] = stacklevel
|
||||
log.info(*args, **kwargs)
|
||||
|
||||
|
||||
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if python_version() >= "3.8.0":
|
||||
kwargs["stacklevel"] = stacklevel
|
||||
log.debug(*args, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log debug-level messages only on rank 0."""
|
||||
_debug(*args, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log info-level messages only on rank 0."""
|
||||
_info(*args, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if type(stacklevel) is type and issubclass(stacklevel, Warning):
|
||||
rank_zero_deprecation(
|
||||
"Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8"
|
||||
f" Please, use `category={stacklevel.__name__}`."
|
||||
)
|
||||
kwargs["category"] = stacklevel
|
||||
stacklevel = kwargs.pop("stacklevel", 2)
|
||||
warnings.warn(message, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log warn-level messages only on rank 0."""
|
||||
_warn(message, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
class LightningDeprecationWarning(DeprecationWarning):
|
||||
"""Deprecation warnings raised by PyTorch Lightning."""
|
||||
|
||||
|
||||
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)
|
|
@ -21,8 +21,7 @@ from typing import Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -14,50 +14,55 @@
|
|||
"""Warning-related utilities."""
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning as NewLightningDeprecationWarning
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation as new_rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn
|
||||
|
||||
|
||||
def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
if type(stacklevel) is type and issubclass(stacklevel, Warning):
|
||||
rank_zero_deprecation(
|
||||
"Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8"
|
||||
f" Please, use `category={stacklevel.__name__}`."
|
||||
)
|
||||
kwargs["category"] = stacklevel
|
||||
stacklevel = kwargs.pop("stacklevel", 2)
|
||||
warnings.warn(message, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
|
||||
"""Function used to log warn-level messages only on rank 0."""
|
||||
_warn(message, stacklevel=stacklevel, **kwargs)
|
||||
# enable our warnings
|
||||
warnings.simplefilter("default", category=NewLightningDeprecationWarning)
|
||||
|
||||
|
||||
class PossibleUserWarning(UserWarning):
|
||||
"""Warnings that could be false positives."""
|
||||
|
||||
|
||||
class LightningDeprecationWarning(DeprecationWarning):
|
||||
"""Deprecation warnings raised by PyTorch Lightning."""
|
||||
|
||||
|
||||
# enable our warnings
|
||||
warnings.simplefilter("default", category=LightningDeprecationWarning)
|
||||
|
||||
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)
|
||||
|
||||
|
||||
class WarningCache(set):
|
||||
def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
|
||||
if message not in self:
|
||||
self.add(message)
|
||||
rank_zero_warn(message, stacklevel=stacklevel, **kwargs)
|
||||
new_rank_zero_warn(message, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
def deprecation(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
|
||||
if message not in self:
|
||||
self.add(message)
|
||||
rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
|
||||
new_rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
|
||||
|
||||
|
||||
def rank_zero_warn(*args: Any, **kwargs: Any) -> Any:
|
||||
new_rank_zero_deprecation(
|
||||
"pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
||||
)
|
||||
return new_rank_zero_warn(*args, **kwargs)
|
||||
|
||||
|
||||
def rank_zero_deprecation(*args: Any, **kwargs: Any) -> Any:
|
||||
new_rank_zero_deprecation(
|
||||
"pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
||||
)
|
||||
return new_rank_zero_deprecation(*args, **kwargs)
|
||||
|
||||
|
||||
class LightningDeprecationWarning(NewLightningDeprecationWarning):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
new_rank_zero_deprecation(
|
||||
"pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
" Use the equivalent class from the pytorch_lightning.utilities.rank_zero module instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -19,8 +19,8 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.callbacks import DeviceStatsMonitor
|
||||
from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ def datadir():
|
|||
@pytest.fixture(scope="function", autouse=True)
|
||||
def preserve_global_rank_variable():
|
||||
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
rank = getattr(rank_zero_only, "rank", None)
|
||||
yield
|
||||
|
|
|
@ -32,10 +32,10 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
|
|||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
|
||||
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from tests.helpers.boring_model import BoringDataModule, BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
|
||||
|
@ -413,3 +413,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir):
|
|||
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_8_0_rank_zero_imports():
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info
|
||||
from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
):
|
||||
rank_zero_debug("foo")
|
||||
with pytest.deprecated_call(
|
||||
match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
):
|
||||
rank_zero_info("foo")
|
||||
with pytest.deprecated_call(
|
||||
match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
):
|
||||
rank_zero_warn("foo")
|
||||
with pytest.deprecated_call(
|
||||
match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
):
|
||||
rank_zero_deprecation("foo")
|
||||
with pytest.deprecated_call(
|
||||
match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
|
||||
" and will be removed in v1.8."
|
||||
):
|
||||
warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
|
||||
|
|
|
@ -24,9 +24,9 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from tests.helpers.boring_model import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment
|
||||
from pytorch_lightning.strategies import DDP2Strategy, DDPShardedStrategy, DDPStrategy, DeepSpeedStrategy
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Mapping
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
|
||||
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
|
||||
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
@rank_zero_only
|
||||
def foo(): # The return type is optional because on non-zero ranks it will not be called
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
|
||||
def test_rank_zero_none_set(rank_key, rank):
|
||||
"""Test that function is not called when rank environment variables are not global zero."""
|
||||
|
||||
with mock.patch.dict(os.environ, {rank_key: rank}):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
@rank_zero_only
|
||||
def foo():
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x is None
|
|
@ -12,10 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Mapping
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
@ -24,43 +21,6 @@ from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
|
|||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
|
||||
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
|
||||
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
|
||||
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
@rank_zero_only
|
||||
def foo(): # The return type is optional because on non-zero ranks it will not be called
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
|
||||
def test_rank_zero_none_set(rank_key, rank):
|
||||
"""Test that function is not called when rank environment variables are not global zero."""
|
||||
|
||||
with mock.patch.dict(os.environ, {rank_key: rank}):
|
||||
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
@rank_zero_only
|
||||
def foo():
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x is None
|
||||
|
||||
|
||||
def _test_collect_states(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ import os
|
|||
from contextlib import redirect_stderr
|
||||
from io import StringIO
|
||||
|
||||
from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache
|
||||
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1"
|
||||
if standalone:
|
||||
|
@ -40,16 +41,16 @@ if standalone:
|
|||
cache.deprecation("test7")
|
||||
|
||||
output = stderr.getvalue()
|
||||
assert "test_warnings.py:30: UserWarning: test1" in output
|
||||
assert "test_warnings.py:31: DeprecationWarning: test2" in output
|
||||
assert "test_warnings.py:31: UserWarning: test1" in output
|
||||
assert "test_warnings.py:32: DeprecationWarning: test2" in output
|
||||
|
||||
assert "test_warnings.py:33: UserWarning: test3" in output
|
||||
assert "test_warnings.py:34: DeprecationWarning: test4" in output
|
||||
assert "test_warnings.py:34: UserWarning: test3" in output
|
||||
assert "test_warnings.py:35: DeprecationWarning: test4" in output
|
||||
|
||||
assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output
|
||||
assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output
|
||||
|
||||
assert "test_warnings.py:39: UserWarning: test6" in output
|
||||
assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output
|
||||
assert "test_warnings.py:40: UserWarning: test6" in output
|
||||
assert "test_warnings.py:41: LightningDeprecationWarning: test7" in output
|
||||
|
||||
# check that logging is properly configured
|
||||
import logging
|
||||
|
|
Loading…
Reference in New Issue