Add back-compatibility for checkpoint io plugins in pl/plugins/io (#14519)
This commit is contained in:
parent
463439e624
commit
cbbd148089
|
@ -15,7 +15,6 @@ import logging
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from lightning_lite.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from lightning_lite.utilities.cloud_io import load as pl_load
|
||||
|
@ -53,9 +52,10 @@ class TorchCheckpointIO(CheckpointIO):
|
|||
# write the checkpoint dictionary on the file
|
||||
atomic_save(checkpoint, path)
|
||||
except AttributeError as err:
|
||||
# todo (sean): is this try catch necessary still?
|
||||
# todo: is this try catch necessary still?
|
||||
# https://github.com/Lightning-AI/lightning/pull/431
|
||||
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
|
||||
# TODO(lite): Lite doesn't support hyperparameters in the checkpoint, so this should be refactored
|
||||
key = "hyper_parameters"
|
||||
checkpoint.pop(key, None)
|
||||
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
|
||||
atomic_save(checkpoint, path)
|
||||
|
|
|
@ -18,8 +18,8 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
|
||||
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# For backward-compatibility
|
||||
from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
|
|
@ -90,7 +90,7 @@ def test_trainer_save_checkpoint_storage_options(tmpdir):
|
|||
instance_path = tmpdir + "/path.ckpt"
|
||||
instance_storage_options = "my instance storage options"
|
||||
|
||||
with mock.patch("pytorch_lightning.plugins.io.torch_plugin.TorchCheckpointIO.save_checkpoint") as io_mock:
|
||||
with mock.patch("lightning_lite.plugins.io.torch_plugin.TorchCheckpointIO.save_checkpoint") as io_mock:
|
||||
trainer.save_checkpoint(instance_path, storage_options=instance_storage_options)
|
||||
io_mock.assert_called_with(ANY, instance_path, storage_options=instance_storage_options)
|
||||
trainer.save_checkpoint(instance_path)
|
||||
|
|
|
@ -19,13 +19,13 @@ import pytest
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_lite.plugins.environments import LightningEnvironment
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule
|
||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment
|
||||
from pytorch_lightning.strategies.bagua import LightningBaguaModule
|
||||
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
|
||||
from pytorch_lightning.strategies.ipu import LightningIPUModule
|
||||
|
|
Loading…
Reference in New Issue