move pytorch_lightning >> lightning/pytorch (#16594)
This commit is contained in:
parent
01b152f169
commit
7d4780adb1
|
@ -26,7 +26,8 @@ pr:
|
|||
- "examples/pl_basics/backbone_image_classifier.py"
|
||||
- "examples/pl_basics/autoencoder.py"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
- "requirements/fabric/**"
|
||||
|
|
|
@ -24,7 +24,8 @@ pr:
|
|||
- "src/lightning/fabric/**"
|
||||
- "src/lightning_fabric/*"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
exclude:
|
||||
|
|
|
@ -21,7 +21,8 @@ pr:
|
|||
- "src/lightning/fabric/**"
|
||||
- "src/lightning_fabric/*"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
exclude:
|
||||
|
|
|
@ -11,7 +11,8 @@ subprojects:
|
|||
- "src/lightning/fabric/**"
|
||||
- "src/lightning_fabric/*"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "tests/legacy/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
|
@ -49,7 +50,8 @@ subprojects:
|
|||
- "examples/pl_basics/backbone_image_classifier.py"
|
||||
- "examples/pl_basics/autoencoder.py"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
- "requirements/fabric/**"
|
||||
|
@ -82,7 +84,8 @@ subprojects:
|
|||
- "src/lightning/fabric/**"
|
||||
- "src/lightning_fabric/*"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
- "!requirements/*/docs.txt"
|
||||
|
@ -99,7 +102,8 @@ subprojects:
|
|||
- "src/lightning/fabric/**"
|
||||
- "src/lightning_fabric/*"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
- "!requirements/docs.txt"
|
||||
|
@ -130,7 +134,8 @@ subprojects:
|
|||
|
||||
- id: "pytorch_lightning: Docs"
|
||||
paths:
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "docs/source-pytorch/**"
|
||||
- ".actions/**"
|
||||
- ".github/workflows/docs-checks.yml"
|
||||
|
|
|
@ -9,7 +9,8 @@ app:
|
|||
- 'requirements/app/**'
|
||||
|
||||
pl:
|
||||
- 'src/pytorch_lightning/**'
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- 'tests/tests_pytorch/**'
|
||||
- 'tests/legacy/**'
|
||||
- 'examples/pl_*/**'
|
||||
|
|
|
@ -10,7 +10,8 @@ on:
|
|||
paths:
|
||||
- ".actions/**"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "tests/legacy/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
|
|
|
@ -15,7 +15,8 @@ on:
|
|||
- "src/lightning_fabric/*"
|
||||
- "tests/tests_fabric/**"
|
||||
- "requirements/pytorch/**"
|
||||
- "src/pytorch_lightning/**"
|
||||
- "src/lightning/pytorch/**"
|
||||
- "src/pytorch_lightning/*"
|
||||
- "tests/tests_pytorch/**"
|
||||
- "pyproject.toml" # includes pytest config
|
||||
- "!requirements/*/docs.txt"
|
||||
|
|
|
@ -57,7 +57,10 @@ src/lightning_fabric/
|
|||
!src/lightning_fabric/__*__.py
|
||||
!src/lightning_fabric/MANIFEST.in
|
||||
!src/lightning_fabric/README.md
|
||||
src/lightning/pytorch/
|
||||
src/pytorch_lightning/
|
||||
!src/pytorch_lightning/__*__.py
|
||||
!src/pytorch_lightning/MANIFEST.in
|
||||
!src/pytorch_lightning/README.md
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
|
|
|
@ -42,9 +42,8 @@ repos:
|
|||
docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif|
|
||||
docs/source-pytorch/_static/images/general/pl_overview_flat.jpg|
|
||||
docs/source-pytorch/_static/images/general/pl_overview.gif|
|
||||
src/lightning_app/cli/pl-app-template/ui/yarn.lock|
|
||||
src/pytorch_lightning/CHANGELOG.md|
|
||||
src/lightning/fabric/CHANGELOG.md
|
||||
src/lightning/fabric/CHANGELOG.md|
|
||||
src/lightning/pytorch/CHANGELOG.md
|
||||
)$
|
||||
- id: detect-private-key
|
||||
|
||||
|
@ -100,7 +99,7 @@ repos:
|
|||
(?x)^(
|
||||
src/lightning/app/CHANGELOG.md|
|
||||
src/lightning/fabric/CHANGELOG.md|
|
||||
src/pytorch_lightning/CHANGELOG.md
|
||||
src/lightning/pytorch/CHANGELOG.md
|
||||
)$
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
|
|
4
setup.py
4
setup.py
|
@ -58,7 +58,7 @@ _PACKAGE_MAPPING = {
|
|||
"fabric": "lightning_fabric",
|
||||
}
|
||||
# TODO: drop this reverse list when all packages are moved
|
||||
_MIRROR_PACKAGE_REVERSED = ("app", "fabric")
|
||||
_MIRROR_PACKAGE_REVERSED = ("app", "fabric", "pytorch")
|
||||
# https://packaging.python.org/guides/single-sourcing-package-version/
|
||||
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
|
||||
_PATH_ROOT = os.path.dirname(__file__)
|
||||
|
@ -142,7 +142,7 @@ if __name__ == "__main__":
|
|||
package_to_install = _PACKAGE_MAPPING.get(_PACKAGE_NAME, "lightning")
|
||||
if package_to_install == "lightning":
|
||||
# merge all requirements files
|
||||
assistant._load_aggregate_requirements(_PATH_REQUIRE, _FREEZE_REQUIREMENTS) # install everything
|
||||
assistant._load_aggregate_requirements(_PATH_REQUIRE, _FREEZE_REQUIREMENTS)
|
||||
# replace imports and copy the code
|
||||
assistant.create_mirror_package(_PATH_SRC, _PACKAGE_MAPPING, reverse=_MIRROR_PACKAGE_REVERSED)
|
||||
else:
|
||||
|
|
|
@ -207,6 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
## [1.9.0] - 2023-01-17
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for native logging of `MetricCollection` with enabled compute groups ([#15580](https://github.com/Lightning-AI/lightning/pull/15580))
|
||||
- Added support for custom artifact names in `pl.loggers.WandbLogger` ([#16173](https://github.com/Lightning-AI/lightning/pull/16173))
|
||||
- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))
|
||||
|
@ -223,7 +224,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added info message for Ampere CUDA GPU users to enable tf32 matmul precision ([#16037](https://github.com/Lightning-AI/lightning/pull/16037))
|
||||
- Added support for returning optimizer-like classes in `LightningModule.configure_optimizers` ([#16189](https://github.com/Lightning-AI/lightning/pull/16189))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
|
|
@ -1,12 +1,18 @@
|
|||
"""Root package info."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pytorch_lightning.__about__ import * # noqa: F401, F403
|
||||
from lightning_utilities import module_available
|
||||
|
||||
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__about__.py")):
|
||||
from lightning.pytorch.__about__ import * # noqa: F401, F403
|
||||
if "__version__" not in locals():
|
||||
from pytorch_lightning.__version__ import version as __version__ # noqa: F401
|
||||
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
|
||||
from lightning.pytorch.__version__ import version as __version__
|
||||
elif module_available("lightning"):
|
||||
from lightning import __version__ # noqa: F401
|
||||
|
||||
_DETAIL = 15 # between logging.INFO and logging.DEBUG, used for logging in production use cases
|
||||
|
||||
|
@ -30,13 +36,13 @@ if not _root_logger.hasHandlers():
|
|||
_logger.addHandler(logging.StreamHandler())
|
||||
_logger.propagate = False
|
||||
|
||||
from lightning_fabric.utilities.seed import seed_everything # noqa: E402
|
||||
from pytorch_lightning.callbacks import Callback # noqa: E402
|
||||
from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402
|
||||
from pytorch_lightning.trainer import Trainer # noqa: E402
|
||||
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
|
||||
from lightning.pytorch.callbacks import Callback # noqa: E402
|
||||
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
|
||||
from lightning.pytorch.trainer import Trainer # noqa: E402
|
||||
|
||||
# this import needs to go last as it will patch other modules
|
||||
import pytorch_lightning._graveyard # noqa: E402, F401 # isort: skip
|
||||
import lightning.pytorch._graveyard # noqa: E402, F401 # isort: skip
|
||||
|
||||
__all__ = ["Trainer", "LightningDataModule", "LightningModule", "Callback", "seed_everything"]
|
||||
|
|
@ -11,4 +11,4 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytorch_lightning._graveyard.legacy_import_unpickler # noqa: F401
|
||||
import lightning.pytorch._graveyard.legacy_import_unpickler # noqa: F401
|
|
@ -11,7 +11,7 @@ def _patch_pl_to_mirror_if_necessary(module: str) -> str:
|
|||
if module.startswith(pl):
|
||||
# for the standalone package this won't do anything,
|
||||
# for the unified mirror package it will redirect the imports
|
||||
module = "pytorch_lightning" + module[len(pl) :]
|
||||
module = "lightning.pytorch" + module[len(pl) :]
|
||||
return module
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:
|
|||
return _compare_version(new_package, op, version, use_base_version)
|
||||
|
||||
|
||||
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to pytorch_lightning,
|
||||
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to lightning.pytorch,
|
||||
# which has to be redirected to the unified package:
|
||||
# https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96
|
||||
try:
|
|
@ -10,16 +10,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lightning_fabric.accelerators import find_usable_cuda_devices # noqa: F401
|
||||
from lightning_fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
|
||||
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
|
||||
from lightning.fabric.accelerators import find_usable_cuda_devices # noqa: F401
|
||||
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.hpu import HPUAccelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.ipu import IPUAccelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401
|
||||
from lightning.pytorch.accelerators.tpu import TPUAccelerator # noqa: F401
|
||||
|
||||
ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators"
|
||||
ACCELERATORS_BASE_MODULE = "lightning.pytorch.accelerators"
|
||||
AcceleratorRegistry = _AcceleratorRegistry()
|
||||
call_register_accelerators(AcceleratorRegistry, ACCELERATORS_BASE_MODULE)
|
|
@ -14,9 +14,9 @@
|
|||
from abc import ABC
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.accelerators.accelerator import Accelerator as _Accelerator
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class Accelerator(_Accelerator, ABC):
|
|
@ -15,11 +15,11 @@ from typing import Any, Dict, List, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.accelerators.cpu import _parse_cpu_cores
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
|
||||
from lightning.fabric.accelerators.cpu import _parse_cpu_cores
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
|
||||
|
||||
|
||||
class CPUAccelerator(Accelerator):
|
|
@ -19,12 +19,12 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices
|
||||
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices
|
||||
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
|
@ -16,11 +16,11 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _HPU_AVAILABLE
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
|
||||
|
||||
if _HPU_AVAILABLE:
|
||||
import habana_frameworks.torch.hpu as torch_hpu
|
||||
|
@ -108,7 +108,7 @@ class HPUAccelerator(Accelerator):
|
|||
def _parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]:
|
||||
"""
|
||||
Parses the hpus given in the format as accepted by the
|
||||
:class:`~pytorch_lightning.trainer.Trainer` for the `devices` flag.
|
||||
:class:`~lightning.pytorch.trainer.Trainer` for the `devices` flag.
|
||||
|
||||
Args:
|
||||
devices: An integer that indicates the number of Gaudi devices to be used
|
|
@ -15,9 +15,9 @@ from typing import Any, Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.imports import _IPU_AVAILABLE
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
from lightning.pytorch.utilities.imports import _IPU_AVAILABLE
|
||||
|
||||
|
||||
class IPUAccelerator(Accelerator):
|
|
@ -15,12 +15,12 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
|
||||
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
|
||||
from lightning.fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
|
||||
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
|
||||
|
||||
|
||||
class MPSAccelerator(Accelerator):
|
|
@ -15,10 +15,10 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.accelerators.tpu import _parse_tpu_devices, _XLA_AVAILABLE
|
||||
from lightning_fabric.accelerators.tpu import TPUAccelerator as FabricTPUAccelerator
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from lightning.fabric.accelerators.tpu import _parse_tpu_devices, _XLA_AVAILABLE
|
||||
from lightning.fabric.accelerators.tpu import TPUAccelerator as FabricTPUAccelerator
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
from lightning.pytorch.accelerators.accelerator import Accelerator
|
||||
|
||||
|
||||
class TPUAccelerator(Accelerator):
|
|
@ -11,26 +11,26 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.callbacks.checkpoint import Checkpoint
|
||||
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
|
||||
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
|
||||
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.model_summary import ModelSummary
|
||||
from pytorch_lightning.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
|
||||
from pytorch_lightning.callbacks.pruning import ModelPruning
|
||||
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
|
||||
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
|
||||
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.callbacks.checkpoint import Checkpoint
|
||||
from lightning.pytorch.callbacks.device_stats_monitor import DeviceStatsMonitor
|
||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
|
||||
from lightning.pytorch.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||
from lightning.pytorch.callbacks.lambda_function import LambdaCallback
|
||||
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
|
||||
from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from lightning.pytorch.callbacks.model_summary import ModelSummary
|
||||
from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
||||
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
|
||||
from lightning.pytorch.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
|
||||
from lightning.pytorch.callbacks.pruning import ModelPruning
|
||||
from lightning.pytorch.callbacks.quantization import QuantizationAwareTraining
|
||||
from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
|
||||
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
|
||||
from lightning.pytorch.callbacks.timer import Timer
|
||||
|
||||
__all__ = [
|
||||
"BackboneFinetuning",
|
|
@ -20,12 +20,12 @@ Finds optimal batch size
|
|||
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.tuner.batch_size_scaling import _scale_batch_size
|
||||
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import lightning_hasattr
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.tuner.batch_size_scaling import _scale_batch_size
|
||||
from lightning.pytorch.utilities.exceptions import _TunerExitException, MisconfigurationException
|
||||
from lightning.pytorch.utilities.parsing import lightning_hasattr
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
class BatchSizeFinder(Callback):
|
||||
|
@ -63,7 +63,7 @@ class BatchSizeFinder(Callback):
|
|||
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
|
||||
# useful while fine-tuning models since you can't always use the same batch size after
|
||||
# unfreezing the backbone.
|
||||
from pytorch_lightning.callbacks import BatchSizeFinder
|
||||
from lightning.pytorch.callbacks import BatchSizeFinder
|
||||
|
||||
|
||||
class FineTuneBatchSizeFinder(BatchSizeFinder):
|
||||
|
@ -85,7 +85,7 @@ class BatchSizeFinder(Callback):
|
|||
Example::
|
||||
|
||||
# 2. Run batch size finder for validate/test/predict.
|
||||
from pytorch_lightning.callbacks import BatchSizeFinder
|
||||
from lightning.pytorch.callbacks import BatchSizeFinder
|
||||
|
||||
|
||||
class EvalBatchSizeFinder(BatchSizeFinder):
|
|
@ -21,8 +21,8 @@ from typing import Any, Dict, List, Optional, Type
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
class Callback:
|
||||
|
@ -217,8 +217,8 @@ class Callback:
|
|||
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
|
||||
|
||||
Args:
|
||||
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
|
||||
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
|
||||
checkpoint: the checkpoint dictionary that will be saved.
|
||||
"""
|
||||
|
||||
|
@ -229,8 +229,8 @@ class Callback:
|
|||
Called when loading a model checkpoint, use to reload state.
|
||||
|
||||
Args:
|
||||
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
|
||||
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
|
||||
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
|
||||
"""
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
r"""
|
||||
This is the base class for model checkpointing. Expert users may want to subclass it in case of writing
|
||||
custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that
|
||||
custom :class:`~lightning.pytorch.callbacksCheckpoint` callback, so that
|
||||
the trainer recognizes the custom class as a checkpointing callback.
|
||||
"""
|
|
@ -20,12 +20,12 @@ Monitors and logs device stats during training.
|
|||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _PSUTIL_AVAILABLE
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
class DeviceStatsMonitor(Callback):
|
||||
|
@ -46,8 +46,8 @@ class DeviceStatsMonitor(Callback):
|
|||
If ``Trainer`` has no logger.
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import DeviceStatsMonitor
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import DeviceStatsMonitor
|
||||
>>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
|
||||
>>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
|
||||
"""
|
||||
|
@ -91,7 +91,7 @@ class DeviceStatsMonitor(Callback):
|
|||
|
||||
if self._cpu_stats and device.type != "cpu":
|
||||
# Don't query CPU stats twice if CPU is accelerator
|
||||
from pytorch_lightning.accelerators.cpu import get_cpu_stats
|
||||
from lightning.pytorch.accelerators.cpu import get_cpu_stats
|
||||
|
||||
device_stats.update(get_cpu_stats())
|
||||
|
|
@ -25,11 +25,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.rank_zero import _get_rank
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.rank_zero import _get_rank
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -74,8 +74,8 @@ class EarlyStopping(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import EarlyStopping
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import EarlyStopping
|
||||
>>> early_stopping = EarlyStopping('val_loss')
|
||||
>>> trainer = Trainer(callbacks=[early_stopping])
|
||||
|
||||
|
@ -174,7 +174,7 @@ class EarlyStopping(Callback):
|
|||
self.patience = state_dict["patience"]
|
||||
|
||||
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
|
||||
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
|
||||
|
|
@ -24,10 +24,10 @@ from torch.nn import Module, ModuleDict
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -328,8 +328,8 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import BackboneFinetuning
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import BackboneFinetuning
|
||||
>>> multiplicative = lambda epoch: 1.5
|
||||
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
|
||||
>>> trainer = Trainer(callbacks=[backbone_finetuning])
|
|
@ -22,9 +22,9 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number.
|
|||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class GradientAccumulationScheduler(Callback):
|
||||
|
@ -51,8 +51,8 @@ class GradientAccumulationScheduler(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
|
||||
# because epoch (key) should be zero-indexed.
|
|
@ -21,7 +21,7 @@ Create a simple callback on the fly using lambda functions.
|
|||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
|
||||
|
||||
class LambdaCallback(Callback):
|
||||
|
@ -29,12 +29,12 @@ class LambdaCallback(Callback):
|
|||
Create a simple callback on the fly using lambda functions.
|
||||
|
||||
Args:
|
||||
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.callback.Callback`
|
||||
**kwargs: hooks supported by :class:`~lightning.pytorch.callbacks.callback.Callback`
|
||||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import LambdaCallback
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import LambdaCallback
|
||||
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
|
||||
"""
|
||||
|
|
@ -19,11 +19,11 @@ Finds optimal learning rate
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.tuner.lr_finder import _lr_find, _LRFinder
|
||||
from pytorch_lightning.utilities.exceptions import _TunerExitException
|
||||
from pytorch_lightning.utilities.seed import isolate_rng
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.tuner.lr_finder import _lr_find, _LRFinder
|
||||
from lightning.pytorch.utilities.exceptions import _TunerExitException
|
||||
from lightning.pytorch.utilities.seed import isolate_rng
|
||||
|
||||
|
||||
class LearningRateFinder(Callback):
|
||||
|
@ -50,7 +50,7 @@ class LearningRateFinder(Callback):
|
|||
|
||||
# Customize LearningRateFinder callback to run at different epochs.
|
||||
# This feature is useful while fine-tuning models.
|
||||
from pytorch_lightning.callbacks import LearningRateFinder
|
||||
from lightning.pytorch.callbacks import LearningRateFinder
|
||||
|
||||
|
||||
class FineTuneLearningRateFinder(LearningRateFinder):
|
|
@ -25,11 +25,11 @@ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
|
|||
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.pytorch.utilities.types import LRSchedulerConfig
|
||||
|
||||
|
||||
class LearningRateMonitor(Callback):
|
||||
|
@ -49,8 +49,8 @@ class LearningRateMonitor(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import LearningRateMonitor
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
|
||||
>>> trainer = Trainer(callbacks=[lr_monitor])
|
||||
|
|
@ -32,13 +32,13 @@ import torch
|
|||
import yaml
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.callbacks import Checkpoint
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
warning_cache = WarningCache()
|
||||
|
@ -47,7 +47,7 @@ warning_cache = WarningCache()
|
|||
class ModelCheckpoint(Checkpoint):
|
||||
r"""
|
||||
Save the model periodically by monitoring a quantity. Every metric logged with
|
||||
:meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in
|
||||
:meth:`~lightning.pytorch.core.module.log` or :meth:`~lightning.pytorch.core.module.log_dict` in
|
||||
LightningModule is a candidate for the monitor key. For more information, see
|
||||
:ref:`checkpointing`.
|
||||
|
||||
|
@ -64,8 +64,8 @@ class ModelCheckpoint(Checkpoint):
|
|||
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
||||
|
||||
By default, dirpath is ``None`` and will be set at runtime to the location
|
||||
specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` argument,
|
||||
specified by :class:`~lightning.pytorch.trainer.trainer.Trainer`'s
|
||||
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument,
|
||||
and if the Trainer uses a logger, the path will also contain logger name and version.
|
||||
|
||||
filename: checkpoint filename. Can contain named formatting options to be auto-filled.
|
||||
|
@ -157,8 +157,8 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# saves checkpoints to 'my/path/' at every epoch
|
||||
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
|
||||
|
@ -371,7 +371,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
logger.after_save_checkpoint(proxy(self))
|
||||
|
||||
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
|
||||
return (
|
||||
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
|
|
@ -15,7 +15,7 @@
|
|||
Model Summary
|
||||
=============
|
||||
|
||||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`.
|
||||
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`.
|
||||
|
||||
The string representation of this summary prints a table with columns containing
|
||||
the name, type and number of parameters for each layer.
|
||||
|
@ -24,19 +24,19 @@ the name, type and number of parameters for each layer.
|
|||
import logging
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.model_summary import DeepSpeedSummary
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary as Summary
|
||||
from pytorch_lightning.utilities.model_summary import summarize
|
||||
from pytorch_lightning.utilities.model_summary.model_summary import _format_summary_table
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.model_summary import DeepSpeedSummary
|
||||
from lightning.pytorch.utilities.model_summary import ModelSummary as Summary
|
||||
from lightning.pytorch.utilities.model_summary import summarize
|
||||
from lightning.pytorch.utilities.model_summary.model_summary import _format_summary_table
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelSummary(Callback):
|
||||
r"""
|
||||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`.
|
||||
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`.
|
||||
|
||||
Args:
|
||||
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
||||
|
@ -44,8 +44,8 @@ class ModelSummary(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import ModelSummary
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import ModelSummary
|
||||
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
|
||||
"""
|
||||
|
||||
|
@ -66,7 +66,7 @@ class ModelSummary(Callback):
|
|||
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
|
||||
|
||||
def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
|
||||
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
|
||||
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
|
||||
|
||||
if isinstance(trainer.strategy, DeepSpeedStrategy) and trainer.strategy.zero_stage_3:
|
||||
return DeepSpeedSummary(pl_module, max_depth=self._max_depth)
|
|
@ -20,9 +20,9 @@ Automatically save a checkpoints on exception.
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.callbacks import Checkpoint
|
||||
|
||||
|
||||
class OnExceptionCheckpoint(Checkpoint):
|
||||
|
@ -38,8 +38,8 @@ class OnExceptionCheckpoint(Checkpoint):
|
|||
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import OnExceptionCheckpoint
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.callbacks import OnExceptionCheckpoint
|
||||
>>> trainer = Trainer(callbacks=[OnExceptionCheckpoint(".")])
|
||||
"""
|
||||
|
|
@ -19,10 +19,10 @@ Aids in saving predictions
|
|||
"""
|
||||
from typing import Any, Literal, Optional, Sequence
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities import LightningEnum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities import LightningEnum
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class WriteInterval(LightningEnum):
|
||||
|
@ -48,7 +48,7 @@ class BasePredictionWriter(Callback):
|
|||
Example::
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import BasePredictionWriter
|
||||
from lightning.pytorch.callbacks import BasePredictionWriter
|
||||
|
||||
class CustomWriter(BasePredictionWriter):
|
||||
|
||||
|
@ -75,7 +75,7 @@ class BasePredictionWriter(Callback):
|
|||
# multi-device inference example
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import BasePredictionWriter
|
||||
from lightning.pytorch.callbacks import BasePredictionWriter
|
||||
|
||||
class CustomWriter(BasePredictionWriter):
|
||||
|
|
@ -18,6 +18,6 @@ Progress Bars
|
|||
Use or override one of the progress bar callbacks.
|
||||
|
||||
"""
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
|
||||
from lightning.pytorch.callbacks.progress.base import ProgressBarBase # noqa: F401
|
||||
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
|
||||
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
|
|
@ -13,16 +13,16 @@
|
|||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.logger import _version
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.utilities.logger import _version
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
class ProgressBarBase(Callback):
|
||||
r"""
|
||||
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
|
||||
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
|
||||
The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback`
|
||||
that keeps track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
You should implement your highly custom progress bars with this as the base class.
|
||||
|
||||
Example::
|
||||
|
@ -207,7 +207,7 @@ class ProgressBarBase(Callback):
|
|||
def enable(self) -> None:
|
||||
"""You should provide a way to enable the progress bar.
|
||||
|
||||
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
|
||||
The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training
|
||||
routines like the :ref:`learning rate finder <advanced/training_tricks:Learning Rate Finder>`.
|
||||
to temporarily enable and disable the main progress bar.
|
||||
"""
|
|
@ -18,9 +18,9 @@ from typing import Any, cast, Dict, Optional, Union
|
|||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
_RICH_AVAILABLE: bool = RequirementCache("rich>=10.2.2")
|
||||
|
||||
|
@ -215,8 +215,8 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import RichProgressBar
|
||||
|
||||
trainer = Trainer(callbacks=RichProgressBar())
|
||||
|
|
@ -25,9 +25,9 @@ if importlib.util.find_spec("ipywidgets") is not None:
|
|||
else:
|
||||
from tqdm import tqdm as _tqdm
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
|
||||
|
||||
_PAD_SIZE = 5
|
||||
|
||||
|
@ -65,7 +65,7 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
- **sanity check progress:** the progress during the sanity check run
|
||||
- **main progress:** shows training + validation progress combined. It also accounts for
|
||||
multiple validation runs during training when
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
|
||||
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used.
|
||||
- **validation progress:** only visible during validation;
|
||||
shows total progress over all validation datasets.
|
||||
- **test progress:** only active when testing; shows total progress over all test datasets.
|
||||
|
@ -74,7 +74,7 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
|
||||
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
|
||||
specific methods of the callback class and pass your custom implementation to the
|
||||
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
|
||||
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -85,7 +85,7 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
... return bar
|
||||
...
|
||||
>>> bar = LitProgressBar()
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> trainer = Trainer(callbacks=[bar])
|
||||
|
||||
Args:
|
||||
|
@ -94,8 +94,8 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines.
|
||||
This is useful when you have progress bars defined elsewhere and want to show all of them
|
||||
together. This corresponds to
|
||||
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
|
||||
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
|
||||
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
|
||||
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
"""
|
||||
|
||||
def __init__(self, refresh_rate: int = 1, process_position: int = 0):
|
|
@ -26,11 +26,11 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
from torch import nn, Tensor
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.core.module import LightningModule
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -25,10 +25,10 @@ from torch import Tensor
|
|||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.quantization import FakeQuantizeBase
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_11:
|
||||
from torch.ao.quantization import fuse_modules_qat as fuse_modules
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
|
||||
from pytorch_lightning.callbacks import ModelSummary
|
||||
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_summary import get_human_readable_count
|
||||
from lightning.pytorch.callbacks import ModelSummary
|
||||
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
|
||||
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
||||
|
||||
if _RICH_AVAILABLE:
|
||||
from rich import get_console
|
||||
|
@ -24,7 +24,7 @@ if _RICH_AVAILABLE:
|
|||
|
||||
class RichModelSummary(ModelSummary):
|
||||
r"""
|
||||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule`
|
||||
Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule`
|
||||
with `rich text formatting <https://github.com/Textualize/rich>`_.
|
||||
|
||||
Install it with pip:
|
||||
|
@ -35,17 +35,17 @@ class RichModelSummary(ModelSummary):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import RichModelSummary
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import RichModelSummary
|
||||
|
||||
trainer = Trainer(callbacks=RichModelSummary())
|
||||
|
||||
You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
|
||||
You could also enable ``RichModelSummary`` using the :class:`~lightning.pytorch.callbacks.RichProgressBar`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import RichProgressBar
|
||||
|
||||
trainer = Trainer(callbacks=RichProgressBar())
|
||||
|
|
@ -22,14 +22,14 @@ import torch
|
|||
from torch import nn, Tensor
|
||||
from torch.optim.swa_utils import SWALR
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.types import LRScheduler
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
from pytorch_lightning.strategies.fsdp import FSDPStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.types import LRScheduler
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
from lightning.pytorch.strategies.fsdp import FSDPStrategy
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from lightning.pytorch.utilities.types import LRSchedulerConfig
|
||||
|
||||
_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]
|
||||
|
|
@ -20,12 +20,12 @@ import time
|
|||
from datetime import timedelta
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import LightningEnum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.utilities import LightningEnum
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,8 +52,8 @@ class Timer(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Timer
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import Timer
|
||||
|
||||
# stop training after 12 hours
|
||||
timer = Timer(duration="00:12:00:00")
|
|
@ -22,13 +22,13 @@ from lightning_utilities.core.imports import RequirementCache
|
|||
from lightning_utilities.core.rank_zero import _warn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
|
||||
from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.17.0")
|
||||
|
||||
|
@ -292,21 +292,21 @@ class LightningCLI:
|
|||
.. warning:: ``LightningCLI`` is in beta and subject to change.
|
||||
|
||||
Args:
|
||||
model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when
|
||||
model_class: An optional :class:`~lightning.pytorch.core.module.LightningModule` class to train on or a
|
||||
callable which returns a :class:`~lightning.pytorch.core.module.LightningModule` instance when
|
||||
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
|
||||
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
|
||||
datamodule_class: An optional :class:`~lightning.pytorch.core.datamodule.LightningDataModule` class or a
|
||||
callable which returns a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` instance when
|
||||
called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
|
||||
save_config_callback: A callback class to save the config.
|
||||
save_config_kwargs: Parameters that will be used to instantiate the save_config_callback.
|
||||
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
|
||||
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
|
||||
trainer_class: An optional subclass of the :class:`~lightning.pytorch.trainer.trainer.Trainer` class or a
|
||||
callable which returns a :class:`~lightning.pytorch.trainer.trainer.Trainer` instance when called.
|
||||
trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
|
||||
this argument will not be configurable from a configuration file and will always be present for
|
||||
this particular CLI. Alternatively, configurable callbacks can be added as explained in
|
||||
:ref:`the CLI docs <lightning-cli>`.
|
||||
seed_everything_default: Number for the :func:`~lightning_fabric.utilities.seed.seed_everything`
|
||||
seed_everything_default: Number for the :func:`~lightning.fabric.utilities.seed.seed_everything`
|
||||
seed value. Set to True to automatically choose a seed value.
|
||||
Setting it to False will avoid calling ``seed_everything``.
|
||||
parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
|
||||
|
@ -319,7 +319,7 @@ class LightningCLI:
|
|||
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style
|
||||
arguments can be given in a ``list``. Alternatively, structured config options can be given in a
|
||||
``dict`` or ``jsonargparse.Namespace``.
|
||||
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
|
||||
method. If set to ``False``, the trainer and model classes will be instantiated only.
|
||||
"""
|
||||
self.save_config_callback = save_config_callback
|
||||
|
@ -362,7 +362,7 @@ class LightningCLI:
|
|||
|
||||
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
|
||||
"""Method that instantiates the argument parser."""
|
||||
kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"])
|
||||
kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"])
|
||||
parser = LightningArgumentParser(**kwargs)
|
||||
parser.add_argument(
|
||||
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
|
||||
|
@ -549,7 +549,7 @@ class LightningCLI:
|
|||
def configure_optimizers(
|
||||
lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
|
||||
) -> Any:
|
||||
"""Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers`
|
||||
"""Override to customize the :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers`
|
||||
method.
|
||||
|
||||
Args:
|
||||
|
@ -567,7 +567,7 @@ class LightningCLI:
|
|||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
|
||||
"""Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method
|
||||
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
|
||||
if not self.auto_configure_optimizers:
|
||||
return
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from lightning.pytorch.core.datamodule import LightningDataModule
|
||||
from lightning.pytorch.core.module import LightningModule
|
||||
|
||||
__all__ = ["LightningDataModule", "LightningModule"]
|
|
@ -19,18 +19,18 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio
|
|||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
from typing_extensions import Self
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.core.hooks import DataHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from pytorch_lightning.core.saving import _load_from_checkpoint
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.core.hooks import DataHooks
|
||||
from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
from lightning.pytorch.core.saving import _load_from_checkpoint
|
||||
from lightning.pytorch.utilities.argparse import (
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
get_init_arguments_and_types,
|
||||
parse_argparser,
|
||||
)
|
||||
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
from lightning.pytorch.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
|
||||
class LightningDataModule(DataHooks, HyperparametersMixin):
|
||||
|
@ -91,7 +91,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin):
|
|||
|
||||
Args:
|
||||
args: The parser or namespace to take arguments from. Only known arguments will be
|
||||
parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
|
||||
parsed and passed to the :class:`~lightning.pytorch.core.datamodule.LightningDataModule`.
|
||||
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
|
||||
These must be valid DataModule arguments.
|
||||
|
|
@ -19,9 +19,9 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
from lightning.pytorch.utilities import move_data_to_device
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
|
||||
|
||||
class ModelHooks:
|
||||
|
@ -233,7 +233,7 @@ class ModelHooks:
|
|||
"""Called before ``optimizer.step()``.
|
||||
|
||||
If using gradient accumulation, the hook is called once the gradients have been accumulated.
|
||||
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
|
||||
See: :paramref:`~lightning.pytorch.trainer.Trainer.accumulate_grad_batches`.
|
||||
|
||||
If using AMP, the loss will be unscaled before calling this hook.
|
||||
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
|
||||
|
@ -378,7 +378,7 @@ class DataHooks:
|
|||
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
|
||||
|
||||
The dataloader you return will not be reloaded unless you set
|
||||
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
||||
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
||||
a positive integer.
|
||||
|
||||
For data processing use the following pattern:
|
||||
|
@ -390,7 +390,7 @@ class DataHooks:
|
|||
|
||||
.. warning:: do not assign state in prepare_data
|
||||
|
||||
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
|
||||
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
|
||||
- :meth:`prepare_data`
|
||||
- :meth:`setup`
|
||||
|
||||
|
@ -455,7 +455,7 @@ class DataHooks:
|
|||
.. warning:: do not assign state in prepare_data
|
||||
|
||||
|
||||
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`
|
||||
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`
|
||||
- :meth:`prepare_data`
|
||||
- :meth:`setup`
|
||||
|
||||
|
@ -500,13 +500,13 @@ class DataHooks:
|
|||
Implement one or multiple PyTorch DataLoaders for validation.
|
||||
|
||||
The dataloader you return will not be reloaded unless you set
|
||||
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
||||
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
|
||||
a positive integer.
|
||||
|
||||
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||
|
||||
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
|
||||
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`
|
||||
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
|
||||
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`
|
||||
- :meth:`prepare_data`
|
||||
- :meth:`setup`
|
||||
|
||||
|
@ -552,7 +552,7 @@ class DataHooks:
|
|||
|
||||
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||
|
||||
- :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`
|
||||
- :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`
|
||||
- :meth:`prepare_data`
|
||||
- :meth:`setup`
|
||||
|
|
@ -11,4 +11,4 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin # noqa: F401
|
||||
from lightning.pytorch.core.mixins.hparams_mixin import HyperparametersMixin # noqa: F401
|
|
@ -17,8 +17,8 @@ import types
|
|||
from argparse import Namespace
|
||||
from typing import Any, List, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters
|
||||
from lightning.pytorch.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
|
||||
from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
|
||||
|
||||
|
||||
class HyperparametersMixin:
|
||||
|
@ -47,7 +47,7 @@ class HyperparametersMixin:
|
|||
logger: Whether to send the hyperparameters to the logger. Default: True
|
||||
|
||||
Example::
|
||||
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
>>> class ManuallyArgsModel(HyperparametersMixin):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
|
@ -60,7 +60,7 @@ class HyperparametersMixin:
|
|||
"arg1": 1
|
||||
"arg3": 3.14
|
||||
|
||||
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
>>> class AutomaticArgsModel(HyperparametersMixin):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
||||
|
@ -74,7 +74,7 @@ class HyperparametersMixin:
|
|||
"arg2": abc
|
||||
"arg3": 3.14
|
||||
|
||||
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
>>> class SingleArgModel(HyperparametersMixin):
|
||||
... def __init__(self, params):
|
||||
... super().__init__()
|
||||
|
@ -88,7 +88,7 @@ class HyperparametersMixin:
|
|||
"p2": abc
|
||||
"p3": 3.14
|
||||
|
||||
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
>>> class ManuallyArgsModel(HyperparametersMixin):
|
||||
... def __init__(self, arg1, arg2, arg3):
|
||||
... super().__init__()
|
|
@ -27,28 +27,28 @@ from torch.nn import Module
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
|
||||
import lightning_fabric as lf
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.loggers import Logger as FabricLogger
|
||||
from lightning_fabric.utilities.apply_func import convert_to_tensors
|
||||
from lightning_fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning_fabric.utilities.distributed import _distributed_available, _sync_ddp
|
||||
from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning_fabric.wrappers import _FabricOptimizer
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.saving import ModelIO
|
||||
from pytorch_lightning.loggers import Logger
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import (
|
||||
import lightning.fabric as lf
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.loggers import Logger as FabricLogger
|
||||
from lightning.fabric.utilities.apply_func import convert_to_tensors
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.fabric.wrappers import _FabricOptimizer
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from lightning.pytorch.core.mixins import HyperparametersMixin
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.core.saving import ModelIO
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from lightning.pytorch.utilities import GradClipAlgorithmType
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
|
||||
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
||||
from lightning.pytorch.utilities.types import (
|
||||
_METRIC,
|
||||
EPOCH_OUTPUT,
|
||||
LRSchedulerPLType,
|
||||
|
@ -135,7 +135,7 @@ class LightningModule(
|
|||
|
||||
Args:
|
||||
use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
|
||||
:class:`~pytorch_lightning.core.optimizer.LightningOptimizer` for automatic handling of precision and
|
||||
:class:`~lightning.pytorch.core.optimizer.LightningOptimizer` for automatic handling of precision and
|
||||
profiling.
|
||||
|
||||
Returns:
|
||||
|
@ -1168,16 +1168,16 @@ class LightningModule(
|
|||
"""
|
||||
|
||||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
|
||||
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
|
||||
calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic.
|
||||
"""Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it
|
||||
calls :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic.
|
||||
|
||||
The :meth:`~pytorch_lightning.core.module.LightningModule.predict_step` is used
|
||||
The :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` is used
|
||||
to scale inference on multi-devices.
|
||||
|
||||
To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter`
|
||||
To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
|
||||
callback to write the predictions to disk or database after each batch or on epoch end.
|
||||
|
||||
The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn
|
||||
The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
|
||||
based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
|
||||
or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
|
||||
|
||||
|
@ -1209,7 +1209,7 @@ class LightningModule(
|
|||
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
|
||||
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
|
||||
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
|
||||
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
|
||||
Lightning will make sure :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks
|
||||
run last.
|
||||
|
||||
Return:
|
||||
|
@ -1309,7 +1309,7 @@ class LightningModule(
|
|||
)
|
||||
|
||||
Metrics can be made available to monitor by simply logging it using
|
||||
``self.log('metric_to_track', metric_val)`` in your :class:`~pytorch_lightning.core.module.LightningModule`.
|
||||
``self.log('metric_to_track', metric_val)`` in your :class:`~lightning.pytorch.core.module.LightningModule`.
|
||||
|
||||
Note:
|
||||
Some things to know:
|
||||
|
@ -1504,7 +1504,7 @@ class LightningModule(
|
|||
def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None:
|
||||
r"""
|
||||
Override this method to adjust the default way the
|
||||
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
|
||||
:class:`~lightning.pytorch.trainer.trainer.Trainer` calls each scheduler.
|
||||
By default, Lightning calls ``step()`` and as shown in the example
|
||||
for each scheduler based on its ``interval``.
|
||||
|
||||
|
@ -1539,7 +1539,7 @@ class LightningModule(
|
|||
optimizer_closure: Optional[Callable[[], Any]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
|
||||
Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
|
||||
the optimizer.
|
||||
|
||||
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
|
||||
|
@ -1695,7 +1695,7 @@ class LightningModule(
|
|||
|
||||
Note:
|
||||
- Requires the implementation of the
|
||||
:meth:`~pytorch_lightning.core.module.LightningModule.forward` method.
|
||||
:meth:`~lightning.pytorch.core.module.LightningModule.forward` method.
|
||||
- The exported script will be set to evaluation mode.
|
||||
- It is recommended that you install the latest supported version of PyTorch
|
||||
to use this feature without limitations. See also the :mod:`torch.jit`
|
|
@ -20,12 +20,12 @@ import torch
|
|||
from torch import optim
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.pytorch.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
|
||||
|
||||
|
||||
def do_nothing_closure() -> None:
|
||||
|
@ -79,7 +79,7 @@ class LightningOptimizer:
|
|||
Setting `sync_grad` to False will block this synchronization and improve performance.
|
||||
"""
|
||||
# local import here to avoid circular import
|
||||
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
|
||||
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
|
||||
|
||||
assert self._strategy is not None
|
||||
lightning_module = self._strategy.lightning_module
|
|
@ -28,15 +28,15 @@ import yaml
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from typing_extensions import Self
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning_fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
from lightning.pytorch.utilities import _OMEGACONF_AVAILABLE
|
||||
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
||||
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
|
||||
from lightning.pytorch.utilities.parsing import AttributeDict, parse_class_init_keys
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
|
@ -20,10 +20,10 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
|
||||
|
||||
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
|
||||
from lightning.pytorch import LightningDataModule, LightningModule
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
||||
|
||||
class RandomDictDataset(Dataset):
|
|
@ -24,9 +24,9 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
|
||||
from lightning_fabric.utilities.imports import _IS_WINDOWS
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import transforms as transform_lib
|
|
@ -13,13 +13,13 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
from pytorch_lightning.loggers.logger import Logger
|
||||
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F401
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
||||
from lightning.pytorch.loggers.logger import Logger
|
||||
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.neptune import NeptuneLogger # noqa: F401
|
||||
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
||||
from lightning.pytorch.loggers.wandb import WandbLogger # noqa: F401
|
||||
|
||||
__all__ = ["CSVLogger", "Logger", "TensorBoardLogger"]
|
||||
|
|
@ -25,10 +25,10 @@ from lightning_utilities.core.imports import module_available
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_COMET_AVAILABLE = module_available("comet_ml")
|
||||
|
@ -54,7 +54,7 @@ else:
|
|||
class CometLogger(Logger):
|
||||
r"""
|
||||
Track your parameters, metrics, source code and more using
|
||||
`Comet <https://www.comet.com/?utm_source=pytorch_lightning&utm_medium=referral>`_.
|
||||
`Comet <https://www.comet.com/?utm_source=lightning.pytorch&utm_medium=referral>`_.
|
||||
|
||||
Install it with pip:
|
||||
|
||||
|
@ -69,8 +69,8 @@ class CometLogger(Logger):
|
|||
.. code-block:: python
|
||||
|
||||
import os
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import CometLogger
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.loggers import CometLogger
|
||||
|
||||
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
|
||||
comet_logger = CometLogger(
|
||||
|
@ -88,7 +88,7 @@ class CometLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import CometLogger
|
||||
from lightning.pytorch.loggers import CometLogger
|
||||
|
||||
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
|
||||
comet_logger = CometLogger(
|
||||
|
@ -102,7 +102,7 @@ class CometLogger(Logger):
|
|||
|
||||
**Log Hyperparameters:**
|
||||
|
||||
Log parameters used to initialize a :class:`~pytorch_lightning.core.module.LightningModule`:
|
||||
Log parameters used to initialize a :class:`~lightning.pytorch.core.module.LightningModule`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -270,7 +270,7 @@ class CometLogger(Logger):
|
|||
def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]:
|
||||
r"""
|
||||
Actual Comet object. To use Comet features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
|
||||
|
||||
Example::
|
||||
|
|
@ -23,14 +23,14 @@ import os
|
|||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from lightning_fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
|
||||
from lightning_fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
|
||||
from lightning_fabric.loggers.logger import rank_zero_experiment
|
||||
from lightning_fabric.utilities.logger import _convert_params
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.logger import Logger
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
|
||||
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
|
||||
from lightning.fabric.loggers.logger import rank_zero_experiment
|
||||
from lightning.fabric.utilities.logger import _convert_params
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.core.saving import save_hparams_to_yaml
|
||||
from lightning.pytorch.loggers.logger import Logger
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,8 +70,8 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
Logs are saved to ``os.path.join(save_dir, name, version)``.
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.loggers import CSVLogger
|
||||
>>> from lightning.pytorch import Trainer
|
||||
>>> from lightning.pytorch.loggers import CSVLogger
|
||||
>>> logger = CSVLogger("logs", name="my_exp_name")
|
||||
>>> trainer = Trainer(logger=logger)
|
||||
|
||||
|
@ -144,7 +144,7 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
r"""
|
||||
|
||||
Actual _ExperimentWriter object. To use _ExperimentWriter features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
|
||||
|
||||
Example::
|
||||
|
|
@ -22,10 +22,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|||
|
||||
import numpy as np
|
||||
|
||||
from lightning_fabric.loggers import Logger as FabricLogger
|
||||
from lightning_fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
|
||||
from lightning_fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from lightning.fabric.loggers import Logger as FabricLogger
|
||||
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
|
||||
from lightning.fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
|
||||
|
||||
class Logger(FabricLogger, ABC):
|
|
@ -28,11 +28,11 @@ import yaml
|
|||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
||||
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.logger import _scan_checkpoints
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
from lightning.pytorch.utilities.logger import _scan_checkpoints
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
@ -83,17 +83,17 @@ class MLFlowLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
|
||||
mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
|
||||
trainer = Trainer(logger=mlf_logger)
|
||||
|
||||
Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows:
|
||||
Use the logger anywhere in your :class:`~lightning.pytorch.core.module.LightningModule` as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
|
||||
|
||||
class LitModel(LightningModule):
|
||||
|
@ -115,12 +115,12 @@ class MLFlowLogger(Logger):
|
|||
save_dir: A path to a local directory where the MLflow runs get saved.
|
||||
Defaults to `./mlflow` if `tracking_uri` is not provided.
|
||||
Has no effect if `tracking_uri` is provided.
|
||||
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
||||
log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint`
|
||||
as MLFlow artifacts.
|
||||
|
||||
* if ``log_model == 'all'``, checkpoints are logged during training.
|
||||
* if ``log_model == True``, checkpoints are logged at the end of training, except when
|
||||
:paramref:`~pytorch_lightning.callbacks.Checkpoint.save_top_k` ``== -1``
|
||||
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
|
||||
which also logs every checkpoint during training.
|
||||
* if ``log_model == False`` (default), no checkpoint is logged.
|
||||
|
||||
|
@ -177,7 +177,7 @@ class MLFlowLogger(Logger):
|
|||
def experiment(self) -> MlflowClient:
|
||||
r"""
|
||||
Actual MLflow object. To use MLflow features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
|
||||
|
||||
Example::
|
||||
|
|
@ -27,12 +27,12 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union
|
|||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from lightning.pytorch.callbacks import Checkpoint
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
|
||||
_NEPTUNE_AVAILABLE = RequirementCache("neptune-client")
|
||||
if _NEPTUNE_AVAILABLE:
|
||||
|
@ -70,8 +70,8 @@ class NeptuneLogger(Logger):
|
|||
.. code-block:: python
|
||||
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.loggers import NeptuneLogger
|
||||
|
||||
neptune_logger = NeptuneLogger(
|
||||
api_key="ANONYMOUS", # replace with your own
|
||||
|
@ -82,12 +82,12 @@ class NeptuneLogger(Logger):
|
|||
|
||||
**How to use NeptuneLogger?**
|
||||
|
||||
Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows:
|
||||
Use the logger anywhere in your :class:`~lightning.pytorch.core.module.LightningModule` as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from neptune.new.types import File
|
||||
from pytorch_lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
|
||||
|
||||
class LitModel(LightningModule):
|
||||
|
@ -137,7 +137,7 @@ class NeptuneLogger(Logger):
|
|||
|
||||
**Log model checkpoints**
|
||||
|
||||
If you have :class:`~pytorch_lightning.callbacks.ModelCheckpoint` configured,
|
||||
If you have :class:`~lightning.pytorch.callbacks.ModelCheckpoint` configured,
|
||||
Neptune logger automatically logs model checkpoints.
|
||||
Model weights will be uploaded to the: "model/checkpoints" namespace in the Neptune Run.
|
||||
You can disable this option:
|
||||
|
@ -153,8 +153,8 @@ class NeptuneLogger(Logger):
|
|||
.. testcode::
|
||||
:skipif: not _NEPTUNE_AVAILABLE
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.loggers import NeptuneLogger
|
||||
|
||||
neptune_logger = NeptuneLogger(
|
||||
project="common/pytorch-lightning-integration",
|
||||
|
@ -326,7 +326,7 @@ class NeptuneLogger(Logger):
|
|||
def experiment(self) -> Run:
|
||||
r"""
|
||||
Actual Neptune run object. Allows you to use neptune logging features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule`.
|
||||
:class:`~lightning.pytorch.core.module.LightningModule`.
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -382,7 +382,7 @@ class NeptuneLogger(Logger):
|
|||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
from lightning.pytorch.loggers import NeptuneLogger
|
||||
|
||||
PARAMS = {
|
||||
"batch_size": 64,
|
|
@ -23,16 +23,16 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
|
||||
from lightning_fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
|
||||
from lightning_fabric.utilities.logger import _convert_params
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.logger import Logger
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
|
||||
from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
|
||||
from lightning.fabric.utilities.logger import _convert_params
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.core.saving import save_hparams_to_yaml
|
||||
from lightning.pytorch.loggers.logger import Logger
|
||||
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,8 +56,8 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger("tb_logs", name="my_model")
|
||||
trainer = Trainer(logger=logger)
|
|
@ -24,13 +24,13 @@ import torch.nn as nn
|
|||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
||||
from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _scan_checkpoints
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.logger import _scan_checkpoints
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
@ -61,7 +61,7 @@ class WandbLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(project="MNIST")
|
||||
|
||||
|
@ -75,7 +75,7 @@ class WandbLogger(Logger):
|
|||
|
||||
**Log metrics**
|
||||
|
||||
Log from :class:`~pytorch_lightning.core.module.LightningModule`:
|
||||
Log from :class:`~lightning.pytorch.core.module.LightningModule`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -91,7 +91,7 @@ class WandbLogger(Logger):
|
|||
|
||||
**Log hyper-parameters**
|
||||
|
||||
Save :class:`~pytorch_lightning.core.module.LightningModule` parameters:
|
||||
Save :class:`~lightning.pytorch.core.module.LightningModule` parameters:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -151,7 +151,7 @@ class WandbLogger(Logger):
|
|||
|
||||
wandb_logger = WandbLogger(log_model="all")
|
||||
|
||||
Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`:
|
||||
Custom checkpointing can be set up through :class:`~lightning.pytorch.callbacks.ModelCheckpoint`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -227,7 +227,7 @@ class WandbLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact")
|
||||
|
||||
|
@ -244,7 +244,7 @@ class WandbLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
wandb_logger = WandbLogger(project="my_project", name="my_run")
|
||||
wandb_logger.use_artifact(artifact="path/to/artifact")
|
||||
|
@ -262,12 +262,12 @@ class WandbLogger(Logger):
|
|||
id: Same as version.
|
||||
anonymous: Enables or explicitly disables anonymous logging.
|
||||
project: The name of the project to which this run will belong.
|
||||
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
|
||||
log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.ModelCheckpoint`
|
||||
as W&B artifacts. `latest` and `best` aliases are automatically set.
|
||||
|
||||
* if ``log_model == 'all'``, checkpoints are logged during training.
|
||||
* if ``log_model == True``, checkpoints are logged at the end of training, except when
|
||||
:paramref:`~pytorch_lightning.callbacks.ModelCheckpoint.save_top_k` ``== -1``
|
||||
:paramref:`~lightning.pytorch.callbacks.ModelCheckpoint.save_top_k` ``== -1``
|
||||
which also logs every checkpoint during training.
|
||||
* if ``log_model == False`` (default), no checkpoint is logged.
|
||||
|
||||
|
@ -376,7 +376,7 @@ class WandbLogger(Logger):
|
|||
r"""
|
||||
|
||||
Actual wandb object. To use wandb features in your
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
|
||||
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
|
||||
|
||||
Example::
|
||||
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
|
||||
from pytorch_lightning.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import _FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
|
||||
from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
|
||||
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
|
||||
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
|
||||
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
|
||||
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
|
|
@ -12,6 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401
|
||||
from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
|
||||
from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
|
||||
from lightning.pytorch.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401
|
|
@ -17,8 +17,8 @@ from typing import Sequence
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.loops.progress import DataLoaderProgress
|
||||
from lightning.pytorch.loops.loop import _Loop
|
||||
from lightning.pytorch.loops.progress import DataLoaderProgress
|
||||
|
||||
|
||||
class _DataLoaderLoop(_Loop):
|
|
@ -21,17 +21,17 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
from torch import Tensor
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||
from pytorch_lightning.loops.dataloader import _DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop
|
||||
from pytorch_lightning.loops.utilities import _set_sampler_epoch
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||
from lightning.pytorch.loops.dataloader import _DataLoaderLoop
|
||||
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop
|
||||
from lightning.pytorch.loops.utilities import _set_sampler_epoch
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
||||
from lightning.pytorch.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
if _RICH_AVAILABLE:
|
||||
from rich import get_console
|
|
@ -2,12 +2,12 @@ from typing import Any, List, Optional, Sequence, Union
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
|
||||
from pytorch_lightning.loops.utilities import _set_sampler_epoch
|
||||
from pytorch_lightning.strategies import DDPSpawnStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
|
||||
from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop
|
||||
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
|
||||
from lightning.pytorch.loops.utilities import _set_sampler_epoch
|
||||
from lightning.pytorch.strategies import DDPSpawnStrategy
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT
|
||||
|
||||
|
||||
class _PredictionLoop(_DataLoaderLoop):
|
|
@ -12,6 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
|
||||
from lightning.pytorch.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
|
||||
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
|
||||
from lightning.pytorch.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
|
|
@ -16,13 +16,13 @@ from collections import OrderedDict
|
|||
from functools import lru_cache
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.loops.progress import BatchProgress
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import SIGTERMException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
from lightning.pytorch.loops.loop import _Loop
|
||||
from lightning.pytorch.loops.progress import BatchProgress
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.exceptions import SIGTERMException
|
||||
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
||||
|
||||
class _EvaluationEpochLoop(_Loop):
|
|
@ -3,11 +3,11 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.utilities import move_data_to_device
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.loops.progress import Progress
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.utilities.rank_zero import WarningCache
|
||||
from lightning.fabric.utilities import move_data_to_device
|
||||
from lightning.pytorch.loops.loop import _Loop
|
||||
from lightning.pytorch.loops.progress import Progress
|
||||
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
@ -168,7 +168,7 @@ class _PredictionEpochLoop(_Loop):
|
|||
|
||||
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
|
||||
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
|
||||
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
|
||||
:class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`."""
|
||||
# the batch_sampler is not be defined in case of CombinedDataLoaders
|
||||
assert self.trainer.predict_dataloaders
|
||||
batch_sampler = getattr(
|
|
@ -17,18 +17,18 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import loops # import as loops to avoid circular imports
|
||||
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.progress import BatchProgress, SchedulerProgress
|
||||
from pytorch_lightning.loops.utilities import _is_max_limit_reached
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from lightning.pytorch import loops # import as loops to avoid circular imports
|
||||
from lightning.pytorch.loops.optimization import _ManualOptimization, _OptimizerLoop
|
||||
from lightning.pytorch.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
|
||||
from lightning.pytorch.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
|
||||
from lightning.pytorch.loops.progress import BatchProgress, SchedulerProgress
|
||||
from lightning.pytorch.loops.utilities import _is_max_limit_reached
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
|
||||
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
|
||||
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
|
||||
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
||||
|
@ -37,19 +37,19 @@ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
|||
class _TrainingEpochLoop(loops._Loop):
|
||||
"""
|
||||
Iterates over all batches in the dataloader (one epoch) that the user returns in their
|
||||
:meth:`~pytorch_lightning.core.module.LightningModule.train_dataloader` method.
|
||||
:meth:`~lightning.pytorch.core.module.LightningModule.train_dataloader` method.
|
||||
|
||||
Its main responsibilities are calling the ``*_epoch_{start,end}`` hooks, accumulating outputs if the user request
|
||||
them in one of these hooks, and running validation at the requested interval.
|
||||
|
||||
The validation is carried out by yet another loop,
|
||||
:class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
|
||||
:class:`~lightning.pytorch.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
|
||||
|
||||
In the ``run()`` method, the training epoch loop could in theory simply call the
|
||||
``LightningModule.training_step`` already and perform the optimization.
|
||||
However, Lightning has built-in support for automatic optimization with multiple optimizers.
|
||||
For this reason there are actually two more loops nested under
|
||||
:class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
|
||||
:class:`~lightning.pytorch.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
|
||||
|
||||
Args:
|
||||
min_steps: The minimum number of steps (batches) to process
|
|
@ -14,19 +14,19 @@
|
|||
import logging
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.epoch import _TrainingEpochLoop
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.progress import Progress
|
||||
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.loops import _Loop
|
||||
from lightning.pytorch.loops.epoch import _TrainingEpochLoop
|
||||
from lightning.pytorch.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
|
||||
from lightning.pytorch.loops.progress import Progress
|
||||
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
|
||||
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
|
||||
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loops.progress import BaseProgress
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.loops.progress import BaseProgress
|
||||
|
||||
|
||||
class _Loop:
|
|
@ -12,5 +12,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
|
||||
from lightning.pytorch.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
|
||||
from lightning.pytorch.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
|
|
@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar
|
||||
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -18,19 +18,19 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.core.optimizer import do_nothing_closure
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.optimization.closure import OutputResult
|
||||
from pytorch_lightning.loops.progress import Progress, ReadyCompletedTracker
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from lightning.pytorch.core.optimizer import do_nothing_closure
|
||||
from lightning.pytorch.loops import _Loop
|
||||
from lightning.pytorch.loops.optimization.closure import OutputResult
|
||||
from lightning.pytorch.loops.progress import Progress, ReadyCompletedTracker
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManualResult(OutputResult):
|
||||
"""A container to hold the result returned by the ``ManualLoop``.
|
||||
|
||||
It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.
|
||||
It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.
|
||||
|
||||
Attributes:
|
||||
extra: Anything returned by the ``training_step``.
|
||||
|
@ -66,11 +66,11 @@ _OUTPUTS_TYPE = Dict[str, Any]
|
|||
|
||||
class _ManualOptimization(_Loop):
|
||||
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
|
||||
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
|
||||
entirely in the :meth:`~lightning.pytorch.core.module.LightningModule.training_step` and therefore the user is
|
||||
responsible for back-propagating gradients and making calls to the optimizers.
|
||||
|
||||
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
|
||||
:meth:`~pytorch_lightning.core.module.LightningModule.training_step`) and passing through the output(s).
|
||||
:meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).
|
||||
"""
|
||||
|
||||
output_result_cls = ManualResult
|
|
@ -19,21 +19,21 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
|
||||
from pytorch_lightning.loops.progress import OptimizationProgress
|
||||
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import WarningCache
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.loops.loop import _Loop
|
||||
from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult
|
||||
from lightning.pytorch.loops.progress import OptimizationProgress
|
||||
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClosureResult(OutputResult):
|
||||
"""A container to hold the result of a :class:`Closure` call.
|
||||
|
||||
It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.
|
||||
It is created from the output of :meth:`~lightning.pytorch.core.module.LightningModule.training_step`.
|
||||
|
||||
Attributes:
|
||||
closure_loss: The loss with a graph attached.
|
||||
|
@ -95,7 +95,7 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
do something with the output.
|
||||
|
||||
Args:
|
||||
step_fn: This is typically the :meth:`pytorch_lightning.core.module.LightningModule.training_step
|
||||
step_fn: This is typically the :meth:`lightning.pytorch.core.module.LightningModule.training_step
|
||||
wrapped with processing for its outputs
|
||||
backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
|
||||
Can be set to ``None`` to skip the backward operation.
|
|
@ -18,15 +18,15 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.progress import BaseProgress
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch.callbacks.timer import Timer
|
||||
from lightning.pytorch.loops import _Loop
|
||||
from lightning.pytorch.loops.progress import BaseProgress
|
||||
from lightning.pytorch.strategies.parallel import ParallelStrategy
|
||||
from lightning.pytorch.strategies.strategy import Strategy
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
def check_finite_loss(loss: Optional[Tensor]) -> None:
|
||||
|
@ -83,7 +83,7 @@ def _parse_loop_limits(
|
|||
|
||||
@contextmanager
|
||||
def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]:
|
||||
"""Blocks synchronization in :class:`~pytorch_lightning.strategies.parallel.ParallelStrategy`. This is useful
|
||||
"""Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful
|
||||
for example when accumulating gradients to reduce communication when it is not needed.
|
||||
|
||||
Args:
|
|
@ -15,8 +15,8 @@ from typing import Any, Union
|
|||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
|
||||
|
||||
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
|
@ -19,9 +19,9 @@ import torch
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
def _ignore_scalar_return_in_dp() -> None:
|
|
@ -19,7 +19,7 @@ from torch import Tensor
|
|||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
|
||||
|
||||
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
||||
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
|
||||
|
||||
|
||||
def _find_tensors(
|
|
@ -0,0 +1,39 @@
|
|||
from typing import Union
|
||||
|
||||
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO
|
||||
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
|
||||
from lightning.pytorch.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
|
||||
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.colossalai import ColossalAIPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.hpu import HPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.ipu import IPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.tpu import TPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
|
||||
|
||||
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
|
||||
PLUGIN_INPUT = Union[PLUGIN, str]
|
||||
|
||||
__all__ = [
|
||||
"AsyncCheckpointIO",
|
||||
"CheckpointIO",
|
||||
"TorchCheckpointIO",
|
||||
"XLACheckpointIO",
|
||||
"HPUCheckpointIO",
|
||||
"ColossalAIPrecisionPlugin",
|
||||
"DeepSpeedPrecisionPlugin",
|
||||
"DoublePrecisionPlugin",
|
||||
"IPUPrecisionPlugin",
|
||||
"HPUPrecisionPlugin",
|
||||
"MixedPrecisionPlugin",
|
||||
"PrecisionPlugin",
|
||||
"FSDPMixedPrecisionPlugin",
|
||||
"TPUPrecisionPlugin",
|
||||
"TPUBf16PrecisionPlugin",
|
||||
"LayerSync",
|
||||
"TorchSyncBatchNorm",
|
||||
]
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lightning_fabric.plugins import ClusterEnvironment # noqa: F401
|
||||
from lightning_fabric.plugins.environments import ( # noqa: F401
|
||||
from lightning.fabric.plugins import ClusterEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments import ( # noqa: F401
|
||||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
LSFEnvironment,
|
||||
|
@ -20,4 +20,4 @@ from lightning_fabric.plugins.environments import ( # noqa: F401
|
|||
TorchElasticEnvironment,
|
||||
XLAEnvironment,
|
||||
)
|
||||
from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
|
||||
from lightning.pytorch.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
|
|
@ -15,7 +15,7 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from lightning_fabric.plugins import ClusterEnvironment
|
||||
from lightning.fabric.plugins import ClusterEnvironment
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lightning_fabric.plugins import CheckpointIO, TorchCheckpointIO, XLACheckpointIO
|
||||
from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
|
||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||
from lightning.fabric.plugins import CheckpointIO, TorchCheckpointIO, XLACheckpointIO
|
||||
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
|
||||
from lightning.pytorch.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||
|
||||
__all__ = ["AsyncCheckpointIO", "CheckpointIO", "HPUCheckpointIO", "TorchCheckpointIO", "XLACheckpointIO"]
|
|
@ -15,8 +15,8 @@
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
from lightning_fabric.plugins import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
from lightning.fabric.plugins import CheckpointIO
|
||||
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
|
||||
|
||||
class AsyncCheckpointIO(_WrappingCheckpointIO):
|
|
@ -13,4 +13,4 @@
|
|||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_fabric.plugins import CheckpointIO # noqa: F401
|
||||
from lightning.fabric.plugins import CheckpointIO # noqa: F401
|
|
@ -17,10 +17,10 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_fabric.plugins import TorchCheckpointIO
|
||||
from lightning_fabric.utilities import move_data_to_device
|
||||
from lightning_fabric.utilities.cloud_io import _atomic_save, get_filesystem
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
from lightning.fabric.plugins import TorchCheckpointIO
|
||||
from lightning.fabric.utilities import move_data_to_device
|
||||
from lightning.fabric.utilities.cloud_io import _atomic_save, get_filesystem
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
|
||||
|
||||
class HPUCheckpointIO(TorchCheckpointIO):
|
|
@ -13,4 +13,4 @@
|
|||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_fabric.plugins import TorchCheckpointIO # noqa: F401
|
||||
from lightning.fabric.plugins import TorchCheckpointIO # noqa: F401
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from lightning_fabric.plugins import CheckpointIO
|
||||
from lightning.fabric.plugins import CheckpointIO
|
||||
|
||||
|
||||
class _WrappingCheckpointIO(CheckpointIO):
|
|
@ -13,4 +13,4 @@
|
|||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_fabric.plugins import XLACheckpointIO # noqa: F401
|
||||
from lightning.fabric.plugins import XLACheckpointIO # noqa: F401
|
|
@ -11,16 +11,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.plugins.precision.amp import MixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.colossalai import ColossalAIPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.hpu import HPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.ipu import IPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.tpu import TPUPrecisionPlugin
|
||||
from lightning.pytorch.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
|
||||
|
||||
__all__ = [
|
||||
"ColossalAIPrecisionPlugin",
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue