From 83039ba4705b02fbb55d5922d15a258a2f041986 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 5 Jul 2021 12:55:41 +0200 Subject: [PATCH] Test deepcopy for progress tracking dataclasses (#8265) --- pytorch_lightning/trainer/progress.py | 47 +++++++++++++-------------- tests/trainer/test_progress.py | 13 ++++++-- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 2d7a1d7e8f..25f76ad085 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -18,19 +18,16 @@ from typing import Optional @dataclass class _DataclassStateDictMixin: - def __getstate__(self) -> dict: + def state_dict(self) -> dict: return asdict(self) - def __setstate__(self, state: dict) -> None: - self.__dict__.update(state) - - def state_dict(self) -> dict: - return self.__getstate__() + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) @classmethod def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin": obj = cls() - obj.__setstate__(state_dict) + obj.load_state_dict(state_dict) return obj @@ -115,9 +112,9 @@ class Progress(_DataclassStateDictMixin): def from_defaults(cls, **kwargs: Optional[int]) -> "Progress": return cls(total=Tracker(**kwargs), current=Tracker(**kwargs)) - def __setstate__(self, state: dict) -> None: - self.total.__setstate__(state["total"]) - self.current.__setstate__(state["current"]) + def load_state_dict(self, state_dict: dict) -> None: + self.total.load_state_dict(state_dict["total"]) + self.current.load_state_dict(state_dict["current"]) class BatchProgress(Progress): @@ -147,9 +144,9 @@ class EpochProgress(Progress): def reset_on_epoch(self) -> None: self.batch.current.reset() - def __setstate__(self, state: dict) -> None: - super().__setstate__(state) - self.batch.__setstate__(state["batch"]) + def load_state_dict(self, state_dict: dict) -> None: + super().load_state_dict(state_dict) + self.batch.load_state_dict(state_dict["batch"]) @dataclass @@ -169,9 +166,9 @@ class OptimizerProgress(_DataclassStateDictMixin): self.step.current.reset() self.zero_grad.current.reset() - def __setstate__(self, state: dict) -> None: - self.step.__setstate__(state["step"]) - self.zero_grad.__setstate__(state["zero_grad"]) + def load_state_dict(self, state_dict: dict) -> None: + self.step.load_state_dict(state_dict["step"]) + self.zero_grad.load_state_dict(state_dict["zero_grad"]) @dataclass @@ -200,9 +197,9 @@ class OptimizationProgress(_DataclassStateDictMixin): self.optimizer.reset_on_epoch() self.scheduler.current.reset() - def __setstate__(self, state: dict) -> None: - self.optimizer.__setstate__(state["optimizer"]) - self.scheduler.__setstate__(state["scheduler"]) + def load_state_dict(self, state_dict: dict) -> None: + self.optimizer.load_state_dict(state_dict["optimizer"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) @dataclass @@ -225,8 +222,8 @@ class EpochLoopProgress(_DataclassStateDictMixin): self.epoch.reset_on_epoch() self.epoch.current.reset() - def __setstate__(self, state: dict) -> None: - self.epoch.__setstate__(state["epoch"]) + def load_state_dict(self, state_dict: dict) -> None: + self.epoch.load_state_dict(state_dict["epoch"]) @dataclass @@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress): optim: OptimizationProgress = field(default_factory=OptimizationProgress) val: EpochLoopProgress = field(default_factory=EpochLoopProgress) - def __setstate__(self, state: dict) -> None: - super().__setstate__(state) - self.optim.__setstate__(state["optim"]) - self.val.__setstate__(state["val"]) + def load_state_dict(self, state_dict: dict) -> None: + super().load_state_dict(state_dict) + self.optim.load_state_dict(state_dict["optim"]) + self.val.load_state_dict(state_dict["val"]) @dataclass diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 8c287e8cb3..a3bbd5a36a 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import pytest from pytorch_lightning.trainer.progress import ( @@ -135,14 +137,17 @@ def test_optimizer_progress_default_factory(): def test_fit_loop_progress_serialization(): fit_loop = FitLoopProgress() + _ = deepcopy(fit_loop) + fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super` + state_dict = fit_loop.state_dict() # yapf: disable assert state_dict == { 'epoch': { # number of epochs across `fit` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, + 'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, # number of epochs this `fit` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, + 'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, 'batch': { # number of batches across `fit` calls 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, @@ -191,13 +196,16 @@ def test_fit_loop_progress_serialization(): } } # yapf: enable + new_loop = FitLoopProgress.from_state_dict(state_dict) assert fit_loop == new_loop def test_epoch_loop_progress_serialization(): loop = EpochLoopProgress() + _ = deepcopy(loop) state_dict = loop.state_dict() + # yapf: disable assert state_dict == { 'epoch': { @@ -214,5 +222,6 @@ def test_epoch_loop_progress_serialization(): } } # yapf: enable + new_loop = EpochLoopProgress.from_state_dict(state_dict) assert loop == new_loop