remove duplicate tests (#2685)
* remove duplicate test * remove duplicated tests
This commit is contained in:
parent
6780214b27
commit
938ec5a6c1
|
@ -1,38 +1,8 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_early_stopping_functionality(tmpdir):
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
def validation_epoch_end(self, outputs):
|
||||
losses = [8, 4, 2, 3, 4, 5, 8, 10]
|
||||
val_loss = losses[self.current_epoch]
|
||||
return {'val_loss': torch.tensor(val_loss)}
|
||||
|
||||
model = CurrentModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=20,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
print(trainer.current_epoch)
|
||||
|
||||
assert trainer.current_epoch == 5, 'early_stopping failed'
|
||||
|
||||
|
||||
def test_trainer_callback_system(tmpdir):
|
||||
"""Test the callback system."""
|
||||
|
||||
|
@ -262,86 +232,3 @@ def test_trainer_callback_system(tmpdir):
|
|||
assert not test_callback.on_validation_end_called
|
||||
assert not test_callback.on_validation_batch_end_called
|
||||
assert not test_callback.on_validation_batch_start_called
|
||||
|
||||
|
||||
def test_early_stopping_no_val_step(tmpdir):
|
||||
"""Test that early stopping callback falls back to training metrics when no validation defined."""
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
def training_step(self, *args, **kwargs):
|
||||
output = super().training_step(*args, **kwargs)
|
||||
output.update({'my_train_metric': output['loss']}) # could be anything else
|
||||
return output
|
||||
|
||||
model = CurrentModel()
|
||||
model.validation_step = None
|
||||
model.val_dataloader = None
|
||||
|
||||
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=stopping,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch < trainer.max_epochs
|
||||
|
||||
|
||||
def test_pickling(tmpdir):
|
||||
import pickle
|
||||
early_stopping = EarlyStopping()
|
||||
ckpt = ModelCheckpoint(tmpdir)
|
||||
|
||||
early_stopping_pickled = pickle.dumps(early_stopping)
|
||||
ckpt_pickled = pickle.dumps(ckpt)
|
||||
|
||||
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
||||
ckpt_loaded = pickle.loads(ckpt_pickled)
|
||||
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate()
|
||||
|
||||
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# These should be different if the dirpath has be overridden
|
||||
assert trainer.ckpt_path != trainer.default_root_dir
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'logger_version,expected',
|
||||
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
|
||||
)
|
||||
def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
||||
"""Test that "version_" prefix is only added when logger's version is an integer"""
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate()
|
||||
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
overfit_batches=0.2,
|
||||
max_epochs=2,
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
ckpt_version = Path(trainer.ckpt_path).parent.name
|
||||
assert ckpt_version == expected
|
||||
|
|
|
@ -130,3 +130,49 @@ def test_pickling(tmpdir):
|
|||
early_stopping_pickled = cloudpickle.dumps(early_stopping)
|
||||
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
|
||||
def test_early_stopping_no_val_step(tmpdir):
|
||||
"""Test that early stopping callback falls back to training metrics when no validation defined."""
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
def training_step(self, *args, **kwargs):
|
||||
output = super().training_step(*args, **kwargs)
|
||||
output.update({'my_train_metric': output['loss']}) # could be anything else
|
||||
return output
|
||||
|
||||
model = CurrentModel()
|
||||
model.validation_step = None
|
||||
model.val_dataloader = None
|
||||
|
||||
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=stopping,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch < trainer.max_epochs
|
||||
|
||||
|
||||
def test_early_stopping_functionality(tmpdir):
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
def validation_epoch_end(self, outputs):
|
||||
losses = [8, 4, 2, 3, 4, 5, 8, 10]
|
||||
val_loss = losses[self.current_epoch]
|
||||
return {'val_loss': torch.tensor(val_loss)}
|
||||
|
||||
model = CurrentModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=20,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.current_epoch == 5, 'early_stopping failed'
|
||||
|
|
|
@ -15,9 +15,7 @@ from tests.base import EvalModelTemplate
|
|||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||
"""
|
||||
Test that None in checkpoint callback is valid and that chkp_path is set correctly
|
||||
"""
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate()
|
||||
|
||||
|
@ -26,8 +24,8 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint,
|
||||
overfit_pct=0.20,
|
||||
max_epochs=(save_top_k + 2),
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -47,8 +45,8 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
|||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
overfit_pct=0.2,
|
||||
max_epochs=5,
|
||||
overfit_batches=0.2,
|
||||
max_epochs=2,
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue