rm EvalModel

This commit is contained in:
rohitgr7 2021-02-20 15:26:04 +05:30
parent 3c87cc48a4
commit 6c3fb39ebe
2 changed files with 6 additions and 7 deletions

View File

@ -20,7 +20,6 @@ from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory from pytorch_lightning.core import memory
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel from tests.helpers.simple_models import ClassificationModel
@ -72,7 +71,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()""" """Make sure DDP works with dataloaders passed to fit()"""
tutils.set_random_master_port() tutils.set_random_master_port()
model = EvalModelTemplate() model = BoringModel()
fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader()) fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
trainer = Trainer( trainer = Trainer(

View File

@ -19,7 +19,6 @@ import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils import tests.helpers.utils as tutils
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory from pytorch_lightning.core import memory
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel from tests.helpers.simple_models import ClassificationModel
@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
import os import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
model = EvalModelTemplate() dm = ClassifDataModule()
model = ClassificationModel()
trainer = pl.Trainer( trainer = pl.Trainer(
default_root_dir=tmpdir, default_root_dir=tmpdir,
max_epochs=2, max_epochs=2,
@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
gpus=[0, 1], gpus=[0, 1],
accelerator='dp', accelerator='dp',
) )
trainer.fit(model) trainer.fit(model, datamodule=dm)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test() results = trainer.test(datamodule=dm)
assert 'test_acc' in results[0] assert 'test_acc' in results[0]
old_weights = model.c_d1.weight.clone().detach().cpu() old_weights = model.c_d1.weight.clone().detach().cpu()
results = trainer.test(model) results = trainer.test(model, datamodule=dm)
assert 'test_acc' in results[0] assert 'test_acc' in results[0]
# make sure weights didn't change # make sure weights didn't change