Test deepcopy for progress tracking dataclasses (#8265)

This commit is contained in:
Carlos Mocholí 2021-07-05 12:55:41 +02:00 committed by GitHub
parent ea88105b88
commit 83039ba470
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 27 deletions

View File

@ -18,19 +18,16 @@ from typing import Optional
@dataclass
class _DataclassStateDictMixin:
def __getstate__(self) -> dict:
def state_dict(self) -> dict:
return asdict(self)
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
def state_dict(self) -> dict:
return self.__getstate__()
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
@classmethod
def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
obj = cls()
obj.__setstate__(state_dict)
obj.load_state_dict(state_dict)
return obj
@ -115,9 +112,9 @@ class Progress(_DataclassStateDictMixin):
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
def __setstate__(self, state: dict) -> None:
self.total.__setstate__(state["total"])
self.current.__setstate__(state["current"])
def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
self.current.load_state_dict(state_dict["current"])
class BatchProgress(Progress):
@ -147,9 +144,9 @@ class EpochProgress(Progress):
def reset_on_epoch(self) -> None:
self.batch.current.reset()
def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.batch.__setstate__(state["batch"])
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
self.batch.load_state_dict(state_dict["batch"])
@dataclass
@ -169,9 +166,9 @@ class OptimizerProgress(_DataclassStateDictMixin):
self.step.current.reset()
self.zero_grad.current.reset()
def __setstate__(self, state: dict) -> None:
self.step.__setstate__(state["step"])
self.zero_grad.__setstate__(state["zero_grad"])
def load_state_dict(self, state_dict: dict) -> None:
self.step.load_state_dict(state_dict["step"])
self.zero_grad.load_state_dict(state_dict["zero_grad"])
@dataclass
@ -200,9 +197,9 @@ class OptimizationProgress(_DataclassStateDictMixin):
self.optimizer.reset_on_epoch()
self.scheduler.current.reset()
def __setstate__(self, state: dict) -> None:
self.optimizer.__setstate__(state["optimizer"])
self.scheduler.__setstate__(state["scheduler"])
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])
self.scheduler.load_state_dict(state_dict["scheduler"])
@dataclass
@ -225,8 +222,8 @@ class EpochLoopProgress(_DataclassStateDictMixin):
self.epoch.reset_on_epoch()
self.epoch.current.reset()
def __setstate__(self, state: dict) -> None:
self.epoch.__setstate__(state["epoch"])
def load_state_dict(self, state_dict: dict) -> None:
self.epoch.load_state_dict(state_dict["epoch"])
@dataclass
@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress):
optim: OptimizationProgress = field(default_factory=OptimizationProgress)
val: EpochLoopProgress = field(default_factory=EpochLoopProgress)
def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.optim.__setstate__(state["optim"])
self.val.__setstate__(state["val"])
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
self.optim.load_state_dict(state_dict["optim"])
self.val.load_state_dict(state_dict["val"])
@dataclass

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import pytest
from pytorch_lightning.trainer.progress import (
@ -135,14 +137,17 @@ def test_optimizer_progress_default_factory():
def test_fit_loop_progress_serialization():
fit_loop = FitLoopProgress()
_ = deepcopy(fit_loop)
fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super`
state_dict = fit_loop.state_dict()
# yapf: disable
assert state_dict == {
'epoch': {
# number of epochs across `fit` calls
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
# number of epochs this `fit` call
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
'batch': {
# number of batches across `fit` calls
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
@ -191,13 +196,16 @@ def test_fit_loop_progress_serialization():
}
}
# yapf: enable
new_loop = FitLoopProgress.from_state_dict(state_dict)
assert fit_loop == new_loop
def test_epoch_loop_progress_serialization():
loop = EpochLoopProgress()
_ = deepcopy(loop)
state_dict = loop.state_dict()
# yapf: disable
assert state_dict == {
'epoch': {
@ -214,5 +222,6 @@ def test_epoch_loop_progress_serialization():
}
}
# yapf: enable
new_loop = EpochLoopProgress.from_state_dict(state_dict)
assert loop == new_loop