Add `Loop.replace` (#10324)
Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
c335a7891d
commit
46d6fbf11b
|
@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
|
||||
|
||||
|
||||
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))
|
||||
|
||||
|
||||
- Added support for `--lr_scheduler=ReduceLROnPlateau` to the `LightningCLI` ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
|
||||
|
||||
|
||||
|
|
|
@ -266,19 +266,27 @@ run (optional)
|
|||
Subloops
|
||||
--------
|
||||
|
||||
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
|
||||
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Step 1: create your loop
|
||||
my_epoch_loop = MyEpochLoop()
|
||||
|
||||
# Step 2: use connect()
|
||||
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)
|
||||
|
||||
# This takes care of properly instantiating the new Loop and setting all references
|
||||
trainer.fit_loop.replace(epoch_loop=MyEpochLoop)
|
||||
# Trainer runs the fit loop with your new epoch loop!
|
||||
trainer.fit(model)
|
||||
|
||||
Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Optional: stitch back the trainer arguments
|
||||
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
|
||||
# Optional: connect children loops as they might have existing state
|
||||
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
|
||||
# Instantiate and connect the loop.
|
||||
trainer.fit_loop.connect(epoch_loop=epoch_loop)
|
||||
trainer.fit(model)
|
||||
|
||||
More about the built-in loops and how they are composed is explained in the next section.
|
||||
|
||||
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif
|
||||
|
|
|
@ -11,9 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar
|
||||
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union
|
||||
|
||||
from deprecate import void
|
||||
from torchmetrics import Metric
|
||||
|
@ -99,6 +99,51 @@ class Loop(ABC, Generic[T]):
|
|||
Linked loops should form a tree.
|
||||
"""
|
||||
|
||||
def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None:
|
||||
"""Optionally replace one or multiple of this loop's sub-loops.
|
||||
|
||||
This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all
|
||||
sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to
|
||||
the parent.
|
||||
|
||||
Args:
|
||||
**loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to
|
||||
replace.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those
|
||||
of the Loop class it replaces.
|
||||
"""
|
||||
new_loops = {}
|
||||
|
||||
for name, type_or_object in loops.items():
|
||||
old_loop = getattr(self, name)
|
||||
|
||||
if isinstance(type_or_object, type):
|
||||
# compare the signatures
|
||||
old_parameters = inspect.signature(old_loop.__class__.__init__).parameters
|
||||
current_parameters = inspect.signature(type_or_object.__init__).parameters
|
||||
if old_parameters != current_parameters:
|
||||
raise MisconfigurationException(
|
||||
f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the"
|
||||
f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not."
|
||||
)
|
||||
# instantiate the loop
|
||||
kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"}
|
||||
loop = type_or_object(**kwargs) # type: ignore[call-arg]
|
||||
else:
|
||||
loop = type_or_object
|
||||
|
||||
# connect sub-loops
|
||||
kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)}
|
||||
loop.connect(**kwargs)
|
||||
# set the trainer reference
|
||||
loop.trainer = self.trainer
|
||||
|
||||
new_loops[name] = loop
|
||||
# connect to self
|
||||
self.connect(**new_loops)
|
||||
|
||||
def on_skip(self) -> T:
|
||||
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
|
||||
|
||||
|
|
|
@ -127,6 +127,6 @@ def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir
|
|||
assert not is_overridden("test_epoch_end", model)
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
|
||||
trainer.test_loop.connect(TestLoop())
|
||||
trainer.test_loop.replace(epoch_loop=TestLoop)
|
||||
trainer.test(model)
|
||||
assert did_assert
|
||||
|
|
|
@ -24,8 +24,9 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
|
|||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
from pytorch_lightning.loops import Loop, TrainingBatchLoop
|
||||
from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -102,6 +103,49 @@ def test_connect_subloops(tmpdir):
|
|||
assert new_batch_loop.trainer is trainer
|
||||
|
||||
|
||||
def test_replace_loops():
|
||||
class TestLoop(TrainingEpochLoop):
|
||||
def __init__(self, foo):
|
||||
super().__init__()
|
||||
|
||||
trainer = Trainer(min_steps=123, max_steps=321)
|
||||
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`"
|
||||
):
|
||||
trainer.fit_loop.replace(epoch_loop=TestLoop)
|
||||
|
||||
class TestLoop(TrainingEpochLoop):
|
||||
...
|
||||
|
||||
# test passing a loop where previous state should be connected
|
||||
old_loop = trainer.fit_loop.epoch_loop
|
||||
trainer.fit_loop.replace(epoch_loop=TestLoop)
|
||||
new_loop = trainer.fit_loop.epoch_loop
|
||||
|
||||
assert isinstance(new_loop, TestLoop)
|
||||
assert trainer.fit_loop.epoch_loop is new_loop
|
||||
assert new_loop.min_steps == 123
|
||||
assert new_loop.max_steps == 321
|
||||
assert new_loop.batch_loop is old_loop.batch_loop
|
||||
assert new_loop.val_loop is old_loop.val_loop
|
||||
assert new_loop.trainer is trainer
|
||||
|
||||
class MyBatchLoop(TrainingBatchLoop):
|
||||
...
|
||||
|
||||
class MyEvalLoop(EvaluationLoop):
|
||||
...
|
||||
|
||||
# test passing more than one where one is an instance and the other a class
|
||||
trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop())
|
||||
new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop
|
||||
new_val_loop = trainer.fit_loop.epoch_loop.val_loop
|
||||
|
||||
assert isinstance(new_batch_loop, MyBatchLoop)
|
||||
assert isinstance(new_val_loop, MyEvalLoop)
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue