diff --git a/CHANGELOG.md b/CHANGELOG.md index 375d6ee060..188b5f9809 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -226,6 +226,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) +- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741)) + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ba074e7cfb..a29310f65f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,10 +23,11 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.apply_func import apply_to_collection if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -37,6 +38,10 @@ if _TPU_AVAILABLE: else: xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + from omegaconf import DictConfig, ListConfig + class TPUSpawnPlugin(DDPSpawnPlugin): @@ -304,4 +309,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin): filepath: write-target file's path """ # Todo: TypeError: 'mappingproxy' object does not support item assignment + if _OMEGACONF_AVAILABLE: + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)