diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 38fb423d22..ea9cb03d18 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -379,7 +379,7 @@ class Accelerator(object): return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: - return checkpoint + return self.training_type_plugin.on_save(checkpoint) def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf4b93e04e..5817987520 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union, Dict import torch from torch.nn import Module @@ -153,7 +153,7 @@ class TrainingTypePlugin(Plugin, ABC): def test_step_end(self, output): return output - def on_save(self, checkpoint: dict) -> dict: + def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: return checkpoint def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: