[bugfix] Add support for omegaconf and tpu (#6741)

* fix_hydra

* update changelog

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
thomas chaton 2021-03-30 16:21:25 +01:00 committed by GitHub
parent 583fcf281c
commit bb92754119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -226,6 +226,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
## [1.2.4] - 2021-03-16
### Changed

View File

@ -23,10 +23,11 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.apply_func import apply_to_collection
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -37,6 +38,10 @@ if _TPU_AVAILABLE:
else:
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
from omegaconf import DictConfig, ListConfig
class TPUSpawnPlugin(DDPSpawnPlugin):
@ -304,4 +309,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
filepath: write-target file's path
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)