36 lines
984 B
Python
36 lines
984 B
Python
import pytest
|
|
|
|
from pytorch_lightning import Trainer
|
|
from tests.helpers import BoringModel
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["min_epochs", "max_epochs", "min_steps", "max_steps"],
|
|
[
|
|
(None, 3, None, None),
|
|
(None, None, None, 20),
|
|
(None, 3, None, 20),
|
|
(None, None, 10, 20),
|
|
(1, 3, None, None),
|
|
(1, None, None, 20),
|
|
(None, 3, 10, None),
|
|
],
|
|
)
|
|
def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps):
|
|
"""Tests that max_steps can be used without max_epochs."""
|
|
model = BoringModel()
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
min_epochs=min_epochs,
|
|
max_epochs=max_epochs,
|
|
min_steps=min_steps,
|
|
max_steps=max_steps,
|
|
enable_model_summary=False,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# check training stopped at max_epochs or max_steps
|
|
if trainer.max_steps and not trainer.max_epochs:
|
|
assert trainer.global_step == trainer.max_steps
|