Loop and test restructuring (#9383)

This commit is contained in:
Carlos Mocholí 2021-09-10 15:18:24 +02:00 committed by GitHub
parent d773407e59
commit 9eccb3148e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 18 additions and 19 deletions

View File

@ -62,10 +62,8 @@ ignore_errors = "True"
[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.closure",
"pytorch_lightning.loops.batch.manual",
"pytorch_lightning.loops.optimizer",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",

View File

@ -18,4 +18,4 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401

View File

@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401

View File

@ -19,8 +19,8 @@ from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.batch.manual import ManualOptimization
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT

View File

@ -17,7 +17,7 @@ import torch
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401

View File

@ -14,7 +14,7 @@
from typing import Any, Optional
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,

View File

@ -20,7 +20,7 @@ from torch.optim import Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.optimization.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
@ -43,7 +43,7 @@ class OptimizerLoop(Loop):
This loop implements what is known in Lightning as Automatic Optimization.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
# TODO: use default dict here to simplify logic in loop
self.outputs: _OUTPUTS_TYPE = []
@ -71,7 +71,7 @@ class OptimizerLoop(Loop):
self._batch_idx = batch_idx
self._optimizers = optimizers
def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override]
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
result = self._run_optimization(
batch,
self._batch_idx,
@ -183,7 +183,7 @@ class OptimizerLoop(Loop):
if not is_first_batch_to_accumulate:
return None
def zero_grad_fn():
def zero_grad_fn() -> None:
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
@ -198,7 +198,7 @@ class OptimizerLoop(Loop):
if self._skip_backward:
return None
def backward_fn(loss: Tensor):
def backward_fn(loss: Tensor) -> Tensor:
self.backward(loss, optimizer, opt_idx)
# check if model weights are nan
@ -332,6 +332,7 @@ class OptimizerLoop(Loop):
if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
assert self.trainer._results is not None
self.trainer._results.cpu()
return result

View File

@ -20,7 +20,7 @@ from torch.optim import Adam, Optimizer, SGD
from pytorch_lightning import Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.loops.optimization.closure import Closure
from tests.helpers.boring_model import BoringModel

View File

@ -15,7 +15,7 @@ import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

View File

@ -18,7 +18,7 @@ from torch.utils.data._utils.collate import default_collate
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.loops.optimization.closure import Closure
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.deterministic_model import DeterministicModel