From d01e8fdc86ff13b6b8dce90895f320c6512d5a14 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 4 Mar 2021 18:09:33 +0000 Subject: [PATCH] [fix] Use training type plugin hook when saving (FSDP 1/n) (#6321) * Rely on training type plugin when saving * Add better typing to training type plugin --- pytorch_lightning/accelerators/accelerator.py | 2 +- .../plugins/training_type/training_type_plugin.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 38fb423d22..ea9cb03d18 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -379,7 +379,7 @@ class Accelerator(object): return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: - return checkpoint + return self.training_type_plugin.on_save(checkpoint) def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf4b93e04e..5817987520 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union, Dict import torch from torch.nn import Module @@ -153,7 +153,7 @@ class TrainingTypePlugin(Plugin, ABC): def test_step_end(self, output): return output - def on_save(self, checkpoint: dict) -> dict: + def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: return checkpoint def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: