lightning/pytorch_lightning/loops/base.py

159 lines
5.1 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from deprecate import void
import pytorch_lightning as pl
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class Loop(ABC):
"""
Basic Loops interface. All classes derived from this must implement the following properties and methods:
* :attr:`done` (property): Condition to break the loop
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
* :attr:`advance` (method): Implements one step of the loop
This class implements the following loop structure:
.. codeblock:: python
on_run_start()
while not done:
on_advance_start()
advance()
on_advance_end()
on_run_end()
"""
def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None
self._restarting = False
@property
def restarting(self) -> bool:
return self._restarting
@restarting.setter
def restarting(self, restarting: bool) -> None:
self._restarting = restarting
@property
@abstractmethod
def done(self) -> bool:
"""Property indicating when loop is finished"""
@property
def skip(self) -> bool:
"""Determine whether to return immediately from the call to :meth:`run`."""
return False
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
Loop Refactor 5/N - Prediction Loop (#7700) * integrate d180bb2 * Minor changes * Refactor loop logic into logger connector * Refactor test * Tighter fx validator * Add back split idx * Typing * update * Conflict * Fix tests * resolve grad_norm * update * move to train loop * Bye grad_norm_dict parameter * Fix sync test * update * Fix bug when validation is run mid epoch * fix grad_norm_dict test * Fix fx_validator test * fix grad_norm_dict test * Fix order bug * Detach tensors in test * resolve some tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pdb * resolve flake8 * Update test * more tests * Revert last thomas' changes * resolve 1 test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor context restoration * integrate latest changes from logger connector refactor poc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate latest changes from logger connector refactor poc * Minor changes * update changelog * Remove unused argument * Update CHANGELOG * Copy call_hook changes * Docs * Fix ref * move to cpu * Bad merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pdb * remove pdb * Refactor to * Avoid partial * trigger ci * Bad merge * integrate latest logger connector changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove grad norm dicts list * Diff * properties first * Bad merge * Reuse metrics_to_scalars * Use active loop * Move to device * resolve test * integrate latest changes from logger connector poc * define union * define union * Update logger connector * Update result * Update imports * Update after rename * Refactor reduce_fx and op * Fix test after rename * mypy * integrate latest logger connector refactor poc changes * Fix test * Refactor test * Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)` * Undo field * add redundant return * rename rename files and classes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename * Replace code * Fix names and imports * Remove metric_attribute * imports * loop hygiene * yapf on loops * protected new loop trigger * rename NEW LOOP guard * integrate latest logger connector changes * integrate latest logger connector changes (eval loop) * resolve todo dataloading reset * re-add notebooks * add missing init * bad merge * remove NEW_LOOP guard * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * flake8 * exclude coverage coverage * integrate #7917, remove teardown from training loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update "accumulated_batches_reached" condition based on if iter count was updated or not * remove public loop properties * make skip backward protected again * typing base loop * typing fit loop * typing training_batch_loop * typing evaluation loop * typing prediction loop * typing training epoch loop * dataloader_loop * evaluation_dataloader_loop * prediction_dataloader_loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate train loop changes from master * integrate eval loop changes from master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tpipes moving model to cpu and leaving it there. * don't reset fit loop don't reset fit loop * fix test iteration count <-> batch_idx reset * replace torch.Tensor -> Tensor * fix attribute error to block_ddp_sync_behaviour * fix flake8 and yapf conflict * remove redundant override * add classes Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> Co-authored-by: Justus Schock <justus.schock@posteo.de> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * trainer changes * connect * clean up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test renaming * rename evaluation loop to evaluation epoch loop * minor docstring improvements * update chlog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try ci fix * update code owners for pl/loops * update mock path * re-order * simplify dataloader reset * simplify get_dataloaders() * save predictions on_run_end() * improve skip condition re-routing * re-order * remove unused type import * check which assert is failing * pig * hobbit * teardown for evaluation * Revert "hobbit" This reverts commit e81b0dbee31da813ba6ad58f74d236863c86d18e. * Revert "pig" This reverts commit 33d89e0720ce7380af80917b15a79362d9416ae7. * Revert "check which assert is failing" This reverts commit b7483b425cab95290eb2cbf354ccb0a77004df83. * free memory in fit loop teardown * update docstring * period * remove dead code * else carlos Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * update chlog * unused imp * move default construction in run_evaluation * add something for lawyer to read * switch typehint for eval loop trainer property * add missing imports * remove a todo that needs more discussion * combine _get_num_dataloaders with the property * Update pytorch_lightning/loops/dataloader/dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * black + yapf * avoid coverage on old unused eval loop * empty space in docstring Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> * resolve todo for args forwarding * weekproxy trainer * fix check for num dataloaders kwargs * clean up num prediction dataloaders property * free memory * rm notebooks folder * rm old file * revert changes to old eval loop * bad merge * undo teardown * setup signature * remove file for notes * free memory * chlog * Revert "weekproxy trainer" This reverts commit d4e6969170b80db4c9e6111fa9af507c740cde4a. * connect trainer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up max batches and dataloaders * max batches handling * no grad handling * unused argument * protected attrs * unused imports * undo unintentional rename * consistent naming * capitalization in docstring * list all args * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
2021-06-23 09:17:04 +00:00
# TODO(@justusschock): Make the trainer a weakref/proxy
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
Loop Refactor 5/N - Prediction Loop (#7700) * integrate d180bb2 * Minor changes * Refactor loop logic into logger connector * Refactor test * Tighter fx validator * Add back split idx * Typing * update * Conflict * Fix tests * resolve grad_norm * update * move to train loop * Bye grad_norm_dict parameter * Fix sync test * update * Fix bug when validation is run mid epoch * fix grad_norm_dict test * Fix fx_validator test * fix grad_norm_dict test * Fix order bug * Detach tensors in test * resolve some tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pdb * resolve flake8 * Update test * more tests * Revert last thomas' changes * resolve 1 test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor context restoration * integrate latest changes from logger connector refactor poc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate latest changes from logger connector refactor poc * Minor changes * update changelog * Remove unused argument * Update CHANGELOG * Copy call_hook changes * Docs * Fix ref * move to cpu * Bad merge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pdb * remove pdb * Refactor to * Avoid partial * trigger ci * Bad merge * integrate latest logger connector changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove grad norm dicts list * Diff * properties first * Bad merge * Reuse metrics_to_scalars * Use active loop * Move to device * resolve test * integrate latest changes from logger connector poc * define union * define union * Update logger connector * Update result * Update imports * Update after rename * Refactor reduce_fx and op * Fix test after rename * mypy * integrate latest logger connector refactor poc changes * Fix test * Refactor test * Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)` * Undo field * add redundant return * rename rename files and classes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename * Replace code * Fix names and imports * Remove metric_attribute * imports * loop hygiene * yapf on loops * protected new loop trigger * rename NEW LOOP guard * integrate latest logger connector changes * integrate latest logger connector changes (eval loop) * resolve todo dataloading reset * re-add notebooks * add missing init * bad merge * remove NEW_LOOP guard * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * flake8 * exclude coverage coverage * integrate #7917, remove teardown from training loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update "accumulated_batches_reached" condition based on if iter count was updated or not * remove public loop properties * make skip backward protected again * typing base loop * typing fit loop * typing training_batch_loop * typing evaluation loop * typing prediction loop * typing training epoch loop * dataloader_loop * evaluation_dataloader_loop * prediction_dataloader_loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate train loop changes from master * integrate eval loop changes from master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tpipes moving model to cpu and leaving it there. * don't reset fit loop don't reset fit loop * fix test iteration count <-> batch_idx reset * replace torch.Tensor -> Tensor * fix attribute error to block_ddp_sync_behaviour * fix flake8 and yapf conflict * remove redundant override * add classes Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> Co-authored-by: Justus Schock <justus.schock@posteo.de> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * trainer changes * connect * clean up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test renaming * rename evaluation loop to evaluation epoch loop * minor docstring improvements * update chlog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try ci fix * update code owners for pl/loops * update mock path * re-order * simplify dataloader reset * simplify get_dataloaders() * save predictions on_run_end() * improve skip condition re-routing * re-order * remove unused type import * check which assert is failing * pig * hobbit * teardown for evaluation * Revert "hobbit" This reverts commit e81b0dbee31da813ba6ad58f74d236863c86d18e. * Revert "pig" This reverts commit 33d89e0720ce7380af80917b15a79362d9416ae7. * Revert "check which assert is failing" This reverts commit b7483b425cab95290eb2cbf354ccb0a77004df83. * free memory in fit loop teardown * update docstring * period * remove dead code * else carlos Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * update chlog * unused imp * move default construction in run_evaluation * add something for lawyer to read * switch typehint for eval loop trainer property * add missing imports * remove a todo that needs more discussion * combine _get_num_dataloaders with the property * Update pytorch_lightning/loops/dataloader/dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * black + yapf * avoid coverage on old unused eval loop * empty space in docstring Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> * resolve todo for args forwarding * weekproxy trainer * fix check for num dataloaders kwargs * clean up num prediction dataloaders property * free memory * rm notebooks folder * rm old file * revert changes to old eval loop * bad merge * undo teardown * setup signature * remove file for notes * free memory * chlog * Revert "weekproxy trainer" This reverts commit d4e6969170b80db4c9e6111fa9af507c740cde4a. * connect trainer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up max batches and dataloaders * max batches handling * no grad handling * unused argument * protected attrs * unused imports * undo unintentional rename * consistent naming * capitalization in docstring * list all args * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/loops/prediction_epoch_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
2021-06-23 09:17:04 +00:00
self.trainer = trainer
def on_skip(self) -> Optional[Any]:
"""
The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
Returns:
the default output value of :meth:`on_run_end`
"""
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
"""
The main entry point to the loop.
Will frequently check the :attr:`done` condition and calls :attr:`advance`
until :attr:`done` evaluates to ``True``.
Returns:
the output of :attr:`on_run_end` (often outputs collected from each step of the loop)
"""
if self.skip:
return self.on_skip()
if self.restarting:
self.restore()
self.restarting = False
else:
self.reset()
self.on_run_start(*args, **kwargs)
while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.iteration_count += 1
except StopIteration:
break
output = self.on_run_end()
return output
def restore(self) -> None:
"""Restore the internal state of the loop the beginning of run if restarting is ``True``."""
@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called as the first thing after entering :attr:`run` (except the state reset).
Accepts all arguments passed to :attr:`run`.
"""
void(*args, **kwargs)
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`.
"""
void(*args, **kwargs)
@abstractmethod
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs a single step. Accepts all arguments passed to :attr:`run`."""
def on_advance_end(self) -> None:
"""Hook to be called each time after :attr:`advance` is called."""
def on_run_end(self) -> Any:
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""
def teardown(self) -> None:
"""Use to release memory etc."""
def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""
def state_dict(self) -> Dict:
"""Return the loop current states."""
return {}