deepcopy model state_dict in tests (#2887)

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Rohit Gupta 2020-08-08 21:43:06 +05:30 committed by GitHub
parent 1bb268ad8a
commit 4d0406ec8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 3 deletions

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import pickle
import cloudpickle
@ -24,7 +25,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.state_dict().copy())
self.saved_states.append(deepcopy(self.state_dict()))
class EarlyStoppingTestRestore(EarlyStopping):
def __init__(self, expected_state):

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import pytest
import torch
@ -33,7 +34,7 @@ def test_model_reset_correctly(tmpdir):
max_epochs=1,
)
before_state_dict = model.state_dict()
before_state_dict = deepcopy(model.state_dict())
_ = trainer.lr_find(model, num_training=5)

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import pytest
import torch
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
@ -148,7 +149,7 @@ def test_model_reset_correctly(tmpdir):
max_epochs=1,
)
before_state_dict = model.state_dict()
before_state_dict = deepcopy(model.state_dict())
trainer.scale_batch_size(model, max_trials=5)