lightning/tests/trainer/connectors/test_signal_connector.py

137 lines
4.8 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import os
import signal
from time import sleep
from unittest import mock
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.exceptions import ExitGracefullyException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_signal_handlers_restored_in_teardown():
"""Test that the SignalConnector restores the previously configured handler on teardown."""
assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL
trainer = Trainer(plugins=SLURMEnvironment())
connector = SignalConnector(trainer)
connector.register_signal_handlers()
assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL
connector.teardown()
assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL
@pytest.mark.parametrize("register_handler", [False, True])
@pytest.mark.parametrize("terminate_gracefully", [False, True])
@RunIf(skip_windows=True)
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir):
if register_handler:
def handler(*_):
pass
signal.signal(signal.SIGTERM, handler)
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if terminate_gracefully or register_handler:
os.kill(os.getpid(), signal.SIGTERM)
sleep(0.1)
return super().training_step(batch, batch_idx)
model = TestModel()
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0)
if terminate_gracefully and not register_handler:
with pytest.raises(ExitGracefullyException):
trainer.fit(model)
else:
trainer.fit(model)
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)
# reset the signal to system defaults
signal.signal(signal.SIGTERM, signal.SIG_DFL)
@RunIf(skip_windows=True)
@pytest.mark.parametrize("auto_requeue", (True, False))
def test_auto_requeue_flag(auto_requeue):
trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=auto_requeue)])
connector = SignalConnector(trainer)
connector.register_signal_handlers()
if auto_requeue:
sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers
assert len(sigterm_handlers) == 1
assert sigterm_handlers[0].__qualname__ == "SignalConnector.sigterm_handler_fn"
sigusr1_handlers = signal.getsignal(signal.SIGUSR1).signal_handlers
assert len(sigusr1_handlers) == 1
assert sigusr1_handlers[0].__qualname__ == "SignalConnector.slurm_sigusr1_handler_fn"
else:
assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL
assert signal.getsignal(signal.SIGUSR1) is signal.SIG_DFL
connector.teardown()
def _registering_signals():
trainer = Trainer()
trainer._signal_connector.register_signal_handlers()
@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_signal_connector_in_thread():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]):
assert future.exception() is None
def signal_handler():
pass
class SignalHandlers:
def signal_handler(self):
pass
@pytest.mark.parametrize(
["handler", "expected_return"],
[
(None, False),
(signal.Handlers.SIG_IGN, True),
(signal.Handlers.SIG_DFL, False),
(signal_handler, True),
(SignalHandlers().signal_handler, True),
],
)
def test_has_already_handler(handler, expected_return):
"""Test that the SignalConnector detects whether a signal handler is already attached."""
with mock.patch("pytorch_lightning.trainer.connectors.signal_connector.signal.getsignal", return_value=handler):
assert SignalConnector._has_already_handler(signal.SIGTERM) is expected_return