Simplify optimization Logic (#4984)

* Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization

* debug

* Revert "debug"

This reverts commit ccca6b6b

* Expose manual reduce for automatic optimization

* Add input arguments

* Enable parity test

* clean imports

* Expose hook after to ensure we reset

* Fix naming

* add

* fix test

* uniformize optimizer logic

* resolve test

* resovle flake8

* resolve amp bug

* update tests

* remove bug

* remove optimizer_step in accelerators

* typo

* update lightning optimizer

* set doesn't work with ddp_spawn

* resolve flake8

* update threshold

* ignore pyright

* correct codeFactor

* remove useless if

* remove zer_grad function

* simplify step

* remove typo

* resolve bug

* Apply suggestions from code review

* update on comments

* resolve bugs

* remove tests

* Update pytorch_lightning/trainer/configuration_validator.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* simplify testing

* add more tests

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
chaton 2020-12-07 12:55:49 +00:00 committed by GitHub
parent ab7c947961
commit 02152c1729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 554 additions and 142 deletions

View File

@ -39,10 +39,11 @@ steps:
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
- pip list
- coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
- python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0 # --flake8
#- cd docs; make doctest; make coverage
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
# Running special tests
- sh tests/special_tests.sh
- coverage report
- python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0
# see: https://docs.codecov.io/docs/merging-reports
- codecov --token $CODECOV_TOKEN --flags=gpu,pytest --name="GPU-coverage" --env=linux --build $DRONE_BUILD_NUMBER --commit $DRONE_COMMIT
# --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG --pr $DRONE_PULL_REQUEST

View File

@ -16,14 +16,15 @@ from enum import Enum
from typing import Any, Optional, Union
import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.core.lightning import LightningModule
import torch.distributed as torch_distrib
if torch.distributed.is_available():
from torch.distributed import ReduceOp
@ -98,40 +99,6 @@ class Accelerator(object):
closure_loss = closure_loss.detach()
return closure_loss
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
automatic_optimization = self.trainer.train_loop.automatic_optimization
# native amp + lbfgs is a no go right now
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
# model hook
model_ref.optimizer_step(
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)
# scale when native amp
if automatic_optimization and using_native_amp:
self.trainer.scaler.update()
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
model_ref = self.trainer.get_model()
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
def clip_gradients(self, optimizer, clip_val=None):
# use the trainer's clip val if none passed
grad_clip_val = self.trainer.gradient_clip_val
@ -160,7 +127,7 @@ class Accelerator(object):
return self.trainer.should_stop
def setup_optimizers(self, model):
if self.trainer.testing is True:
if self.trainer.testing:
return
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)

View File

@ -23,7 +23,14 @@ from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
TPU_AVAILABLE,
move_data_to_device,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -245,24 +252,6 @@ class TPUAccelerator(Accelerator):
return closure_loss
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# model hook
model_ref.optimizer_step(
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=True,
using_native_amp=False,
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md

View File

@ -67,6 +67,9 @@ class GradientAccumulationScheduler(Callback):
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())
def going_to_accumulate_grad_batches(self):
return any([v > 1 for v in self.scheduling.values()])
def on_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):

View File

@ -33,6 +33,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
@ -1236,15 +1237,10 @@ class LightningModule(
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.
"""
if on_tpu and TPU_AVAILABLE:
xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs})
elif self.trainer.amp_backend is not None:
self.trainer.precision_connector.backend.optimizer_step(
self.trainer, optimizer, optimizer_closure)
else:
optimizer.step(closure=optimizer_closure, *args, **kwargs)
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure, *args, **kwargs)
def optimizer_zero_grad(
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int

View File

@ -35,7 +35,7 @@ def do_nothing_closure():
class LightningOptimizer:
"""
This class is used to wrap the user optimizers and handle properly
the backward and optimizer_step logic across accelerators, AMP, accumulated_grad_batches
the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches
"""
def __init__(self,
optimizer: Optimizer,
@ -60,17 +60,35 @@ class LightningOptimizer:
self._trainer = None
self._optimizer = optimizer
self._accumulate_grad_batches = accumulate_grad_batches
self._use_accumulate_grad_batches_from_trainer = accumulate_grad_batches is None
self._automatic_optimization = None
self._optimizer_idx = None
@property
def accumulate_grad_batches(self):
return self._accumulate_grad_batches
@accumulate_grad_batches.setter
def accumulate_grad_batches(self, accumulate_grad_batches):
self._accumulate_grad_batches = accumulate_grad_batches
def _on_trainer_init(self, trainer):
self._trainer = proxy(trainer)
self._automatic_optimization = trainer.train_loop.automatic_optimization
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
break
@classmethod
def to_lightning_optimizer(cls, optimizer, trainer):
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
return optimizer
def _accumulated_batches_reached(self):
if self._use_accumulate_grad_batches_from_trainer:
accumulate_grad_batches = self._trainer.accumulate_grad_batches
else:
accumulate_grad_batches = self._accumulate_grad_batches
return (self._trainer.batch_idx + 1) % accumulate_grad_batches == 0
if self.accumulate_grad_batches is None:
return self._trainer.train_loop._accumulated_batches_reached()
return (self._trainer.batch_idx + 1) % self.accumulate_grad_batches == 0
@property
def _should_accumulate(self):
@ -79,6 +97,45 @@ class LightningOptimizer:
is_final_batch = self._trainer.train_loop._num_training_batches_reached()
return not (accumulation_done or is_final_batch)
def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
trainer = self._trainer
optimizer = self._optimizer
model = trainer.get_model()
if trainer.on_tpu:
with trainer.profiler.profile(profiler_name):
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
elif trainer.amp_backend is not None:
trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
else:
with trainer.profiler.profile(profiler_name):
optimizer.step(closure=closure, *args, **kwargs)
trainer.train_loop.on_before_zero_grad(self)
model.optimizer_zero_grad(
trainer.current_epoch,
trainer.batch_idx,
optimizer,
self._optimizer_idx
)
def _check_make_optimizer_step(self, make_optimizer_step: Optional[bool]) -> bool:
if make_optimizer_step is not None and self._trainer.overriden_optimizer_zero_grad:
raise MisconfigurationException(
"When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed.")
if self._trainer.train_loop.automatic_optimization:
if self._trainer.overriden_optimizer_step and self._trainer.overriden_optimizer_zero_grad:
return True
if make_optimizer_step is None:
make_optimizer_step = not self._should_accumulate
return make_optimizer_step
def step(self, *args, closure: Optional[Callable] = None, make_optimizer_step: Optional[bool] = None, **kwargs):
"""
Call this directly from your training_step when doing optimizations manually.
@ -173,40 +230,23 @@ class LightningOptimizer:
# Trainer(accumulate_grad_batches=x)
opt_dis.step(closure=optimizer_closure, make_optimizer_step=True)
"""
profiler_name = "optimizer_step_and_closure"
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
if closure is None:
closure = do_nothing_closure
profile_name = "optimizer_step"
profile_name = f"optimizer_step_{self._optimizer_idx}"
else:
if not isinstance(closure, types.FunctionType):
raise MisconfigurationException("When closure is provided, it should be a function")
if make_optimizer_step is None:
make_optimizer_step = not self._should_accumulate
trainer = self._trainer
optimizer = self._optimizer
make_optimizer_step = self._check_make_optimizer_step(make_optimizer_step)
if make_optimizer_step:
if trainer.on_tpu:
with trainer.profiler.profile(profiler_name):
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
elif trainer.amp_backend is not None:
trainer.precision_connector.backend.optimizer_step(
trainer, optimizer, closure)
else:
with trainer.profiler.profile(profiler_name):
optimizer.step(closure=closure, *args, **kwargs)
# perform zero grad
optimizer.zero_grad()
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
else:
# make sure to call optimizer_closure when accumulating
with trainer.profiler.profile("closure"):
with trainer.train_loop.block_ddp_sync_behaviour():
with self._trainer.profiler.profile(f"closure_{self._optimizer_idx}"):
with self._trainer.train_loop.block_ddp_sync_behaviour():
closure()
def __repr__(self):

View File

@ -14,7 +14,7 @@
import itertools
import threading
from collections.abc import Mapping, Iterable
from collections.abc import Iterable, Mapping
from itertools import chain
import torch

View File

@ -137,5 +137,9 @@ class ApexPlugin(PrecisionPlugin):
# TODO: pass the closure to the step ASAP
with trainer.profiler.profile("closure"):
closure()
if not self.trainer.train_loop.automatic_optimization:
trainer.call_hook("on_after_backward")
with trainer.profiler.profile("optimizer_step"):
optimizer.step()

View File

@ -1,9 +1,10 @@
import os
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Optional, Union
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.plugin import LightningPlugin

View File

@ -69,6 +69,11 @@ class NativeAMPPlugin(PrecisionPlugin):
# TODO: pass the closure to the step ASAP
with trainer.profiler.profile("closure"):
closure()
if not self.trainer.train_loop.automatic_optimization:
trainer.scaler.unscale_(optimizer)
trainer.call_hook("on_after_backward")
with trainer.profiler.profile("optimizer_step"):
trainer.scaler.step(optimizer)
trainer.scaler.update()

View File

@ -17,7 +17,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.utilities import AMPType, FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if FAIRSCALE_AVAILABLE:

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -69,6 +69,37 @@ class ConfigValidator(object):
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)
trainer = self.trainer
trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
enable_pl_optimizer = trainer._enable_pl_optimizer
automatic_optimization = trainer.train_loop.automatic_optimization
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
rank_zero_warn(
"When overriding `LightningModule` optimizer_step with"
" `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
)
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()
has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
'`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)
if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_zero_grad with '
'`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
)
def __verify_eval_loop_configuration(self, model, eval_loop_name):
step_name = f'{eval_loop_name}_step'

View File

@ -42,7 +42,7 @@ class LoggerConnector:
@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self._current_stage)
return self._cached_results.get(self._current_stage) # type: ignore
def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
self._current_stage = LoggerStages.determine_stage(stage_or_testing)

View File

@ -855,7 +855,8 @@ class Trainer(
model.setup(stage_name)
def _reset_result_and_set_hook_fx_name(self, hook_name):
if "batch_start" in hook_name:
# on_before_zero_grad is called within training_step
if "batch_start" in hook_name or "on_before_zero_grad" in hook_name:
return True
model_ref = self.get_model()
if model_ref is not None:

View File

@ -26,7 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import AMPType, parsing
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
@ -321,7 +321,6 @@ class TrainLoop:
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
# manually capture logged metrics
model_ref._results = Result()
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
training_step_output = self.trainer.accelerator_backend.training_step(args)
@ -475,21 +474,34 @@ class TrainLoop:
return training_step_output_for_epoch_end
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs):
# optimizer step lightningModule hook
if isinstance(optimizer, LightningOptimizer):
optimizer.step(closure=train_step_and_backward_closure)
else:
with self.trainer.profiler.profile("optimizer_step"):
self.trainer.accelerator_backend.optimizer_step(
optimizer, batch_idx, opt_idx, train_step_and_backward_closure, *args, **kwargs
)
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
# native amp + lbfgs is a no go right now
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
# model hook
model_ref.optimizer_step(
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=train_step_and_backward_closure,
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)
def on_before_zero_grad(self, optimizer):
self.trainer.call_hook('on_before_zero_grad', optimizer)
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
def track_and_norm_grad(self, optimizer):
# track gradient norms
grad_norm_dic = self._track_gradient_norm()
@ -708,7 +720,6 @@ class TrainLoop:
if self._curr_step_result is None:
# user decided to skip optimization
# make sure to zero grad.
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
continue
batch_outputs = self._process_closure_result(
@ -720,9 +731,6 @@ class TrainLoop:
grad_norm_dic = self._cur_grad_norm_dict
self._cur_grad_norm_dict = None
# hook + clear gradients
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
# update running loss + reset accumulated loss
self.update_running_loss()
@ -947,14 +955,3 @@ class TrainLoop:
# reset for next set of accumulated grads
self.accumulated_loss.reset()
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
if self.automatic_optimization:
# hook
self.on_before_zero_grad(optimizer)
optimizers = enumerate([optimizer])
else:
optimizers = []
for idx, optimizer in optimizers:
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

View File

@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import MagicMock, call, ANY
from unittest.mock import ANY, MagicMock, call
from pytorch_lightning import Trainer
from tests.base import BoringModel

View File

@ -0,0 +1,102 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import ArgumentParser
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
def test_automatic_optimization(tmpdir):
class TestModel(BoringModel):
def optimizer_step(self, *_, **__):
pass
model = TestModel()
try:
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
accumulate_grad_batches=2,
automatic_optimization=True
)
trainer.fit(model)
except MisconfigurationException as e:
assert "It ensures optimizer_step or optimizer_zero_grad are called on every batch" in str(e)
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir):
with patch("torch.optim.SGD.step") as sgd_step, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
patch("torch.optim.Adam.step") as adam_step, \
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0:
assert isinstance(optimizer, SGD)
optimizer.step(closure=optimizer_closure)
if not enable_pl_optimizer:
optimizer.zero_grad()
# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0:
assert isinstance(optimizer, Adam)
optimizer.step(closure=optimizer_closure)
if not enable_pl_optimizer:
optimizer.zero_grad()
model = TestModel()
model.training_epoch_end = None
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
automatic_optimization=True,
enable_pl_optimizer=enable_pl_optimizer
)
trainer.fit(model)
assert sgd_step.call_count == 4
assert sgd_zero_grad.call_count == 4
assert adam_step.call_count == 2
assert adam_zero_grad.call_count == 2

View File

@ -21,6 +21,7 @@ from torch.optim import Adam, Optimizer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
@ -188,10 +189,226 @@ def test_state(tmpdir):
assert isinstance(lightning_optimizer, Optimizer)
lightning_dict = {}
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx",
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_lightning_step"]
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization",
"_accumulate_grad_batches"]
for k, v in lightning_optimizer.__dict__.items():
if k not in special_attrs:
lightning_dict[k] = v
assert lightning_dict == optimizer.__dict__
assert optimizer.state_dict() == lightning_optimizer.state_dict()
assert optimizer.state == lightning_optimizer.state
def test_lightning_optimizer_automatic_optimization(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_epoch_end(self, outputs):
outputs = sum(outputs, [])
torch.stack([x["loss"] for x in outputs]).mean()
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 2 == 0)
def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
optimizer_1 = LightningOptimizer(optimizer_1, 4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=10,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
"""
Test lightning optimize works with optimizer_zero_grad overrides in automatic_optimization
"""
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_epoch_end(self, outputs):
outputs = sum(outputs, [])
torch.stack([x["loss"] for x in outputs]).mean()
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
if optimizer_idx == 0:
if batch_idx % 2 == 0:
optimizer.zero_grad()
if optimizer_idx == 1:
if batch_idx % 5 == 0:
optimizer.zero_grad()
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
optimizer.step(closure=optimizer_closure)
def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=10,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)
assert adam_zero_grad.call_count == 2
assert sgd_zero_grad.call_count == 5
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step(tmpdir):
"""
Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization
"""
try:
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_epoch_end(self, outputs):
outputs = sum(outputs, [])
torch.stack([x["loss"] for x in outputs]).mean()
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
if optimizer_idx == 0:
if batch_idx % 2 == 0:
optimizer.zero_grad()
if optimizer_idx == 1:
if batch_idx % 5 == 0:
optimizer.zero_grad()
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
if optimizer_idx == 0:
optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 3 == 0)
return
optimizer.step(closure=optimizer_closure)
def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=20,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True
)
trainer.fit(model)
assert adam_zero_grad.call_count == 4
assert sgd_zero_grad.call_count == 10
except MisconfigurationException as e:
assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str(e)
def test_lightning_optimizer_automatic_optimization_make_optimizer_step_2(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
"""
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_epoch_end(self, outputs):
outputs = sum(outputs, [])
torch.stack([x["loss"] for x in outputs]).mean()
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
make_optimizer_step = None
if optimizer_idx == 0:
make_optimizer_step = batch_idx % 4 == 0
optimizer.step(closure=optimizer_closure, make_optimizer_step=make_optimizer_step)
def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=20,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
automatic_optimization=True,
)
trainer.fit(model)
assert adam_zero_grad.call_count == 20
assert sgd_zero_grad.call_count == 5

View File

@ -1,13 +1,15 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
import os
from unittest import mock
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from tests.base.boring_model import BoringModel
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(os.environ, {

17
tests/special_tests.sh Normal file
View File

@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export PL_RUNNING_SPECIAL_TESTS=1
# Running special tests
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"

View File

@ -22,6 +22,7 @@ import torch.nn.functional as F
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@ -563,6 +564,15 @@ def test_multiple_optimizers_step(tmpdir):
Tests that `step` works with several optimizers
"""
class TestModel(BoringModel):
called = False
def on_after_backward(self):
self.called = True
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 100, norm.item()
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
@ -621,6 +631,7 @@ def test_multiple_optimizers_step(tmpdir):
num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls
assert model.called
def test_step_with_optimizer_closure(tmpdir):
@ -891,3 +902,32 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m
expected_calls = [call(closure=ANY, optim='adam') for s in range(2)]
mock_adam_step.assert_has_calls(expected_calls)
def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir):
"""
Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException
"""
try:
class TestModel(BoringModel):
def optimizer_zero_grad(self, *_):
pass
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 8
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
)
except MisconfigurationException as e:
assert "`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported" in str(e)