lightning/tests/plugins/legacy/test_rpc_plugin.py

128 lines
4.3 KiB
Python
Raw Normal View History

import os
from typing import Optional
from unittest import mock
import pytest
import torch
[feat] pp 2/n (#5026) * Added changes for RPC plugin * Add missing kwargs * Fix code format * Loading refactors by introducing is_distributed var, fix optimizer step flow * Add rpc guard * Added docstrings and typing * resolve comments * Add additional rpc hook, refactor name of exit process hook for clarity * remove annotation * Modify behaviour to allow optional return, add test for rpc plugin * resolve tests * rename is_ddp_based * update * update for windows * update * resolve test * code smell * Added sequential plugin * resolve bug * update * cleanup * add Exception * resolve docs * Remove ddp support * Revert distributed -> ddp * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/basic_examples/conv_sequential_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Address code review points * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add missing return * Fix formatting, add datamodule args * add small comment * resolve comments * resolve comments * update source for fairscale * update extras * remove staticmethod * resolve flake8 * Skip tests that are failing due to bug upstream with multiple optimizers and shard * update * update on comments * clean test * latest comments * remove old comments * add todo * Update version * update * resolve bugs * resolve bugs * update test * remove hanging test * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve on comments * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/ddp_sequential_plugin.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove ImportError Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2020-12-09 12:56:51 +00:00
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import _RPC_AVAILABLE
from tests.base.boring_model import BoringModel
@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend.ddp_plugin, RPCPlugin)
raise RuntimeError('finished plugin check')
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
callbacks=[CB()],
plugins=[RPCPlugin()]
)
with pytest.raises(RuntimeError, match='finished plugin check'):
trainer.fit(model)
class CustomRPCPlugin(RPCPlugin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.rpc_save_model_count = 0
self.on_main_rpc_connect_count = 0
self.worker_optimizer_step_count = 0
self.is_main_rpc_process_count = 0
self.on_exit_rpc_process_count = 0
self.return_after_exit_rpc_process_count = 0
def on_accelerator_exit_rpc_process(self, trainer) -> None:
self.on_exit_rpc_process_count += 1
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
self.rpc_save_model_count += 1
def on_main_rpc_connection(self, trainer) -> None:
self.on_main_rpc_connect_count += 1
def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:
self.worker_optimizer_step_count += 1
@property
def is_main_rpc_process(self) -> bool:
self.is_main_rpc_process_count += 1
return torch.distributed.get_rank() == 0
@property
def return_after_exit_rpc_process(self) -> bool:
self.return_after_exit_rpc_process_count += 1
return False
def barrier(self, name: Optional[str] = None) -> None:
return
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_rpc_function_calls_ddp(tmpdir):
model = BoringModel()
plugin = CustomRPCPlugin()
max_epochs = 2
limit_train_batches = 2
trainer = Trainer(
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=max_epochs,
gpus=2,
distributed_backend='ddp',
plugins=[plugin]
)
trainer.fit(model)
if trainer.global_rank == 0: # Main process
assert plugin.rpc_save_model_count == max_epochs
assert plugin.on_main_rpc_connect_count == 1
assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches
# Call once at init, and at optim step
assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count
assert plugin.on_exit_rpc_process_count == 0
else: # Worker process
assert plugin.rpc_save_model_count == max_epochs
assert plugin.on_main_rpc_connect_count == 0
# Never signaled by worker, only by main process
assert plugin.worker_optimizer_step_count == 0
# Call once at init, and at optim step
assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches)
# Called at init
assert plugin.on_exit_rpc_process_count == 1