Add `Loop.replace` (#10324)

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-12-16 18:41:38 +01:00 committed by GitHub
parent c335a7891d
commit 46d6fbf11b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 11 deletions

View File

@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))
- Added support for `--lr_scheduler=ReduceLROnPlateau` to the `LightningCLI` ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))

View File

@ -266,19 +266,27 @@ run (optional)
Subloops
--------
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method:
.. code-block:: python
# Step 1: create your loop
my_epoch_loop = MyEpochLoop()
# Step 2: use connect()
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)
# This takes care of properly instantiating the new Loop and setting all references
trainer.fit_loop.replace(epoch_loop=MyEpochLoop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)
Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
.. code-block:: python
# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)
More about the built-in loops and how they are composed is explained in the next section.
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif

View File

@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union
from deprecate import void
from torchmetrics import Metric
@ -99,6 +99,51 @@ class Loop(ABC, Generic[T]):
Linked loops should form a tree.
"""
def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None:
"""Optionally replace one or multiple of this loop's sub-loops.
This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all
sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to
the parent.
Args:
**loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to
replace.
Raises:
MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those
of the Loop class it replaces.
"""
new_loops = {}
for name, type_or_object in loops.items():
old_loop = getattr(self, name)
if isinstance(type_or_object, type):
# compare the signatures
old_parameters = inspect.signature(old_loop.__class__.__init__).parameters
current_parameters = inspect.signature(type_or_object.__init__).parameters
if old_parameters != current_parameters:
raise MisconfigurationException(
f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the"
f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not."
)
# instantiate the loop
kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"}
loop = type_or_object(**kwargs) # type: ignore[call-arg]
else:
loop = type_or_object
# connect sub-loops
kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)}
loop.connect(**kwargs)
# set the trainer reference
loop.trainer = self.trainer
new_loops[name] = loop
# connect to self
self.connect(**new_loops)
def on_skip(self) -> T:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.

View File

@ -127,6 +127,6 @@ def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir
assert not is_overridden("test_epoch_end", model)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
trainer.test_loop.connect(TestLoop())
trainer.test_loop.replace(epoch_loop=TestLoop)
trainer.test(model)
assert did_assert

View File

@ -24,8 +24,9 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -102,6 +103,49 @@ def test_connect_subloops(tmpdir):
assert new_batch_loop.trainer is trainer
def test_replace_loops():
class TestLoop(TrainingEpochLoop):
def __init__(self, foo):
super().__init__()
trainer = Trainer(min_steps=123, max_steps=321)
with pytest.raises(
MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`"
):
trainer.fit_loop.replace(epoch_loop=TestLoop)
class TestLoop(TrainingEpochLoop):
...
# test passing a loop where previous state should be connected
old_loop = trainer.fit_loop.epoch_loop
trainer.fit_loop.replace(epoch_loop=TestLoop)
new_loop = trainer.fit_loop.epoch_loop
assert isinstance(new_loop, TestLoop)
assert trainer.fit_loop.epoch_loop is new_loop
assert new_loop.min_steps == 123
assert new_loop.max_steps == 321
assert new_loop.batch_loop is old_loop.batch_loop
assert new_loop.val_loop is old_loop.val_loop
assert new_loop.trainer is trainer
class MyBatchLoop(TrainingBatchLoop):
...
class MyEvalLoop(EvaluationLoop):
...
# test passing more than one where one is an instance and the other a class
trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop())
new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop
new_val_loop = trainer.fit_loop.epoch_loop.val_loop
assert isinstance(new_batch_loop, MyBatchLoop)
assert isinstance(new_val_loop, MyEvalLoop)
class CustomException(Exception):
pass