Test deepcopy for progress tracking dataclasses (#8265)
This commit is contained in:
parent
ea88105b88
commit
83039ba470
|
@ -18,19 +18,16 @@ from typing import Optional
|
|||
@dataclass
|
||||
class _DataclassStateDictMixin:
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
def state_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.__dict__.update(state)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return self.__getstate__()
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
|
||||
obj = cls()
|
||||
obj.__setstate__(state_dict)
|
||||
obj.load_state_dict(state_dict)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -115,9 +112,9 @@ class Progress(_DataclassStateDictMixin):
|
|||
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
|
||||
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.total.__setstate__(state["total"])
|
||||
self.current.__setstate__(state["current"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.total.load_state_dict(state_dict["total"])
|
||||
self.current.load_state_dict(state_dict["current"])
|
||||
|
||||
|
||||
class BatchProgress(Progress):
|
||||
|
@ -147,9 +144,9 @@ class EpochProgress(Progress):
|
|||
def reset_on_epoch(self) -> None:
|
||||
self.batch.current.reset()
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
super().__setstate__(state)
|
||||
self.batch.__setstate__(state["batch"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
super().load_state_dict(state_dict)
|
||||
self.batch.load_state_dict(state_dict["batch"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -169,9 +166,9 @@ class OptimizerProgress(_DataclassStateDictMixin):
|
|||
self.step.current.reset()
|
||||
self.zero_grad.current.reset()
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.step.__setstate__(state["step"])
|
||||
self.zero_grad.__setstate__(state["zero_grad"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.step.load_state_dict(state_dict["step"])
|
||||
self.zero_grad.load_state_dict(state_dict["zero_grad"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -200,9 +197,9 @@ class OptimizationProgress(_DataclassStateDictMixin):
|
|||
self.optimizer.reset_on_epoch()
|
||||
self.scheduler.current.reset()
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.optimizer.__setstate__(state["optimizer"])
|
||||
self.scheduler.__setstate__(state["scheduler"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
self.scheduler.load_state_dict(state_dict["scheduler"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -225,8 +222,8 @@ class EpochLoopProgress(_DataclassStateDictMixin):
|
|||
self.epoch.reset_on_epoch()
|
||||
self.epoch.current.reset()
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.epoch.__setstate__(state["epoch"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.epoch.load_state_dict(state_dict["epoch"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress):
|
|||
optim: OptimizationProgress = field(default_factory=OptimizationProgress)
|
||||
val: EpochLoopProgress = field(default_factory=EpochLoopProgress)
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
super().__setstate__(state)
|
||||
self.optim.__setstate__(state["optim"])
|
||||
self.val.__setstate__(state["val"])
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
super().load_state_dict(state_dict)
|
||||
self.optim.load_state_dict(state_dict["optim"])
|
||||
self.val.load_state_dict(state_dict["val"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.trainer.progress import (
|
||||
|
@ -135,14 +137,17 @@ def test_optimizer_progress_default_factory():
|
|||
|
||||
def test_fit_loop_progress_serialization():
|
||||
fit_loop = FitLoopProgress()
|
||||
_ = deepcopy(fit_loop)
|
||||
fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super`
|
||||
|
||||
state_dict = fit_loop.state_dict()
|
||||
# yapf: disable
|
||||
assert state_dict == {
|
||||
'epoch': {
|
||||
# number of epochs across `fit` calls
|
||||
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
||||
'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
|
||||
# number of epochs this `fit` call
|
||||
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
||||
'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
|
||||
'batch': {
|
||||
# number of batches across `fit` calls
|
||||
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
||||
|
@ -191,13 +196,16 @@ def test_fit_loop_progress_serialization():
|
|||
}
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
new_loop = FitLoopProgress.from_state_dict(state_dict)
|
||||
assert fit_loop == new_loop
|
||||
|
||||
|
||||
def test_epoch_loop_progress_serialization():
|
||||
loop = EpochLoopProgress()
|
||||
_ = deepcopy(loop)
|
||||
state_dict = loop.state_dict()
|
||||
|
||||
# yapf: disable
|
||||
assert state_dict == {
|
||||
'epoch': {
|
||||
|
@ -214,5 +222,6 @@ def test_epoch_loop_progress_serialization():
|
|||
}
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
new_loop = EpochLoopProgress.from_state_dict(state_dict)
|
||||
assert loop == new_loop
|
||||
|
|
Loading…
Reference in New Issue