bump: Torch `2.5` (#20351)
* bump: Torch `2.5.0`
* push docker
* docker
* 2.5.1 and mypy
* update USE_DISTRIBUTED=0 test
* also for pytorch lightning no distributed
* set USE_LIBUV=0 on windows
* try drop pickle warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* disable compiling update_metrics
* bump 2.2.x to bugfix
* disable also log in logger connector (also calls metric)
* more point release bumps
* remove unloved type ignore and print some more on exit
* update checkgroup
* minor versions
* shortened version in build-pl
* pytorch 2.4 is with python 3.11
* 2.1 and 2.3 without patch release
* for 2.4.1: docker with 3.11 test with 3.12
---------
Co-authored-by: Thomas Viehmann <tv.code@beamnet.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 61a403a512
)
This commit is contained in:
parent
d62b53a643
commit
b1eceb1516
|
@ -46,7 +46,7 @@ jobs:
|
|||
variables:
|
||||
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
|
||||
container:
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
|
||||
options: "--gpus=all --shm-size=32g"
|
||||
strategy:
|
||||
matrix:
|
||||
|
|
|
@ -60,7 +60,7 @@ jobs:
|
|||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
|
||||
PACKAGE_NAME: "fabric"
|
||||
"Lightning | latest":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
workspace:
|
||||
clean: all
|
||||
|
|
|
@ -53,7 +53,7 @@ jobs:
|
|||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
|
||||
PACKAGE_NAME: "pytorch"
|
||||
"Lightning | latest":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
pool: lit-rtx-3090
|
||||
variables:
|
||||
|
|
|
@ -21,19 +21,22 @@ subprojects:
|
|||
checks:
|
||||
- "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.11, 2.2)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.11, 2.3)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.12, 2.4)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
|
||||
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
|
||||
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.11, 2.2)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.11, 2.3)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.12, 2.4)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)"
|
||||
- "pl-cpu (macOS-14, pytorch, 3.9, 2.1)"
|
||||
- "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)"
|
||||
- "pl-cpu (windows-2022, pytorch, 3.9, 2.1)"
|
||||
|
@ -141,15 +144,17 @@ subprojects:
|
|||
- "!*.md"
|
||||
- "!**/*.md"
|
||||
checks:
|
||||
- "build-cuda (3.11, 2.1, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.3, 12.1.0)"
|
||||
- "build-cuda (3.12, 2.4, 12.1.0)"
|
||||
- "build-cuda (3.10, 2.1.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.2.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.3.1, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.4.1, 12.1.0)"
|
||||
- "build-cuda (3.12, 2.5.1, 12.1.0)"
|
||||
#- "build-NGC"
|
||||
- "build-pl (3.11, 2.1, 12.1.0)"
|
||||
- "build-pl (3.10, 2.1, 12.1.0)"
|
||||
- "build-pl (3.11, 2.2, 12.1.0)"
|
||||
- "build-pl (3.11, 2.3, 12.1.0)"
|
||||
- "build-pl (3.12, 2.4, 12.1.0)"
|
||||
- "build-pl (3.11, 2.4, 12.1.0)"
|
||||
- "build-pl (3.12, 2.5, 12.1.0)"
|
||||
|
||||
# SECTION: lightning_fabric
|
||||
|
||||
|
@ -168,19 +173,22 @@ subprojects:
|
|||
checks:
|
||||
- "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.10, 2.1)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.11, 2.3)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
|
||||
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
|
||||
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.10, 2.1)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)"
|
||||
- "fabric-cpu (macOS-14, fabric, 3.9, 2.1)"
|
||||
- "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)"
|
||||
- "fabric-cpu (windows-2022, fabric, 3.9, 2.1)"
|
||||
|
|
|
@ -43,15 +43,18 @@ jobs:
|
|||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
|
||||
- { os: "macOS-13", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
|
||||
|
|
|
@ -47,15 +47,18 @@ jobs:
|
|||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
|
||||
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
|
||||
- { os: "macOS-13", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
|
||||
|
|
|
@ -43,10 +43,11 @@ jobs:
|
|||
include:
|
||||
# We only release one docker image per PyTorch version.
|
||||
# Make sure the matrix here matches the one below.
|
||||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.0" }
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
@ -103,10 +104,11 @@ jobs:
|
|||
include:
|
||||
# These are the base images for PL release docker images.
|
||||
# Make sure the matrix here matches the one above.
|
||||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.1.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.0" }
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
|
@ -115,6 +117,12 @@ jobs:
|
|||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: shorten Torch version
|
||||
run: |
|
||||
# convert 1.10.2 to 1.10
|
||||
pt_version=$(echo ${{ matrix.pytorch_version }} | cut -d. -f1,2)
|
||||
echo "PT_VERSION=$pt_version" >> $GITHUB_ENV
|
||||
- uses: docker/build-push-action@v6
|
||||
with:
|
||||
build-args: |
|
||||
|
@ -123,7 +131,7 @@ jobs:
|
|||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
file: dockers/base-cuda/Dockerfile
|
||||
push: ${{ env.PUSH_NIGHTLY }}
|
||||
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}"
|
||||
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}"
|
||||
timeout-minutes: 95
|
||||
- uses: ravsamhq/notify-slack-action@v2
|
||||
if: failure() && env.PUSH_NIGHTLY == 'true'
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
torch >=2.1.0, <2.5.0
|
||||
torch >=2.1.0, <2.6.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
packaging >=20.0, <=23.1
|
||||
typing-extensions >=4.4.0, <4.10.0
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
torchvision >=0.16.0, <0.20.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
torchvision >=0.16.0, <0.21.0
|
||||
torchmetrics >=0.10.0, <1.5.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -7,4 +7,4 @@ pytest-rerunfailures ==12.0
|
|||
pytest-random-order ==1.1.0
|
||||
click ==8.1.7
|
||||
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
|
||||
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
|
||||
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
torch >=2.1.0, <2.5.0
|
||||
torch >=2.1.0, <2.6.0
|
||||
tqdm >=4.57.0, <4.67.0
|
||||
PyYAML >=5.4, <6.1.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
|
||||
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
|
||||
packaging >=20.0, <=23.1
|
||||
typing-extensions >=4.4.0, <4.10.0
|
||||
lightning-utilities >=0.10.0, <0.12.0
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
requests <2.32.0
|
||||
torchvision >=0.16.0, <0.20.0
|
||||
torchvision >=0.16.0, <0.21.0
|
||||
ipython[all] <8.15.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
torchmetrics >=0.10.0, <1.5.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mypy==1.11.0
|
||||
torch==2.4.1
|
||||
torch==2.5.1
|
||||
|
||||
types-Markdown
|
||||
types-PyYAML
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lightning_utilities.core.imports import package_available
|
||||
|
||||
|
@ -26,6 +27,10 @@ if not _root_logger.hasHandlers():
|
|||
# https://github.com/pytorch/pytorch/issues/83973
|
||||
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
|
||||
|
||||
# see https://github.com/pytorch/pytorch/issues/139990
|
||||
if sys.platform == "win32":
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
|
||||
|
||||
from lightning.fabric.fabric import Fabric # noqa: E402
|
||||
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
|
||||
|
|
|
@ -531,7 +531,7 @@ class LightningModule(
|
|||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx, # type: ignore[arg-type]
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
add_dataloader_idx=add_dataloader_idx,
|
||||
batch_size=batch_size,
|
||||
|
@ -1405,7 +1405,9 @@ class LightningModule(
|
|||
input_sample = self._apply_batch_transfer_handler(input_sample)
|
||||
|
||||
file_path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
||||
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
|
||||
# BytesIO does work, too.
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
|
||||
self.train(mode)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -351,6 +351,7 @@ class _ResultCollection(dict):
|
|||
|
||||
return batch_size
|
||||
|
||||
@torch.compiler.disable
|
||||
def log(
|
||||
self,
|
||||
fx: str,
|
||||
|
@ -413,6 +414,7 @@ class _ResultCollection(dict):
|
|||
batch_size = self._extract_batch_size(self[key], batch_size, meta)
|
||||
self.update_metrics(key, value, batch_size)
|
||||
|
||||
@torch.compiler.disable
|
||||
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
|
||||
result_metric = self[key]
|
||||
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
|
||||
|
|
|
@ -48,6 +48,7 @@ function show_batched_output {
|
|||
# heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail
|
||||
if perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt | grep -qv -f testnames.txt; then
|
||||
echo "Potential error! Stopping."
|
||||
perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt
|
||||
rm standalone_test_output.txt
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -23,6 +23,13 @@ def test_import_fabric_with_torch_dist_unavailable():
|
|||
code = dedent(
|
||||
"""
|
||||
import torch
|
||||
try:
|
||||
# PyTorch 2.5 relies on torch,distributed._composable.fsdp not
|
||||
# existing with USE_DISTRIBUTED=0
|
||||
import torch._dynamo.variables.functions
|
||||
torch._dynamo.variables.functions._fsdp_param_group = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# pretend torch.distributed not available
|
||||
for name in list(torch.distributed.__dict__.keys()):
|
||||
|
@ -31,6 +38,11 @@ def test_import_fabric_with_torch_dist_unavailable():
|
|||
|
||||
torch.distributed.is_available = lambda: False
|
||||
|
||||
# needed for Dynamo in PT 2.5+ compare the torch.distributed source
|
||||
class _ProcessGroupStub:
|
||||
pass
|
||||
torch.distributed.ProcessGroup = _ProcessGroupStub
|
||||
|
||||
import lightning.fabric
|
||||
"""
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
@ -23,7 +22,6 @@ from unittest.mock import Mock
|
|||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -193,12 +191,10 @@ def test_pickling():
|
|||
early_stopping = EarlyStopping(monitor="foo")
|
||||
|
||||
early_stopping_pickled = pickle.dumps(early_stopping)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
early_stopping_pickled = cloudpickle.dumps(early_stopping)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import pickle
|
|||
import re
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
from datetime import timedelta
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
|
@ -32,7 +31,6 @@ import torch
|
|||
import yaml
|
||||
from jsonargparse import ArgumentParser
|
||||
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -352,12 +350,10 @@ def test_pickling(tmp_path):
|
|||
ckpt = ModelCheckpoint(dirpath=tmp_path)
|
||||
|
||||
ckpt_pickled = pickle.dumps(ckpt)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
ckpt_loaded = pickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
ckpt_pickled = cloudpickle.dumps(ckpt)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from unittest import mock
|
|||
import lightning.pytorch as pl
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import OnExceptionCheckpoint
|
||||
|
@ -254,7 +253,6 @@ def test_result_collection_restoration(tmp_path):
|
|||
}
|
||||
|
||||
# make sure can be pickled
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
pickle.loads(pickle.dumps(result))
|
||||
# make sure can be torch.loaded
|
||||
filepath = str(tmp_path / "result")
|
||||
|
|
|
@ -12,12 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
|
||||
from tests_pytorch import _PATH_DATASETS
|
||||
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
|
||||
|
@ -44,9 +42,7 @@ def test_pickling_dataset_mnist(dataset_cls, args):
|
|||
mnist = dataset_cls(**args)
|
||||
|
||||
mnist_pickled = pickle.dumps(mnist)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
pickle.loads(mnist_pickled)
|
||||
|
||||
mnist_pickled = cloudpickle.dumps(mnist)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
cloudpickle.loads(mnist_pickled)
|
||||
|
|
|
@ -14,13 +14,11 @@
|
|||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
|
||||
from lightning.pytorch import Callback, Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.loggers import (
|
||||
|
@ -184,11 +182,6 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
|
|||
trainer = Trainer(max_epochs=1, logger=logger)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
|
||||
with (
|
||||
pytest.warns(FutureWarning, match="`weights_only=False`")
|
||||
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
|
||||
else nullcontext()
|
||||
):
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import patch
|
||||
|
@ -21,7 +20,6 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
||||
|
@ -124,7 +122,6 @@ def test_multiple_loggers_pickle(tmp_path):
|
|||
|
||||
trainer = Trainer(logger=[logger1, logger2])
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
for logger in trainer2.loggers:
|
||||
logger.log_metrics({"acc": 1.0}, 0)
|
||||
|
|
|
@ -13,13 +13,11 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
|
@ -162,7 +160,6 @@ def test_wandb_pickle(wandb_mock, tmp_path):
|
|||
assert trainer.logger.experiment, "missing experiment"
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
|
||||
assert os.environ["WANDB_MODE"] == "dryrun"
|
||||
|
|
|
@ -117,6 +117,13 @@ def test_import_pytorch_lightning_with_torch_dist_unavailable():
|
|||
code = dedent(
|
||||
"""
|
||||
import torch
|
||||
try:
|
||||
# PyTorch 2.5 relies on torch,distributed._composable.fsdp not
|
||||
# existing with USE_DISTRIBUTED=0
|
||||
import torch._dynamo.variables.functions
|
||||
torch._dynamo.variables.functions._fsdp_param_group = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# pretend torch.distributed not available
|
||||
for name in list(torch.distributed.__dict__.keys()):
|
||||
|
@ -125,6 +132,11 @@ def test_import_pytorch_lightning_with_torch_dist_unavailable():
|
|||
|
||||
torch.distributed.is_available = lambda: False
|
||||
|
||||
# needed for Dynamo in PT 2.5+ compare the torch.distributed source
|
||||
class _ProcessGroupStub:
|
||||
pass
|
||||
torch.distributed.ProcessGroup = _ProcessGroupStub
|
||||
|
||||
import lightning.pytorch
|
||||
"""
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue