Clean up last `ModelCheckpoint` `makedirs` call to IOPlugin (#11035)
This commit is contained in:
parent
7aee00c679
commit
d0b67f7376
|
@ -249,7 +249,7 @@ class ModelCheckpoint(Callback):
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
"""When pretrain routine starts we build the ckpt dir on the fly."""
|
"""When pretrain routine starts we resolve the ckpt dir on the fly."""
|
||||||
if self._save_on_train_epoch_end is None:
|
if self._save_on_train_epoch_end is None:
|
||||||
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
||||||
# validation, then we run after validation instead of on train epoch end
|
# validation, then we run after validation instead of on train epoch end
|
||||||
|
@ -600,9 +600,6 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
self.dirpath = ckpt_path
|
self.dirpath = ckpt_path
|
||||||
|
|
||||||
if not trainer.fast_dev_run and trainer.training_type_plugin.should_rank_save_checkpoint:
|
|
||||||
self._fs.makedirs(self.dirpath, exist_ok=True)
|
|
||||||
|
|
||||||
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
||||||
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
||||||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
||||||
|
|
|
@ -11,11 +11,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||||
from pytorch_lightning.utilities.types import _PATH
|
from pytorch_lightning.utilities.types import _PATH
|
||||||
|
|
||||||
if _TPU_AVAILABLE:
|
if _TPU_AVAILABLE:
|
||||||
|
@ -36,6 +38,8 @@ class XLACheckpointIO(TorchCheckpointIO):
|
||||||
path: write-target path
|
path: write-target path
|
||||||
storage_options: Optional parameters when saving the model/training states.
|
storage_options: Optional parameters when saving the model/training states.
|
||||||
"""
|
"""
|
||||||
|
fs = get_filesystem(path)
|
||||||
|
fs.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
# Todo: TypeError: 'mappingproxy' object does not support item assignment
|
# Todo: TypeError: 'mappingproxy' object does not support item assignment
|
||||||
# Ref: https://github.com/pytorch/xla/issues/2773
|
# Ref: https://github.com/pytorch/xla/issues/2773
|
||||||
if _OMEGACONF_AVAILABLE:
|
if _OMEGACONF_AVAILABLE:
|
||||||
|
|
|
@ -75,9 +75,6 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
||||||
self.tpu_local_core_rank = xm.get_local_ordinal()
|
self.tpu_local_core_rank = xm.get_local_ordinal()
|
||||||
self.tpu_global_core_rank = xm.get_ordinal()
|
self.tpu_global_core_rank = xm.get_ordinal()
|
||||||
|
|
||||||
def save(self, state_dict: Dict, path: _PATH) -> None:
|
|
||||||
xm.save(state_dict, path)
|
|
||||||
|
|
||||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
|
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
|
||||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||||
|
|
||||||
|
|
|
@ -1700,10 +1700,6 @@ class Trainer(
|
||||||
# some training types define a world size
|
# some training types define a world size
|
||||||
return getattr(self.training_type_plugin, "world_size", 1)
|
return getattr(self.training_type_plugin, "world_size", 1)
|
||||||
|
|
||||||
@property
|
|
||||||
def should_rank_save_checkpoint(self) -> bool:
|
|
||||||
return self.training_type_plugin.should_rank_save_checkpoint
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _distrib_type(self) -> _StrategyType:
|
def _distrib_type(self) -> _StrategyType:
|
||||||
return self._accelerator_connector._distrib_type
|
return self._accelerator_connector._distrib_type
|
||||||
|
|
Loading…
Reference in New Issue