Loop and test restructuring (#9383)
This commit is contained in:
parent
d773407e59
commit
9eccb3148e
|
@ -62,10 +62,8 @@ ignore_errors = "True"
|
|||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"pytorch_lightning.callbacks.pruning",
|
||||
"pytorch_lightning.loops.closure",
|
||||
"pytorch_lightning.loops.batch.manual",
|
||||
"pytorch_lightning.loops.optimizer",
|
||||
"pytorch_lightning.trainer.evaluation_loop",
|
||||
"pytorch_lightning.loops.optimization.*",
|
||||
"pytorch_lightning.loops.evaluation_loop",
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.*",
|
||||
"pytorch_lightning.trainer.progress",
|
||||
"pytorch_lightning.tuner.auto_gpu_select",
|
||||
|
|
|
@ -18,4 +18,4 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
|
||||
|
|
|
@ -12,5 +12,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
|
||||
|
|
|
@ -19,8 +19,8 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.batch.manual import ManualOptimization
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import loops # import as loops to avoid circular imports
|
||||
from pytorch_lightning.loops.batch import TrainingBatchLoop
|
||||
from pytorch_lightning.loops.closure import ClosureResult
|
||||
from pytorch_lightning.loops.optimization.closure import ClosureResult
|
||||
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
|
||||
|
|
|
@ -12,4 +12,4 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
|
|
@ -14,7 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.closure import ClosureResult
|
||||
from pytorch_lightning.loops.optimization.closure import ClosureResult
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_build_training_step_kwargs,
|
||||
_check_training_step_output,
|
|
@ -20,7 +20,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.optimization.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_block_parallel_sync_behavior,
|
||||
_build_training_step_kwargs,
|
||||
|
@ -43,7 +43,7 @@ class OptimizerLoop(Loop):
|
|||
This loop implements what is known in Lightning as Automatic Optimization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# TODO: use default dict here to simplify logic in loop
|
||||
self.outputs: _OUTPUTS_TYPE = []
|
||||
|
@ -71,7 +71,7 @@ class OptimizerLoop(Loop):
|
|||
self._batch_idx = batch_idx
|
||||
self._optimizers = optimizers
|
||||
|
||||
def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override]
|
||||
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
result = self._run_optimization(
|
||||
batch,
|
||||
self._batch_idx,
|
||||
|
@ -183,7 +183,7 @@ class OptimizerLoop(Loop):
|
|||
if not is_first_batch_to_accumulate:
|
||||
return None
|
||||
|
||||
def zero_grad_fn():
|
||||
def zero_grad_fn() -> None:
|
||||
self._on_before_zero_grad(optimizer)
|
||||
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
||||
|
@ -198,7 +198,7 @@ class OptimizerLoop(Loop):
|
|||
if self._skip_backward:
|
||||
return None
|
||||
|
||||
def backward_fn(loss: Tensor):
|
||||
def backward_fn(loss: Tensor) -> Tensor:
|
||||
self.backward(loss, optimizer, opt_idx)
|
||||
|
||||
# check if model weights are nan
|
||||
|
@ -332,6 +332,7 @@ class OptimizerLoop(Loop):
|
|||
|
||||
if self.trainer.move_metrics_to_cpu:
|
||||
# hiddens and the training step output are not moved as they are not considered "metrics"
|
||||
assert self.trainer._results is not None
|
||||
self.trainer._results.cpu()
|
||||
|
||||
return result
|
|
@ -20,7 +20,7 @@ from torch.optim import Adam, Optimizer, SGD
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops.closure import Closure
|
||||
from pytorch_lightning.loops.optimization.closure import Closure
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loops.closure import ClosureResult
|
||||
from pytorch_lightning.loops.optimization.closure import ClosureResult
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel
|
||||
|
|
@ -18,7 +18,7 @@ from torch.utils.data._utils.collate import default_collate
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loops.closure import Closure
|
||||
from pytorch_lightning.loops.optimization.closure import Closure
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.deterministic_model import DeterministicModel
|
Loading…
Reference in New Issue