2021-01-13 19:35:42 +00:00
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule
|
2021-01-27 16:38:14 +00:00
|
|
|
from pytorch_lightning.trainer.states import RunningStage
|
2021-01-13 19:35:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_distributed_module_methods():
|
|
|
|
""" Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """
|
|
|
|
pl_module = MagicMock()
|
|
|
|
dist_module = LightningDistributedModule(pl_module)
|
|
|
|
|
|
|
|
batch = torch.rand(5)
|
|
|
|
batch_idx = 3
|
|
|
|
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.TRAINING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module(batch, batch_idx)
|
|
|
|
pl_module.training_step.assert_called_with(batch, batch_idx)
|
|
|
|
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.TESTING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module(batch, batch_idx)
|
|
|
|
pl_module.test_step.assert_called_with(batch, batch_idx)
|
|
|
|
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.EVALUATING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module(batch, batch_idx)
|
|
|
|
pl_module.validation_step.assert_called_with(batch, batch_idx)
|
|
|
|
|
|
|
|
|
|
|
|
def test_lightning_distributed_module_warn_none_output():
|
|
|
|
""" Test that the LightningDistributedModule warns about forgotten return statement. """
|
|
|
|
pl_module = MagicMock()
|
|
|
|
dist_module = LightningDistributedModule(pl_module)
|
|
|
|
|
|
|
|
pl_module.training_step.return_value = None
|
|
|
|
pl_module.validation_step.return_value = None
|
|
|
|
pl_module.test_step.return_value = None
|
|
|
|
|
|
|
|
with pytest.warns(UserWarning, match="Your training_step returned None"):
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.TRAINING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module()
|
|
|
|
|
|
|
|
with pytest.warns(UserWarning, match="Your test_step returned None"):
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.TESTING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module()
|
|
|
|
|
|
|
|
with pytest.warns(UserWarning, match="Your validation_step returned None"):
|
2021-01-27 16:38:14 +00:00
|
|
|
pl_module.running_stage = RunningStage.EVALUATING
|
2021-01-13 19:35:42 +00:00
|
|
|
dist_module()
|