rm EvalModel
This commit is contained in:
parent
3c87cc48a4
commit
6c3fb39ebe
|
@ -20,7 +20,6 @@ from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.core import memory
|
from pytorch_lightning.core import memory
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
from pytorch_lightning.trainer.states import TrainerState
|
from pytorch_lightning.trainer.states import TrainerState
|
||||||
from tests.base import EvalModelTemplate
|
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
from tests.helpers.datamodules import ClassifDataModule
|
from tests.helpers.datamodules import ClassifDataModule
|
||||||
from tests.helpers.simple_models import ClassificationModel
|
from tests.helpers.simple_models import ClassificationModel
|
||||||
|
@ -72,7 +71,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
|
||||||
"""Make sure DDP works with dataloaders passed to fit()"""
|
"""Make sure DDP works with dataloaders passed to fit()"""
|
||||||
tutils.set_random_master_port()
|
tutils.set_random_master_port()
|
||||||
|
|
||||||
model = EvalModelTemplate()
|
model = BoringModel()
|
||||||
fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
|
fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|
|
@ -19,7 +19,6 @@ import tests.helpers.pipelines as tpipes
|
||||||
import tests.helpers.utils as tutils
|
import tests.helpers.utils as tutils
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.core import memory
|
from pytorch_lightning.core import memory
|
||||||
from tests.base import EvalModelTemplate
|
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
from tests.helpers.datamodules import ClassifDataModule
|
from tests.helpers.datamodules import ClassifDataModule
|
||||||
from tests.helpers.simple_models import ClassificationModel
|
from tests.helpers.simple_models import ClassificationModel
|
||||||
|
@ -76,7 +75,8 @@ def test_dp_test(tmpdir):
|
||||||
import os
|
import os
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||||
|
|
||||||
model = EvalModelTemplate()
|
dm = ClassifDataModule()
|
||||||
|
model = ClassificationModel()
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
default_root_dir=tmpdir,
|
default_root_dir=tmpdir,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
|
@ -85,14 +85,14 @@ def test_dp_test(tmpdir):
|
||||||
gpus=[0, 1],
|
gpus=[0, 1],
|
||||||
accelerator='dp',
|
accelerator='dp',
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model, datamodule=dm)
|
||||||
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
|
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
|
||||||
results = trainer.test()
|
results = trainer.test(datamodule=dm)
|
||||||
assert 'test_acc' in results[0]
|
assert 'test_acc' in results[0]
|
||||||
|
|
||||||
old_weights = model.c_d1.weight.clone().detach().cpu()
|
old_weights = model.c_d1.weight.clone().detach().cpu()
|
||||||
|
|
||||||
results = trainer.test(model)
|
results = trainer.test(model, datamodule=dm)
|
||||||
assert 'test_acc' in results[0]
|
assert 'test_acc' in results[0]
|
||||||
|
|
||||||
# make sure weights didn't change
|
# make sure weights didn't change
|
||||||
|
|
Loading…
Reference in New Issue