rm EvalModel
This commit is contained in:
parent
3c87cc48a4
commit
6c3fb39ebe
|
@ -20,7 +20,6 @@ from pytorch_lightning.callbacks import EarlyStopping
|
|||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.datamodules import ClassifDataModule
|
||||
from tests.helpers.simple_models import ClassificationModel
|
||||
|
@ -72,7 +71,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
|
|||
"""Make sure DDP works with dataloaders passed to fit()"""
|
||||
tutils.set_random_master_port()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model = BoringModel()
|
||||
fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
|
||||
|
||||
trainer = Trainer(
|
||||
|
|
|
@ -19,7 +19,6 @@ import tests.helpers.pipelines as tpipes
|
|||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.core import memory
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.datamodules import ClassifDataModule
|
||||
from tests.helpers.simple_models import ClassificationModel
|
||||
|
@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
|
|||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
model = EvalModelTemplate()
|
||||
dm = ClassifDataModule()
|
||||
model = ClassificationModel()
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
|
@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
|
|||
gpus=[0, 1],
|
||||
accelerator='dp',
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
|
||||
results = trainer.test()
|
||||
results = trainer.test(datamodule=dm)
|
||||
assert 'test_acc' in results[0]
|
||||
|
||||
old_weights = model.c_d1.weight.clone().detach().cpu()
|
||||
|
||||
results = trainer.test(model)
|
||||
results = trainer.test(model, datamodule=dm)
|
||||
assert 'test_acc' in results[0]
|
||||
|
||||
# make sure weights didn't change
|
||||
|
|
Loading…
Reference in New Issue