lightning/tests/trainer/flags/test_min_max_epochs.py

36 lines
984 B
Python

import pytest
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
@pytest.mark.parametrize(
["min_epochs", "max_epochs", "min_steps", "max_steps"],
[
(None, 3, None, None),
(None, None, None, 20),
(None, 3, None, 20),
(None, None, 10, 20),
(1, 3, None, None),
(1, None, None, 20),
(None, 3, 10, None),
],
)
def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps):
"""Tests that max_steps can be used without max_epochs."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=min_epochs,
max_epochs=max_epochs,
min_steps=min_steps,
max_steps=max_steps,
enable_model_summary=False,
)
trainer.fit(model)
# check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs:
assert trainer.global_step == trainer.max_steps