lightning/tests/plugins/legacy/test_plugin.py

126 lines
4.7 KiB
Python
Raw Normal View History

Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import pytest
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.plugins.legacy.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.legacy.native_amp import NativeAMPPlugin
Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
)
def test_custom_required_plugins(tmpdir, ddp_backend, gpus, num_processes):
"""
Test to ensure that if a plugin requires certain plugin to be added, these are added automatically
"""
class RequiredPlugin(NativeAMPPlugin):
"""
My custom amp plugin that's required with my DDP plugin as default.
This allows us to ensure this plugin is added when using CustomPlugin rather than ensuring
the user passes it manually into the list.
"""
class CustomPlugin(DDPPlugin):
def required_plugins(self, amp_backend: AMPType, trainer: Trainer) -> list:
return [RequiredPlugin(trainer=trainer)]
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend.ddp_plugin, CustomPlugin)
assert isinstance(trainer.precision_connector.backend, RequiredPlugin)
raise RuntimeError('finished plugin check')
model = BoringModel()
with pytest.warns(UserWarning,
match=f'plugin {type(CustomPlugin())} has added additional '
f'required plugins as default: {[type(RequiredPlugin())]}*'):
trainer = Trainer(
fast_dev_run=True,
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
plugins=[CustomPlugin()],
callbacks=[CB()],
)
with pytest.raises(RuntimeError, match='finished plugin check'):
trainer.fit(model)
@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
)
def test_invalid_custom_required_plugins(tmpdir, ddp_backend, gpus, num_processes):
"""
Test to ensure if the user passes a plugin that conflicts with the required defaults of another plugin,
we throw a warning and error.
The user has to override the required defaults plugin.
"""
class RequiredPlugin(NativeAMPPlugin):
"""
My custom amp plugin that's required with my DDP plugin as default.
This allows us to ensure this plugin is added when using CustomPlugin rather than ensuring
the user passes it manually into the list.
"""
class CustomPlugin(DDPPlugin):
def required_plugins(self, amp_backend: AMPType, trainer: Trainer) -> list:
return [RequiredPlugin(trainer=trainer)]
with pytest.warns(UserWarning, match=f'plugin {type(CustomPlugin())} has added additional '
f'required plugins as default: {[type(RequiredPlugin())]}*'), \
pytest.raises(MisconfigurationException, match=f"you can only use one {type(NativeAMPPlugin)}"
f" in plugins. You passed in: {2}"):
Trainer(
fast_dev_run=True,
gpus=gpus,
num_processes=num_processes,
accelerator=ddp_backend,
Allow string plugins (#4888) * Allow plugin to be chosen via string * Fix implementation, add tests * Fix codefactor issues * Added missing env patch * Skip test for windows * Reword reason * Add skip to invalid test * Create required_plugins function, move sharded amp requirement to plugin * Pass AMPType, fix setter for apex * Better doc strings * Add exception when using apex * Add trainer available_plugins function, warn user when plugins have been added automatically with option to override behaviour * Fixed pep8 indent * Fix codefactor issues * Add env variables * Update pytorch_lightning/cluster_environments/cluster_environment.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed code review * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/plugin_connector.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Addressed more code review feedback * Fixed docstrings * Swapped to verbose runtime error * Apply suggestions from code review * Apply suggestions from code review * Update pytorch_lightning/plugins/sharded_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Change name * Pass trainer to plugins that may require it * Fix sharded plugin * Added test to ensure string sharded works * Removed trainer typing as this breaks pep8 * Fixed doc issues * Fixed tests Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-12-01 20:30:49 +00:00
plugins=[CustomPlugin(), NativeAMPPlugin()],
)