deepcopy model state_dict in tests (#2887)
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
1bb268ad8a
commit
4d0406ec8b
|
@ -1,3 +1,4 @@
|
|||
from copy import deepcopy
|
||||
import pickle
|
||||
|
||||
import cloudpickle
|
||||
|
@ -24,7 +25,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
super().on_validation_end(trainer, pl_module)
|
||||
self.saved_states.append(self.state_dict().copy())
|
||||
self.saved_states.append(deepcopy(self.state_dict()))
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
def __init__(self, expected_state):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from copy import deepcopy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -33,7 +34,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
max_epochs=1,
|
||||
)
|
||||
|
||||
before_state_dict = model.state_dict()
|
||||
before_state_dict = deepcopy(model.state_dict())
|
||||
|
||||
_ = trainer.lr_find(model, num_training=5)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from copy import deepcopy
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
|
||||
|
@ -148,7 +149,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
max_epochs=1,
|
||||
)
|
||||
|
||||
before_state_dict = model.state_dict()
|
||||
before_state_dict = deepcopy(model.state_dict())
|
||||
|
||||
trainer.scale_batch_size(model, max_trials=5)
|
||||
|
||||
|
|
Loading…
Reference in New Issue