move pytorch_lightning >> lightning/pytorch (#16594)

This commit is contained in:
Jirka Borovec 2023-02-02 03:22:42 +09:00 committed by GitHub
parent 01b152f169
commit 7d4780adb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
197 changed files with 1283 additions and 1264 deletions

View File

@ -26,7 +26,8 @@ pr:
- "examples/pl_basics/backbone_image_classifier.py"
- "examples/pl_basics/autoencoder.py"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
- "requirements/fabric/**"

View File

@ -24,7 +24,8 @@ pr:
- "src/lightning/fabric/**"
- "src/lightning_fabric/*"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
exclude:

View File

@ -21,7 +21,8 @@ pr:
- "src/lightning/fabric/**"
- "src/lightning_fabric/*"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
exclude:

View File

@ -11,7 +11,8 @@ subprojects:
- "src/lightning/fabric/**"
- "src/lightning_fabric/*"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "tests/legacy/**"
- "pyproject.toml" # includes pytest config
@ -49,7 +50,8 @@ subprojects:
- "examples/pl_basics/backbone_image_classifier.py"
- "examples/pl_basics/autoencoder.py"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
- "requirements/fabric/**"
@ -82,7 +84,8 @@ subprojects:
- "src/lightning/fabric/**"
- "src/lightning_fabric/*"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
- "!requirements/*/docs.txt"
@ -99,7 +102,8 @@ subprojects:
- "src/lightning/fabric/**"
- "src/lightning_fabric/*"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
- "!requirements/docs.txt"
@ -130,7 +134,8 @@ subprojects:
- id: "pytorch_lightning: Docs"
paths:
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "docs/source-pytorch/**"
- ".actions/**"
- ".github/workflows/docs-checks.yml"

3
.github/labeler.yml vendored
View File

@ -9,7 +9,8 @@ app:
- 'requirements/app/**'
pl:
- 'src/pytorch_lightning/**'
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- 'tests/tests_pytorch/**'
- 'tests/legacy/**'
- 'examples/pl_*/**'

View File

@ -10,7 +10,8 @@ on:
paths:
- ".actions/**"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "tests/legacy/**"
- "pyproject.toml" # includes pytest config

View File

@ -15,7 +15,8 @@ on:
- "src/lightning_fabric/*"
- "tests/tests_fabric/**"
- "requirements/pytorch/**"
- "src/pytorch_lightning/**"
- "src/lightning/pytorch/**"
- "src/pytorch_lightning/*"
- "tests/tests_pytorch/**"
- "pyproject.toml" # includes pytest config
- "!requirements/*/docs.txt"

5
.gitignore vendored
View File

@ -57,7 +57,10 @@ src/lightning_fabric/
!src/lightning_fabric/__*__.py
!src/lightning_fabric/MANIFEST.in
!src/lightning_fabric/README.md
src/lightning/pytorch/
src/pytorch_lightning/
!src/pytorch_lightning/__*__.py
!src/pytorch_lightning/MANIFEST.in
!src/pytorch_lightning/README.md
# PyInstaller
# Usually these files are written by a python script from a template

View File

@ -42,9 +42,8 @@ repos:
docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif|
docs/source-pytorch/_static/images/general/pl_overview_flat.jpg|
docs/source-pytorch/_static/images/general/pl_overview.gif|
src/lightning_app/cli/pl-app-template/ui/yarn.lock|
src/pytorch_lightning/CHANGELOG.md|
src/lightning/fabric/CHANGELOG.md
src/lightning/fabric/CHANGELOG.md|
src/lightning/pytorch/CHANGELOG.md
)$
- id: detect-private-key
@ -100,7 +99,7 @@ repos:
(?x)^(
src/lightning/app/CHANGELOG.md|
src/lightning/fabric/CHANGELOG.md|
src/pytorch_lightning/CHANGELOG.md
src/lightning/pytorch/CHANGELOG.md
)$
- repo: https://github.com/charliermarsh/ruff-pre-commit

View File

@ -58,7 +58,7 @@ _PACKAGE_MAPPING = {
"fabric": "lightning_fabric",
}
# TODO: drop this reverse list when all packages are moved
_MIRROR_PACKAGE_REVERSED = ("app", "fabric")
_MIRROR_PACKAGE_REVERSED = ("app", "fabric", "pytorch")
# https://packaging.python.org/guides/single-sourcing-package-version/
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
_PATH_ROOT = os.path.dirname(__file__)
@ -142,7 +142,7 @@ if __name__ == "__main__":
package_to_install = _PACKAGE_MAPPING.get(_PACKAGE_NAME, "lightning")
if package_to_install == "lightning":
# merge all requirements files
assistant._load_aggregate_requirements(_PATH_REQUIRE, _FREEZE_REQUIREMENTS) # install everything
assistant._load_aggregate_requirements(_PATH_REQUIRE, _FREEZE_REQUIREMENTS)
# replace imports and copy the code
assistant.create_mirror_package(_PATH_SRC, _PACKAGE_MAPPING, reverse=_MIRROR_PACKAGE_REVERSED)
else:

View File

@ -207,6 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [1.9.0] - 2023-01-17
### Added
- Added support for native logging of `MetricCollection` with enabled compute groups ([#15580](https://github.com/Lightning-AI/lightning/pull/15580))
- Added support for custom artifact names in `pl.loggers.WandbLogger` ([#16173](https://github.com/Lightning-AI/lightning/pull/16173))
- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))
@ -223,7 +224,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added info message for Ampere CUDA GPU users to enable tf32 matmul precision ([#16037](https://github.com/Lightning-AI/lightning/pull/16037))
- Added support for returning optimizer-like classes in `LightningModule.configure_optimizers` ([#16189](https://github.com/Lightning-AI/lightning/pull/16189))
### Changed
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

View File

@ -1,12 +1,18 @@
"""Root package info."""
import logging
import os
from typing import Any
from pytorch_lightning.__about__ import * # noqa: F401, F403
from lightning_utilities import module_available
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__about__.py")):
from lightning.pytorch.__about__ import * # noqa: F401, F403
if "__version__" not in locals():
from pytorch_lightning.__version__ import version as __version__ # noqa: F401
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
from lightning.pytorch.__version__ import version as __version__
elif module_available("lightning"):
from lightning import __version__ # noqa: F401
_DETAIL = 15 # between logging.INFO and logging.DEBUG, used for logging in production use cases
@ -30,13 +36,13 @@ if not _root_logger.hasHandlers():
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False
from lightning_fabric.utilities.seed import seed_everything # noqa: E402
from pytorch_lightning.callbacks import Callback # noqa: E402
from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402
from pytorch_lightning.trainer import Trainer # noqa: E402
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.pytorch.callbacks import Callback # noqa: E402
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
from lightning.pytorch.trainer import Trainer # noqa: E402
# this import needs to go last as it will patch other modules
import pytorch_lightning._graveyard # noqa: E402, F401 # isort: skip
import lightning.pytorch._graveyard # noqa: E402, F401 # isort: skip
__all__ = ["Trainer", "LightningDataModule", "LightningModule", "Callback", "seed_everything"]

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning._graveyard.legacy_import_unpickler # noqa: F401
import lightning.pytorch._graveyard.legacy_import_unpickler # noqa: F401

View File

@ -11,7 +11,7 @@ def _patch_pl_to_mirror_if_necessary(module: str) -> str:
if module.startswith(pl):
# for the standalone package this won't do anything,
# for the unified mirror package it will redirect the imports
module = "pytorch_lightning" + module[len(pl) :]
module = "lightning.pytorch" + module[len(pl) :]
return module
@ -29,7 +29,7 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:
return _compare_version(new_package, op, version, use_base_version)
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to pytorch_lightning,
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to lightning.pytorch,
# which has to be redirected to the unified package:
# https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96
try:

View File

@ -10,16 +10,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_fabric.accelerators import find_usable_cuda_devices # noqa: F401
from lightning_fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
from lightning.fabric.accelerators import find_usable_cuda_devices # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning.pytorch.accelerators.accelerator import Accelerator # noqa: F401
from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401
from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401
from lightning.pytorch.accelerators.hpu import HPUAccelerator # noqa: F401
from lightning.pytorch.accelerators.ipu import IPUAccelerator # noqa: F401
from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.pytorch.accelerators.tpu import TPUAccelerator # noqa: F401
ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators"
ACCELERATORS_BASE_MODULE = "lightning.pytorch.accelerators"
AcceleratorRegistry = _AcceleratorRegistry()
call_register_accelerators(AcceleratorRegistry, ACCELERATORS_BASE_MODULE)

View File

@ -14,9 +14,9 @@
from abc import ABC
from typing import Any, Dict
import pytorch_lightning as pl
from lightning_fabric.accelerators.accelerator import Accelerator as _Accelerator
from lightning_fabric.utilities.types import _DEVICE
import lightning.pytorch as pl
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
from lightning.fabric.utilities.types import _DEVICE
class Accelerator(_Accelerator, ABC):

View File

@ -15,11 +15,11 @@ from typing import Any, Dict, List, Union
import torch
from lightning_fabric.accelerators.cpu import _parse_cpu_cores
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
from lightning.fabric.accelerators.cpu import _parse_cpu_cores
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
class CPUAccelerator(Accelerator):

View File

@ -19,12 +19,12 @@ from typing import Any, Dict, List, Optional, Union
import torch
import pytorch_lightning as pl
from lightning_fabric.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import lightning.pytorch as pl
from lightning.fabric.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException
_log = logging.getLogger(__name__)

View File

@ -16,11 +16,11 @@ from typing import Any, Dict, List, Optional, Union
import torch
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _HPU_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
if _HPU_AVAILABLE:
import habana_frameworks.torch.hpu as torch_hpu
@ -108,7 +108,7 @@ class HPUAccelerator(Accelerator):
def _parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]:
"""
Parses the hpus given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer` for the `devices` flag.
:class:`~lightning.pytorch.trainer.Trainer` for the `devices` flag.
Args:
devices: An integer that indicates the number of Gaudi devices to be used

View File

@ -15,9 +15,9 @@ from typing import Any, Dict, List
import torch
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.imports import _IPU_AVAILABLE
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.imports import _IPU_AVAILABLE
class IPUAccelerator(Accelerator):

View File

@ -15,12 +15,12 @@ from typing import Any, Dict, List, Optional, Union
import torch
from lightning_fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
from lightning.fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
class MPSAccelerator(Accelerator):

View File

@ -15,10 +15,10 @@ from typing import Any, Dict, List, Optional, Union
import torch
from lightning_fabric.accelerators.tpu import _parse_tpu_devices, _XLA_AVAILABLE
from lightning_fabric.accelerators.tpu import TPUAccelerator as FabricTPUAccelerator
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.tpu import _parse_tpu_devices, _XLA_AVAILABLE
from lightning.fabric.accelerators.tpu import TPUAccelerator as FabricTPUAccelerator
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
class TPUAccelerator(Accelerator):

View File

@ -11,26 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_summary import ModelSummary
from pytorch_lightning.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from pytorch_lightning.callbacks.timer import Timer
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.callbacks.checkpoint import Checkpoint
from lightning.pytorch.callbacks.device_stats_monitor import DeviceStatsMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from lightning.pytorch.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from lightning.pytorch.callbacks.lambda_function import LambdaCallback
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.callbacks.model_summary import ModelSummary
from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
from lightning.pytorch.callbacks.pruning import ModelPruning
from lightning.pytorch.callbacks.quantization import QuantizationAwareTraining
from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from lightning.pytorch.callbacks.timer import Timer
__all__ = [
"BackboneFinetuning",

View File

@ -20,12 +20,12 @@ Finds optimal batch size
from typing import Optional
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.tuner.batch_size_scaling import _scale_batch_size
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.tuner.batch_size_scaling import _scale_batch_size
from lightning.pytorch.utilities.exceptions import _TunerExitException, MisconfigurationException
from lightning.pytorch.utilities.parsing import lightning_hasattr
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
class BatchSizeFinder(Callback):
@ -63,7 +63,7 @@ class BatchSizeFinder(Callback):
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
# useful while fine-tuning models since you can't always use the same batch size after
# unfreezing the backbone.
from pytorch_lightning.callbacks import BatchSizeFinder
from lightning.pytorch.callbacks import BatchSizeFinder
class FineTuneBatchSizeFinder(BatchSizeFinder):
@ -85,7 +85,7 @@ class BatchSizeFinder(Callback):
Example::
# 2. Run batch size finder for validate/test/predict.
from pytorch_lightning.callbacks import BatchSizeFinder
from lightning.pytorch.callbacks import BatchSizeFinder
class EvalBatchSizeFinder(BatchSizeFinder):

View File

@ -21,8 +21,8 @@ from typing import Any, Dict, List, Optional, Type
from torch import Tensor
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import STEP_OUTPUT
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import STEP_OUTPUT
class Callback:
@ -217,8 +217,8 @@ class Callback:
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
checkpoint: the checkpoint dictionary that will be saved.
"""
@ -229,8 +229,8 @@ class Callback:
Called when loading a model checkpoint, use to reload state.
Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
"""

View File

@ -1,9 +1,9 @@
from pytorch_lightning.callbacks.callback import Callback
from lightning.pytorch.callbacks.callback import Callback
class Checkpoint(Callback):
r"""
This is the base class for model checkpointing. Expert users may want to subclass it in case of writing
custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that
custom :class:`~lightning.pytorch.callbacksCheckpoint` callback, so that
the trainer recognizes the custom class as a checkpointing callback.
"""

View File

@ -20,12 +20,12 @@ Monitors and logs device stats during training.
"""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT
class DeviceStatsMonitor(Callback):
@ -46,8 +46,8 @@ class DeviceStatsMonitor(Callback):
If ``Trainer`` has no logger.
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import DeviceStatsMonitor
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
"""
@ -91,7 +91,7 @@ class DeviceStatsMonitor(Callback):
if self._cpu_stats and device.type != "cpu":
# Don't query CPU stats twice if CPU is accelerator
from pytorch_lightning.accelerators.cpu import get_cpu_stats
from lightning.pytorch.accelerators.cpu import get_cpu_stats
device_stats.update(get_cpu_stats())

View File

@ -25,11 +25,11 @@ import numpy as np
import torch
from torch import Tensor
import pytorch_lightning as pl
from lightning_fabric.utilities.rank_zero import _get_rank
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
import lightning.pytorch as pl
from lightning.fabric.utilities.rank_zero import _get_rank
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
log = logging.getLogger(__name__)
@ -74,8 +74,8 @@ class EarlyStopping(Callback):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])
@ -174,7 +174,7 @@ class EarlyStopping(Callback):
self.patience = state_dict["patience"]
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
from lightning.pytorch.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

View File

@ -24,10 +24,10 @@ from torch.nn import Module, ModuleDict
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)
@ -328,8 +328,8 @@ class BackboneFinetuning(BaseFinetuning):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])

View File

@ -22,9 +22,9 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number.
from typing import Any, Dict
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
class GradientAccumulationScheduler(Callback):
@ -51,8 +51,8 @@ class GradientAccumulationScheduler(Callback):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler
# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
# because epoch (key) should be zero-indexed.

View File

@ -21,7 +21,7 @@ Create a simple callback on the fly using lambda functions.
from typing import Callable, Optional
from pytorch_lightning.callbacks.callback import Callback
from lightning.pytorch.callbacks.callback import Callback
class LambdaCallback(Callback):
@ -29,12 +29,12 @@ class LambdaCallback(Callback):
Create a simple callback on the fly using lambda functions.
Args:
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.callback.Callback`
**kwargs: hooks supported by :class:`~lightning.pytorch.callbacks.callback.Callback`
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LambdaCallback
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import LambdaCallback
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
"""

View File

@ -19,11 +19,11 @@ Finds optimal learning rate
"""
from typing import Optional
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.tuner.lr_finder import _lr_find, _LRFinder
from pytorch_lightning.utilities.exceptions import _TunerExitException
from pytorch_lightning.utilities.seed import isolate_rng
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.tuner.lr_finder import _lr_find, _LRFinder
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.seed import isolate_rng
class LearningRateFinder(Callback):
@ -50,7 +50,7 @@ class LearningRateFinder(Callback):
# Customize LearningRateFinder callback to run at different epochs.
# This feature is useful while fine-tuning models.
from pytorch_lightning.callbacks import LearningRateFinder
from lightning.pytorch.callbacks import LearningRateFinder
class FineTuneLearningRateFinder(LearningRateFinder):

View File

@ -25,11 +25,11 @@ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig
class LearningRateMonitor(Callback):
@ -49,8 +49,8 @@ class LearningRateMonitor(Callback):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateMonitor
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import LearningRateMonitor
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_monitor])

View File

@ -32,13 +32,13 @@ import torch
import yaml
from torch import Tensor
import pytorch_lightning as pl
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import STEP_OUTPUT
import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.types import STEP_OUTPUT
log = logging.getLogger(__name__)
warning_cache = WarningCache()
@ -47,7 +47,7 @@ warning_cache = WarningCache()
class ModelCheckpoint(Checkpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in
:meth:`~lightning.pytorch.core.module.log` or :meth:`~lightning.pytorch.core.module.log_dict` in
LightningModule is a candidate for the monitor key. For more information, see
:ref:`checkpointing`.
@ -64,8 +64,8 @@ class ModelCheckpoint(Checkpoint):
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is ``None`` and will be set at runtime to the location
specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` argument,
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
and if the Trainer uses a logger, the path will also contain logger name and version.
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
@ -157,8 +157,8 @@ class ModelCheckpoint(Checkpoint):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import ModelCheckpoint
# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
@ -371,7 +371,7 @@ class ModelCheckpoint(Checkpoint):
logger.after_save_checkpoint(proxy(self))
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
from lightning.pytorch.trainer.states import TrainerFn
return (
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run

View File

@ -15,7 +15,7 @@
Model Summary
=============
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`.
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`.
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
@ -24,19 +24,19 @@ the name, type and number of parameters for each layer.
import logging
from typing import List, Tuple, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.model_summary import DeepSpeedSummary
from pytorch_lightning.utilities.model_summary import ModelSummary as Summary
from pytorch_lightning.utilities.model_summary import summarize
from pytorch_lightning.utilities.model_summary.model_summary import _format_summary_table
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.model_summary import DeepSpeedSummary
from lightning.pytorch.utilities.model_summary import ModelSummary as Summary
from lightning.pytorch.utilities.model_summary import summarize
from lightning.pytorch.utilities.model_summary.model_summary import _format_summary_table
log = logging.getLogger(__name__)
class ModelSummary(Callback):
r"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`.
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`.
Args:
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
@ -44,8 +44,8 @@ class ModelSummary(Callback):
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelSummary
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import ModelSummary
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
"""
@ -66,7 +66,7 @@ class ModelSummary(Callback):
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
if isinstance(trainer.strategy, DeepSpeedStrategy) and trainer.strategy.zero_stage_3:
return DeepSpeedSummary(pl_module, max_depth=self._max_depth)

View File

@ -20,9 +20,9 @@ Automatically save a checkpoints on exception.
import os
from typing import Any
import pytorch_lightning as pl
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
import lightning.pytorch as pl
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
class OnExceptionCheckpoint(Checkpoint):
@ -38,8 +38,8 @@ class OnExceptionCheckpoint(Checkpoint):
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import OnExceptionCheckpoint
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import OnExceptionCheckpoint
>>> trainer = Trainer(callbacks=[OnExceptionCheckpoint(".")])
"""

View File

@ -19,10 +19,10 @@ Aids in saving predictions
"""
from typing import Any, Literal, Optional, Sequence
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities import LightningEnum
from lightning.pytorch.utilities.exceptions import MisconfigurationException
class WriteInterval(LightningEnum):
@ -48,7 +48,7 @@ class BasePredictionWriter(Callback):
Example::
import torch
from pytorch_lightning.callbacks import BasePredictionWriter
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
@ -75,7 +75,7 @@ class BasePredictionWriter(Callback):
# multi-device inference example
import torch
from pytorch_lightning.callbacks import BasePredictionWriter
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):

View File

@ -18,6 +18,6 @@ Progress Bars
Use or override one of the progress bar callbacks.
"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
from lightning.pytorch.callbacks.progress.base import ProgressBarBase # noqa: F401
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401

View File

@ -13,16 +13,16 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.logger import _version
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.logger import _version
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
class ProgressBarBase(Callback):
r"""
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback`
that keeps track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
You should implement your highly custom progress bars with this as the base class.
Example::
@ -207,7 +207,7 @@ class ProgressBarBase(Callback):
def enable(self) -> None:
"""You should provide a way to enable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the :ref:`learning rate finder <advanced/training_tricks:Learning Rate Finder>`.
to temporarily enable and disable the main progress bar.
"""

View File

@ -18,9 +18,9 @@ from typing import Any, cast, Dict, Optional, Union
from lightning_utilities.core.imports import RequirementCache
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.types import STEP_OUTPUT
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.utilities.types import STEP_OUTPUT
_RICH_AVAILABLE: bool = RequirementCache("rich>=10.2.2")
@ -215,8 +215,8 @@ class RichProgressBar(ProgressBarBase):
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=RichProgressBar())

View File

@ -25,9 +25,9 @@ if importlib.util.find_spec("ipywidgets") is not None:
else:
from tqdm import tqdm as _tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
_PAD_SIZE = 5
@ -65,7 +65,7 @@ class TQDMProgressBar(ProgressBarBase):
- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.
@ -74,7 +74,7 @@ class TQDMProgressBar(ProgressBarBase):
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
Example:
@ -85,7 +85,7 @@ class TQDMProgressBar(ProgressBarBase):
... return bar
...
>>> bar = LitProgressBar()
>>> from pytorch_lightning import Trainer
>>> from lightning.pytorch import Trainer
>>> trainer = Trainer(callbacks=[bar])
Args:
@ -94,8 +94,8 @@ class TQDMProgressBar(ProgressBarBase):
process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
"""
def __init__(self, refresh_rate: int = 1, process_position: int = 0):

View File

@ -26,11 +26,11 @@ from lightning_utilities.core.apply_func import apply_to_collection
from torch import nn, Tensor
from typing_extensions import TypedDict
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_only
log = logging.getLogger(__name__)

View File

@ -25,10 +25,10 @@ from torch import Tensor
from torch.ao.quantization.qconfig import QConfig
from torch.quantization import FakeQuantizeBase
import pytorch_lightning as pl
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_11:
from torch.ao.quantization import fuse_modules_qat as fuse_modules

View File

@ -13,9 +13,9 @@
# limitations under the License.
from typing import List, Tuple
from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from pytorch_lightning.utilities.model_summary import get_human_readable_count
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.model_summary import get_human_readable_count
if _RICH_AVAILABLE:
from rich import get_console
@ -24,7 +24,7 @@ if _RICH_AVAILABLE:
class RichModelSummary(ModelSummary):
r"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`
with `rich text formatting <https://github.com/Textualize/rich>`_.
Install it with pip:
@ -35,17 +35,17 @@ class RichModelSummary(ModelSummary):
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichModelSummary
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import RichModelSummary
trainer = Trainer(callbacks=RichModelSummary())
You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
You could also enable ``RichModelSummary`` using the :class:`~lightning.pytorch.callbacks.RichProgressBar`
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=RichProgressBar())

View File

@ -22,14 +22,14 @@ import torch
from torch import nn, Tensor
from torch.optim.swa_utils import SWALR
import pytorch_lightning as pl
from lightning_fabric.utilities.types import LRScheduler
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.fsdp import FSDPStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
import lightning.pytorch as pl
from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig
_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]

View File

@ -20,12 +20,12 @@ import time
from datetime import timedelta
from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities import LightningEnum
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info
log = logging.getLogger(__name__)
@ -52,8 +52,8 @@ class Timer(Callback):
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Timer
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Timer
# stop training after 12 hours
timer = Timer(duration="00:12:00:00")

View File

@ -22,13 +22,13 @@ from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.17.0")
@ -292,21 +292,21 @@ class LightningCLI:
.. warning:: ``LightningCLI`` is in beta and subject to change.
Args:
model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a
callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when
model_class: An optional :class:`~lightning.pytorch.core.module.LightningModule` class to train on or a
callable which returns a :class:`~lightning.pytorch.core.module.LightningModule` instance when
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a
callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
save_config_callback: A callback class to save the config.
save_config_kwargs: Parameters that will be used to instantiate the save_config_callback.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a
callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
this argument will not be configurable from a configuration file and will always be present for
this particular CLI. Alternatively, configurable callbacks can be added as explained in
:ref:`the CLI docs <lightning-cli>`.
seed_everything_default: Number for the :func:`~lightning_fabric.utilities.seed.seed_everything`
seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything`
seed value. Set to True to automatically choose a seed value.
Setting it to False will avoid calling ``seed_everything``.
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
@ -319,7 +319,7 @@ class LightningCLI:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style
arguments can be given in a ``list``. Alternatively, structured config options can be given in a
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.save_config_callback = save_config_callback
@ -362,7 +362,7 @@ class LightningCLI:
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
"""Method that instantiates the argument parser."""
kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"])
kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"])
parser = LightningArgumentParser(**kwargs)
parser.add_argument(
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
@ -549,7 +549,7 @@ class LightningCLI:
def configure_optimizers(
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
) -> Any:
"""Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers`
"""Override to customize the :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers`
method.
Args:
@ -567,7 +567,7 @@ class LightningCLI:
return [optimizer], [lr_scheduler]
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
"""Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
if not self.auto_configure_optimizers:
return

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.module import LightningModule
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning.pytorch.core.module import LightningModule
__all__ = ["LightningDataModule", "LightningModule"]

View File

@ -19,18 +19,18 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio
from torch.utils.data import DataLoader, Dataset, IterableDataset
from typing_extensions import Self
import pytorch_lightning as pl
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.core.hooks import DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import _load_from_checkpoint
from pytorch_lightning.utilities.argparse import (
import lightning.pytorch as pl
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.core.hooks import DataHooks
from lightning.pytorch.core.mixins import HyperparametersMixin
from lightning.pytorch.core.saving import _load_from_checkpoint
from lightning.pytorch.utilities.argparse import (
add_argparse_args,
from_argparse_args,
get_init_arguments_and_types,
parse_argparser,
)
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
from lightning.pytorch.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
class LightningDataModule(DataHooks, HyperparametersMixin):
@ -91,7 +91,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
parsed and passed to the :class:`~lightning.pytorch.core.datamodule.LightningDataModule`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid DataModule arguments.

View File

@ -19,9 +19,9 @@ import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from lightning.pytorch.utilities import move_data_to_device
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
class ModelHooks:
@ -233,7 +233,7 @@ class ModelHooks:
"""Called before ``optimizer.step()``.
If using gradient accumulation, the hook is called once the gradients have been accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
See: :paramref:`~lightning.pytorch.trainer.Trainer.accumulate_grad_batches`.
If using AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
@ -378,7 +378,7 @@ class DataHooks:
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
The dataloader you return will not be reloaded unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
a positive integer.
For data processing use the following pattern:
@ -390,7 +390,7 @@ class DataHooks:
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
- :meth:`prepare_data`
- :meth:`setup`
@ -455,7 +455,7 @@ class DataHooks:
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`
- :meth:`prepare_data`
- :meth:`setup`
@ -500,13 +500,13 @@ class DataHooks:
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
a positive integer.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`
- :meth:`prepare_data`
- :meth:`setup`
@ -552,7 +552,7 @@ class DataHooks:
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`
- :meth:`prepare_data`
- :meth:`setup`

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin # noqa: F401
from lightning.pytorch.core.mixins.hparams_mixin import HyperparametersMixin # noqa: F401

View File

@ -17,8 +17,8 @@ import types
from argparse import Namespace
from typing import Any, List, MutableMapping, Optional, Sequence, Union
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters
from lightning.pytorch.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
class HyperparametersMixin:
@ -47,7 +47,7 @@ class HyperparametersMixin:
logger: Whether to send the hyperparameters to the logger. Default: True
Example::
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
@ -60,7 +60,7 @@ class HyperparametersMixin:
"arg1": 1
"arg3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class AutomaticArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
@ -74,7 +74,7 @@ class HyperparametersMixin:
"arg2": abc
"arg3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class SingleArgModel(HyperparametersMixin):
... def __init__(self, params):
... super().__init__()
@ -88,7 +88,7 @@ class HyperparametersMixin:
"p2": abc
"p3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()

View File

@ -27,28 +27,28 @@ from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric, MetricCollection
import lightning_fabric as lf
import pytorch_lightning as pl
from lightning_fabric.loggers import Logger as FabricLogger
from lightning_fabric.utilities.apply_func import convert_to_tensors
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_fabric.utilities.distributed import _distributed_available, _sync_ddp
from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0
from lightning_fabric.wrappers import _FabricOptimizer
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.loggers import Logger
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import (
import lightning.fabric as lf
import lightning.pytorch as pl
from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from lightning.pytorch.core.mixins import HyperparametersMixin
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.core.saving import ModelIO
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
EPOCH_OUTPUT,
LRSchedulerPLType,
@ -135,7 +135,7 @@ class LightningModule(
Args:
use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
:class:`~pytorch_lightning.core.optimizer.LightningOptimizer` for automatic handling of precision and
:class:`~lightning.pytorch.core.optimizer.LightningOptimizer` for automatic handling of precision and
profiling.
Returns:
@ -1168,16 +1168,16 @@ class LightningModule(
"""
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic.
"""Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it
calls :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic.
The :meth:`~pytorch_lightning.core.module.LightningModule.predict_step` is used
The :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` is used
to scale inference on multi-devices.
To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter`
To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
callback to write the predictions to disk or database after each batch or on epoch end.
The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn
The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
@ -1209,7 +1209,7 @@ class LightningModule(
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
Lightning will make sure :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks
run last.
Return:
@ -1309,7 +1309,7 @@ class LightningModule(
)
Metrics can be made available to monitor by simply logging it using
``self.log('metric_to_track', metric_val)`` in your :class:`~pytorch_lightning.core.module.LightningModule`.
``self.log('metric_to_track', metric_val)`` in your :class:`~lightning.pytorch.core.module.LightningModule`.
Note:
Some things to know:
@ -1504,7 +1504,7 @@ class LightningModule(
def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
:class:`~lightning.pytorch.trainer.trainer.Trainer` calls each scheduler.
By default, Lightning calls ``step()`` and as shown in the example
for each scheduler based on its ``interval``.
@ -1539,7 +1539,7 @@ class LightningModule(
optimizer_closure: Optional[Callable[[], Any]] = None,
) -> None:
r"""
Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
the optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
@ -1695,7 +1695,7 @@ class LightningModule(
Note:
- Requires the implementation of the
:meth:`~pytorch_lightning.core.module.LightningModule.forward` method.
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method.
- The exported script will be set to evaluation mode.
- It is recommended that you install the latest supported version of PyTorch
to use this feature without limitations. See also the :mod:`torch.jit`

View File

@ -20,12 +20,12 @@ import torch
from torch import optim
from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
import lightning.pytorch as pl
from lightning.fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
def do_nothing_closure() -> None:
@ -79,7 +79,7 @@ class LightningOptimizer:
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
# local import here to avoid circular import
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
assert self._strategy is not None
lightning_module = self._strategy.lightning_module

View File

@ -28,15 +28,15 @@ import yaml
from lightning_utilities.core.apply_func import apply_to_collection
from typing_extensions import Self
import pytorch_lightning as pl
from lightning_fabric.utilities.cloud_io import _load as pl_load
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.pytorch.utilities import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.migration import pl_legacy_patch
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
from lightning.pytorch.utilities.parsing import AttributeDict, parse_class_init_keys
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)
PRIMITIVE_TYPES = (bool, int, float, str)

View File

@ -20,10 +20,10 @@ from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import LightningDataModule, LightningModule
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
class RandomDictDataset(Dataset):

View File

@ -24,9 +24,9 @@ import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, random_split
from lightning_fabric.utilities.imports import _IS_WINDOWS
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch import LightningDataModule
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib

View File

@ -13,13 +13,13 @@
# limitations under the License.
import os
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F401
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from lightning.pytorch.loggers.neptune import NeptuneLogger # noqa: F401
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.wandb import WandbLogger # noqa: F401
__all__ = ["CSVLogger", "Logger", "TensorBoardLogger"]

View File

@ -25,10 +25,10 @@ from lightning_utilities.core.imports import module_available
from torch import Tensor
from torch.nn import Module
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only
log = logging.getLogger(__name__)
_COMET_AVAILABLE = module_available("comet_ml")
@ -54,7 +54,7 @@ else:
class CometLogger(Logger):
r"""
Track your parameters, metrics, source code and more using
`Comet <https://www.comet.com/?utm_source=pytorch_lightning&utm_medium=referral>`_.
`Comet <https://www.comet.com/?utm_source=lightning.pytorch&utm_medium=referral>`_.
Install it with pip:
@ -69,8 +69,8 @@ class CometLogger(Logger):
.. code-block:: python
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
@ -88,7 +88,7 @@ class CometLogger(Logger):
.. code-block:: python
from pytorch_lightning.loggers import CometLogger
from lightning.pytorch.loggers import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
@ -102,7 +102,7 @@ class CometLogger(Logger):
**Log Hyperparameters:**
Log parameters used to initialize a :class:`~pytorch_lightning.core.module.LightningModule`:
Log parameters used to initialize a :class:`~lightning.pytorch.core.module.LightningModule`:
.. code-block:: python
@ -270,7 +270,7 @@ class CometLogger(Logger):
def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]:
r"""
Actual Comet object. To use Comet features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Example::

View File

@ -23,14 +23,14 @@ import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
from lightning_fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
from lightning_fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
from lightning_fabric.loggers.logger import rank_zero_experiment
from lightning_fabric.utilities.logger import _convert_params
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
from lightning.fabric.loggers.logger import rank_zero_experiment
from lightning.fabric.utilities.logger import _convert_params
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.core.saving import save_hparams_to_yaml
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
log = logging.getLogger(__name__)
@ -70,8 +70,8 @@ class CSVLogger(Logger, FabricCSVLogger):
Logs are saved to ``os.path.join(save_dir, name, version)``.
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import CSVLogger
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.loggers import CSVLogger
>>> logger = CSVLogger("logs", name="my_exp_name")
>>> trainer = Trainer(logger=logger)
@ -144,7 +144,7 @@ class CSVLogger(Logger, FabricCSVLogger):
r"""
Actual _ExperimentWriter object. To use _ExperimentWriter features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Example::

View File

@ -22,10 +22,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
import numpy as np
from lightning_fabric.loggers import Logger as FabricLogger
from lightning_fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
from lightning_fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
from lightning.fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
class Logger(FabricLogger, ABC):

View File

@ -28,11 +28,11 @@ import yaml
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.logger import _scan_checkpoints
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.logger import _scan_checkpoints
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
@ -83,17 +83,17 @@ class MLFlowLogger(Logger):
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger
mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
trainer = Trainer(logger=mlf_logger)
Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows:
Use the logger anywhere in your :class:`~lightning.pytorch.core.module.LightningModule` as follows:
.. code-block:: python
from pytorch_lightning import LightningModule
from lightning.pytorch import LightningModule
class LitModel(LightningModule):
@ -115,12 +115,12 @@ class MLFlowLogger(Logger):
save_dir: A path to a local directory where the MLflow runs get saved.
Defaults to `./mlflow` if `tracking_uri` is not provided.
Has no effect if `tracking_uri` is provided.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`
as MLFlow artifacts.
* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.Checkpoint.save_top_k` ``== -1``
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.
@ -177,7 +177,7 @@ class MLFlowLogger(Logger):
def experiment(self) -> MlflowClient:
r"""
Actual MLflow object. To use MLflow features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Example::

View File

@ -27,12 +27,12 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
import pytorch_lightning as pl
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import lightning.pytorch as pl
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.rank_zero import rank_zero_only
_NEPTUNE_AVAILABLE = RequirementCache("neptune-client")
if _NEPTUNE_AVAILABLE:
@ -70,8 +70,8 @@ class NeptuneLogger(Logger):
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS", # replace with your own
@ -82,12 +82,12 @@ class NeptuneLogger(Logger):
**How to use NeptuneLogger?**
Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows:
Use the logger anywhere in your :class:`~lightning.pytorch.core.module.LightningModule` as follows:
.. code-block:: python
from neptune.new.types import File
from pytorch_lightning import LightningModule
from lightning.pytorch import LightningModule
class LitModel(LightningModule):
@ -137,7 +137,7 @@ class NeptuneLogger(Logger):
**Log model checkpoints**
If you have :class:`~pytorch_lightning.callbacks.ModelCheckpoint` configured,
If you have :class:`~lightning.pytorch.callbacks.ModelCheckpoint` configured,
Neptune logger automatically logs model checkpoints.
Model weights will be uploaded to the: "model/checkpoints" namespace in the Neptune Run.
You can disable this option:
@ -153,8 +153,8 @@ class NeptuneLogger(Logger):
.. testcode::
:skipif: not _NEPTUNE_AVAILABLE
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
project="common/pytorch-lightning-integration",
@ -326,7 +326,7 @@ class NeptuneLogger(Logger):
def experiment(self) -> Run:
r"""
Actual Neptune run object. Allows you to use neptune logging features in your
:class:`~pytorch_lightning.core.module.LightningModule`.
:class:`~lightning.pytorch.core.module.LightningModule`.
Example::
@ -382,7 +382,7 @@ class NeptuneLogger(Logger):
Example::
from pytorch_lightning.loggers import NeptuneLogger
from lightning.pytorch.loggers import NeptuneLogger
PARAMS = {
"batch_size": 64,

View File

@ -23,16 +23,16 @@ from typing import Any, Dict, Optional, Union
from torch import Tensor
import pytorch_lightning as pl
from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning_fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
from lightning_fabric.utilities.logger import _convert_params
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
import lightning.pytorch as pl
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
from lightning.fabric.utilities.logger import _convert_params
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.core.saving import save_hparams_to_yaml
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
@ -56,8 +56,8 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
.. testcode::
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)

View File

@ -24,13 +24,13 @@ import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _scan_checkpoints
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.logger import _scan_checkpoints
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
try:
import wandb
@ -61,7 +61,7 @@ class WandbLogger(Logger):
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="MNIST")
@ -75,7 +75,7 @@ class WandbLogger(Logger):
**Log metrics**
Log from :class:`~pytorch_lightning.core.module.LightningModule`:
Log from :class:`~lightning.pytorch.core.module.LightningModule`:
.. code-block:: python
@ -91,7 +91,7 @@ class WandbLogger(Logger):
**Log hyper-parameters**
Save :class:`~pytorch_lightning.core.module.LightningModule` parameters:
Save :class:`~lightning.pytorch.core.module.LightningModule` parameters:
.. code-block:: python
@ -151,7 +151,7 @@ class WandbLogger(Logger):
wandb_logger = WandbLogger(log_model="all")
Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`:
Custom checkpointing can be set up through :class:`~lightning.pytorch.callbacks.ModelCheckpoint`:
.. code-block:: python
@ -227,7 +227,7 @@ class WandbLogger(Logger):
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.loggers import WandbLogger
artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact")
@ -244,7 +244,7 @@ class WandbLogger(Logger):
.. code-block:: python
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="my_project", name="my_run")
wandb_logger.use_artifact(artifact="path/to/artifact")
@ -262,12 +262,12 @@ class WandbLogger(Logger):
id: Same as version.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.ModelCheckpoint`
as W&B artifacts. `latest` and `best` aliases are automatically set.
* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.ModelCheckpoint.save_top_k` ``== -1``
:paramref:`~lightning.pytorch.callbacks.ModelCheckpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.
@ -376,7 +376,7 @@ class WandbLogger(Logger):
r"""
Actual wandb object. To use wandb features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Example::

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from pytorch_lightning.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import _FitLoop # noqa: F401
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
from pytorch_lightning.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401
from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
from lightning.pytorch.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401

View File

@ -17,8 +17,8 @@ from typing import Sequence
from torch.utils.data import DataLoader
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.loops.progress import DataLoaderProgress
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import DataLoaderProgress
class _DataLoaderLoop(_Loop):

View File

@ -21,17 +21,17 @@ from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import _DataLoaderLoop
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loops.dataloader import _DataLoaderLoop
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop
from lightning.pytorch.loops.utilities import _set_sampler_epoch
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import EPOCH_OUTPUT
if _RICH_AVAILABLE:
from rich import get_console

View File

@ -2,12 +2,12 @@ from typing import Any, List, Optional, Sequence, Union
from torch.utils.data import DataLoader
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
from lightning.pytorch.loops.utilities import _set_sampler_epoch
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT
class _PredictionLoop(_DataLoaderLoop):

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401

View File

@ -16,13 +16,13 @@ from collections import OrderedDict
from functools import lru_cache
from typing import Any, Optional, Union
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.loops.progress import BatchProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import SIGTERMException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import BatchProgress
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
class _EvaluationEpochLoop(_Loop):

View File

@ -3,11 +3,11 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
import torch
from lightning_fabric.utilities import move_data_to_device
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.loops.progress import Progress
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.utilities.rank_zero import WarningCache
from lightning.fabric.utilities import move_data_to_device
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
from lightning.pytorch.utilities.rank_zero import WarningCache
warning_cache = WarningCache()
@ -168,7 +168,7 @@ class _PredictionEpochLoop(_Loop):
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
:class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`."""
# the batch_sampler is not be defined in case of CombinedDataLoaders
assert self.trainer.predict_dataloaders
batch_sampler = getattr(

View File

@ -17,18 +17,18 @@ from typing import Any, Dict, List, Optional, Union
import torch
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop
from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
@ -37,19 +37,19 @@ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
class _TrainingEpochLoop(loops._Loop):
"""
Iterates over all batches in the dataloader (one epoch) that the user returns in their
:meth:`~pytorch_lightning.core.module.LightningModule.train_dataloader` method.
:meth:`~lightning.pytorch.core.module.LightningModule.train_dataloader` method.
Its main responsibilities are calling the ``*_epoch_{start,end}`` hooks, accumulating outputs if the user request
them in one of these hooks, and running validation at the requested interval.
The validation is carried out by yet another loop,
:class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
:class:`~lightning.pytorch.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
In the ``run()`` method, the training epoch loop could in theory simply call the
``LightningModule.training_step`` already and perform the optimization.
However, Lightning has built-in support for automatic optimization with multiple optimizers.
For this reason there are actually two more loops nested under
:class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
:class:`~lightning.pytorch.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
Args:
min_steps: The minimum number of steps (batches) to process

View File

@ -14,19 +14,19 @@
import logging
from typing import Any, Optional, Type
import pytorch_lightning as pl
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.epoch import _TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
from pytorch_lightning.loops.progress import Progress
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
import lightning.pytorch as pl
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.epoch import _TrainingEpochLoop
from lightning.pytorch.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
log = logging.getLogger(__name__)

View File

@ -13,8 +13,8 @@
# limitations under the License.
from typing import Dict, Optional
import pytorch_lightning as pl
from pytorch_lightning.loops.progress import BaseProgress
import lightning.pytorch as pl
from lightning.pytorch.loops.progress import BaseProgress
class _Loop:

View File

@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401

View File

@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, Optional, TypeVar
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.exceptions import MisconfigurationException
T = TypeVar("T")

View File

@ -18,19 +18,19 @@ from typing import Any, Dict, Optional
from torch import Tensor
from pytorch_lightning.core.optimizer import do_nothing_closure
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.progress import Progress, ReadyCompletedTracker
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from lightning.pytorch.core.optimizer import do_nothing_closure
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.optimization.closure import OutputResult
from lightning.pytorch.loops.progress import Progress, ReadyCompletedTracker
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import STEP_OUTPUT
@dataclass
class ManualResult(OutputResult):
"""A container to hold the result returned by the ``ManualLoop``.
It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.
It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.
Attributes:
extra: Anything returned by the ``training_step``.
@ -66,11 +66,11 @@ _OUTPUTS_TYPE = Dict[str, Any]
class _ManualOptimization(_Loop):
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
entirely in the :meth:`~lightning.pytorch.core.module.LightningModule.training_step` and therefore the user is
responsible for back-propagating gradients and making calls to the optimizers.
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
:meth:`~pytorch_lightning.core.module.LightningModule.training_step`) and passing through the output(s).
:meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).
"""
output_result_cls = ManualResult

View File

@ -19,21 +19,21 @@ import torch
from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
from pytorch_lightning.loops.progress import OptimizationProgress
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import WarningCache
from pytorch_lightning.utilities.types import STEP_OUTPUT
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult
from lightning.pytorch.loops.progress import OptimizationProgress
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache
from lightning.pytorch.utilities.types import STEP_OUTPUT
@dataclass
class ClosureResult(OutputResult):
"""A container to hold the result of a :class:`Closure` call.
It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.
It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.
Attributes:
closure_loss: The loss with a graph attached.
@ -95,7 +95,7 @@ class Closure(AbstractClosure[ClosureResult]):
do something with the output.
Args:
step_fn: This is typically the :meth:`pytorch_lightning.core.module.LightningModule.training_step
step_fn: This is typically the :meth:`lightning.pytorch.core.module.LightningModule.training_step
wrapped with processing for its outputs
backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
Can be set to ``None`` to skip the backward operation.

View File

@ -18,15 +18,15 @@ import torch
from torch import Tensor
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from lightning_fabric.utilities.warnings import PossibleUserWarning
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.progress import BaseProgress
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.progress import BaseProgress
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
def check_finite_loss(loss: Optional[Tensor]) -> None:
@ -83,7 +83,7 @@ def _parse_loop_limits(
@contextmanager
def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]:
"""Blocks synchronization in :class:`~pytorch_lightning.strategies.parallel.ParallelStrategy`. This is useful
"""Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful
for example when accumulating gradients to reduce communication when it is not needed.
Args:

View File

@ -15,8 +15,8 @@ from typing import Any, Union
import torch
import pytorch_lightning as pl
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):

View File

@ -19,9 +19,9 @@ import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
import lightning.pytorch as pl
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
def _ignore_scalar_return_in_dp() -> None:

View File

@ -19,7 +19,7 @@ from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
def _find_tensors(

View File

@ -0,0 +1,39 @@
from typing import Union
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
from lightning.pytorch.plugins.io.hpu_plugin import HPUCheckpointIO
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.hpu import HPUPrecisionPlugin
from lightning.pytorch.plugins.precision.ipu import IPUPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.tpu import TPUPrecisionPlugin
from lightning.pytorch.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN_INPUT = Union[PLUGIN, str]
__all__ = [
"AsyncCheckpointIO",
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"HPUCheckpointIO",
"ColossalAIPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"IPUPrecisionPlugin",
"HPUPrecisionPlugin",
"MixedPrecisionPlugin",
"PrecisionPlugin",
"FSDPMixedPrecisionPlugin",
"TPUPrecisionPlugin",
"TPUBf16PrecisionPlugin",
"LayerSync",
"TorchSyncBatchNorm",
]

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_fabric.plugins import ClusterEnvironment # noqa: F401
from lightning_fabric.plugins.environments import ( # noqa: F401
from lightning.fabric.plugins import ClusterEnvironment # noqa: F401
from lightning.fabric.plugins.environments import ( # noqa: F401
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
@ -20,4 +20,4 @@ from lightning_fabric.plugins.environments import ( # noqa: F401
TorchElasticEnvironment,
XLAEnvironment,
)
from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
from lightning.pytorch.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401

View File

@ -15,7 +15,7 @@
import logging
import os
from lightning_fabric.plugins import ClusterEnvironment
from lightning.fabric.plugins import ClusterEnvironment
log = logging.getLogger(__name__)

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_fabric.plugins import CheckpointIO, TorchCheckpointIO, XLACheckpointIO
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from lightning.fabric.plugins import CheckpointIO, TorchCheckpointIO, XLACheckpointIO
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
from lightning.pytorch.plugins.io.hpu_plugin import HPUCheckpointIO
__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"]

View File

@ -15,8 +15,8 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from lightning_fabric.plugins import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.fabric.plugins import CheckpointIO
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
class AsyncCheckpointIO(_WrappingCheckpointIO):

View File

@ -13,4 +13,4 @@
# limitations under the License.
# For backward-compatibility
from lightning_fabric.plugins import CheckpointIO # noqa: F401
from lightning.fabric.plugins import CheckpointIO # noqa: F401

View File

@ -17,10 +17,10 @@ from typing import Any, Dict, Optional
import torch
from lightning_fabric.plugins import TorchCheckpointIO
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.cloud_io import _atomic_save, get_filesystem
from lightning_fabric.utilities.types import _PATH
from lightning.fabric.plugins import TorchCheckpointIO
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.cloud_io import _atomic_save, get_filesystem
from lightning.fabric.utilities.types import _PATH
class HPUCheckpointIO(TorchCheckpointIO):

View File

@ -13,4 +13,4 @@
# limitations under the License.
# For backward-compatibility
from lightning_fabric.plugins import TorchCheckpointIO # noqa: F401
from lightning.fabric.plugins import TorchCheckpointIO # noqa: F401

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional
from lightning_fabric.plugins import CheckpointIO
from lightning.fabric.plugins import CheckpointIO
class _WrappingCheckpointIO(CheckpointIO):

View File

@ -13,4 +13,4 @@
# limitations under the License.
# For backward-compatibility
from lightning_fabric.plugins import XLACheckpointIO # noqa: F401
from lightning.fabric.plugins import XLACheckpointIO # noqa: F401

View File

@ -11,16 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.amp import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.hpu import HPUPrecisionPlugin
from lightning.pytorch.plugins.precision.ipu import IPUPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.tpu import TPUPrecisionPlugin
from lightning.pytorch.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
__all__ = [
"ColossalAIPrecisionPlugin",

Some files were not shown because too many files have changed in this diff Show More