[fix] Use training type plugin hook when saving (FSDP 1/n) (#6321)

* Rely on training type plugin when saving

* Add better typing to training type plugin
This commit is contained in:
Sean Naren 2021-03-04 18:09:33 +00:00 committed by GitHub
parent e038e747a0
commit d01e8fdc86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -379,7 +379,7 @@ class Accelerator(object):
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
return checkpoint
return self.training_type_plugin.on_save(checkpoint)
def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union, Dict
import torch
from torch.nn import Module
@ -153,7 +153,7 @@ class TrainingTypePlugin(Plugin, ABC):
def test_step_end(self, output):
return output
def on_save(self, checkpoint: dict) -> dict:
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
return checkpoint
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: