70 lines
3.0 KiB
Python
70 lines
3.0 KiB
Python
![]() |
# Copyright The Lightning AI team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import logging
|
||
|
|
||
|
import pytest
|
||
|
from lightning.pytorch import Trainer
|
||
|
from lightning.pytorch.callbacks import ModelSummary
|
||
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
||
|
from lightning.pytorch.profilers import PassThroughProfiler
|
||
|
|
||
|
|
||
|
def test_barebones_disables_logging():
|
||
|
pl_module = BoringModel()
|
||
|
trainer = Trainer(barebones=True)
|
||
|
pl_module._trainer = trainer
|
||
|
|
||
|
with pytest.warns(match=r"barebones=True\)` is configured"):
|
||
|
pl_module.log("foo", 1.0)
|
||
|
with pytest.warns(match=r"barebones=True\)` is configured"):
|
||
|
pl_module.log_dict({"foo": 1.0})
|
||
|
|
||
|
|
||
|
def test_barebones_argument_selection(caplog):
|
||
|
with caplog.at_level(logging.INFO):
|
||
|
trainer = Trainer(barebones=True)
|
||
|
|
||
|
assert "running in `Trainer(barebones=True)` mode" in caplog.text
|
||
|
assert trainer.barebones
|
||
|
assert not trainer.checkpoint_callbacks
|
||
|
assert not trainer.loggers
|
||
|
assert not trainer.progress_bar_callback
|
||
|
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||
|
assert not trainer.log_every_n_steps
|
||
|
assert not trainer.num_sanity_val_steps
|
||
|
assert not trainer.fast_dev_run
|
||
|
assert not trainer._detect_anomaly
|
||
|
assert isinstance(trainer.profiler, PassThroughProfiler)
|
||
|
|
||
|
|
||
|
def test_barebones_raises():
|
||
|
with pytest.raises(ValueError, match=r"enable_checkpointing=True\)` was passed"):
|
||
|
Trainer(barebones=True, enable_checkpointing=True)
|
||
|
with pytest.raises(ValueError, match=r"logger=True\)` was passed"):
|
||
|
Trainer(barebones=True, logger=True)
|
||
|
with pytest.raises(ValueError, match=r"enable_progress_bar=True\)` was passed"):
|
||
|
Trainer(barebones=True, enable_progress_bar=True)
|
||
|
with pytest.raises(ValueError, match=r"enable_model_summary=True\)` was passed"):
|
||
|
Trainer(barebones=True, enable_model_summary=True)
|
||
|
with pytest.raises(ValueError, match=r"log_every_n_steps=1\)` was passed"):
|
||
|
Trainer(barebones=True, log_every_n_steps=1)
|
||
|
with pytest.raises(ValueError, match=r"num_sanity_val_steps=1\)` was passed"):
|
||
|
Trainer(barebones=True, num_sanity_val_steps=1)
|
||
|
with pytest.raises(ValueError, match=r"fast_dev_run=1\)` was passed"):
|
||
|
Trainer(barebones=True, fast_dev_run=1)
|
||
|
with pytest.raises(ValueError, match=r"detect_anomaly=True\)` was passed"):
|
||
|
Trainer(barebones=True, detect_anomaly=True)
|
||
|
with pytest.raises(ValueError, match=r"profiler='simple'\)` was passed"):
|
||
|
Trainer(barebones=True, profiler="simple")
|