[bugfix] Add support for omegaconf and tpu (#6741)
* fix_hydra * update changelog Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
583fcf281c
commit
bb92754119
|
@ -226,6 +226,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
|
||||
|
||||
|
||||
- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
|
||||
|
||||
## [1.2.4] - 2021-03-16
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -23,10 +23,11 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -37,6 +38,10 @@ if _TPU_AVAILABLE:
|
|||
else:
|
||||
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
|
||||
|
||||
class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||
|
||||
|
@ -304,4 +309,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
filepath: write-target file's path
|
||||
"""
|
||||
# Todo: TypeError: 'mappingproxy' object does not support item assignment
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
|
||||
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
|
||||
|
|
Loading…
Reference in New Issue