lightning/tests/tests_pytorch/trainer/flags/test_barebones.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

70 lines
3.0 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pytest
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.profilers import PassThroughProfiler
def test_barebones_disables_logging():
pl_module = BoringModel()
trainer = Trainer(barebones=True)
pl_module._trainer = trainer
with pytest.warns(match=r"barebones=True\)` is configured"):
pl_module.log("foo", 1.0)
with pytest.warns(match=r"barebones=True\)` is configured"):
pl_module.log_dict({"foo": 1.0})
def test_barebones_argument_selection(caplog):
with caplog.at_level(logging.INFO):
trainer = Trainer(barebones=True)
assert "running in `Trainer(barebones=True)` mode" in caplog.text
assert trainer.barebones
assert not trainer.checkpoint_callbacks
assert not trainer.loggers
assert not trainer.progress_bar_callback
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
assert not trainer.log_every_n_steps
assert not trainer.num_sanity_val_steps
assert not trainer.fast_dev_run
assert not trainer._detect_anomaly
assert isinstance(trainer.profiler, PassThroughProfiler)
def test_barebones_raises():
with pytest.raises(ValueError, match=r"enable_checkpointing=True\)` was passed"):
Trainer(barebones=True, enable_checkpointing=True)
with pytest.raises(ValueError, match=r"logger=True\)` was passed"):
Trainer(barebones=True, logger=True)
with pytest.raises(ValueError, match=r"enable_progress_bar=True\)` was passed"):
Trainer(barebones=True, enable_progress_bar=True)
with pytest.raises(ValueError, match=r"enable_model_summary=True\)` was passed"):
Trainer(barebones=True, enable_model_summary=True)
with pytest.raises(ValueError, match=r"log_every_n_steps=1\)` was passed"):
Trainer(barebones=True, log_every_n_steps=1)
with pytest.raises(ValueError, match=r"num_sanity_val_steps=1\)` was passed"):
Trainer(barebones=True, num_sanity_val_steps=1)
with pytest.raises(ValueError, match=r"fast_dev_run=1\)` was passed"):
Trainer(barebones=True, fast_dev_run=1)
with pytest.raises(ValueError, match=r"detect_anomaly=True\)` was passed"):
Trainer(barebones=True, detect_anomaly=True)
with pytest.raises(ValueError, match=r"profiler='simple'\)` was passed"):
Trainer(barebones=True, profiler="simple")