Simplify optimization Logic (#4984)
* Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization
* debug
* Revert "debug"
This reverts commit ccca6b6b
* Expose manual reduce for automatic optimization
* Add input arguments
* Enable parity test
* clean imports
* Expose hook after to ensure we reset
* Fix naming
* add
* fix test
* uniformize optimizer logic
* resolve test
* resovle flake8
* resolve amp bug
* update tests
* remove bug
* remove optimizer_step in accelerators
* typo
* update lightning optimizer
* set doesn't work with ddp_spawn
* resolve flake8
* update threshold
* ignore pyright
* correct codeFactor
* remove useless if
* remove zer_grad function
* simplify step
* remove typo
* resolve bug
* Apply suggestions from code review
* update on comments
* resolve bugs
* remove tests
* Update pytorch_lightning/trainer/configuration_validator.py
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
* simplify testing
* add more tests
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
ab7c947961
commit
02152c1729
|
@ -39,10 +39,11 @@ steps:
|
|||
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
|
||||
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
|
||||
- pip list
|
||||
- coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
|
||||
- python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0 # --flake8
|
||||
#- cd docs; make doctest; make coverage
|
||||
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
|
||||
# Running special tests
|
||||
- sh tests/special_tests.sh
|
||||
- coverage report
|
||||
- python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0
|
||||
# see: https://docs.codecov.io/docs/merging-reports
|
||||
- codecov --token $CODECOV_TOKEN --flags=gpu,pytest --name="GPU-coverage" --env=linux --build $DRONE_BUILD_NUMBER --commit $DRONE_COMMIT
|
||||
# --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG --pr $DRONE_PULL_REQUEST
|
||||
|
|
|
@ -16,14 +16,15 @@ from enum import Enum
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
import torch.distributed as torch_distrib
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
|
@ -98,40 +99,6 @@ class Accelerator(object):
|
|||
closure_loss = closure_loss.detach()
|
||||
return closure_loss
|
||||
|
||||
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
|
||||
model_ref = self.trainer.get_model()
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
|
||||
automatic_optimization = self.trainer.train_loop.automatic_optimization
|
||||
|
||||
# native amp + lbfgs is a no go right now
|
||||
if using_native_amp and is_lbfgs:
|
||||
raise MisconfigurationException(
|
||||
'native PyTorch amp and lbfgs are not compatible.'
|
||||
' To request, please file a Github issue in PyTorch and tag @mcarilli')
|
||||
|
||||
# model hook
|
||||
model_ref.optimizer_step(
|
||||
epoch=self.trainer.current_epoch,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=opt_idx,
|
||||
optimizer_closure=lambda_closure,
|
||||
on_tpu=False, # TPUAccelerator class sets this as True
|
||||
using_native_amp=using_native_amp,
|
||||
using_lbfgs=is_lbfgs,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# scale when native amp
|
||||
if automatic_optimization and using_native_amp:
|
||||
self.trainer.scaler.update()
|
||||
|
||||
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
|
||||
model_ref = self.trainer.get_model()
|
||||
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def clip_gradients(self, optimizer, clip_val=None):
|
||||
# use the trainer's clip val if none passed
|
||||
grad_clip_val = self.trainer.gradient_clip_val
|
||||
|
@ -160,7 +127,7 @@ class Accelerator(object):
|
|||
return self.trainer.should_stop
|
||||
|
||||
def setup_optimizers(self, model):
|
||||
if self.trainer.testing is True:
|
||||
if self.trainer.testing:
|
||||
return
|
||||
|
||||
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
|
||||
|
|
|
@ -23,7 +23,14 @@ from torch.optim import Optimizer
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn, move_data_to_device
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities import (
|
||||
TPU_AVAILABLE,
|
||||
move_data_to_device,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -245,24 +252,6 @@ class TPUAccelerator(Accelerator):
|
|||
|
||||
return closure_loss
|
||||
|
||||
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
|
||||
model_ref = self.trainer.get_model()
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
|
||||
# model hook
|
||||
model_ref.optimizer_step(
|
||||
epoch=self.trainer.current_epoch,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=opt_idx,
|
||||
optimizer_closure=lambda_closure,
|
||||
on_tpu=True,
|
||||
using_native_amp=False,
|
||||
using_lbfgs=is_lbfgs,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
|
||||
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
||||
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
||||
|
|
|
@ -67,6 +67,9 @@ class GradientAccumulationScheduler(Callback):
|
|||
self.scheduling = scheduling
|
||||
self.epochs = sorted(scheduling.keys())
|
||||
|
||||
def going_to_accumulate_grad_batches(self):
|
||||
return any([v > 1 for v in self.scheduling.values()])
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
epoch = trainer.current_epoch
|
||||
for i in reversed(range(len(self.epochs))):
|
||||
|
|
|
@ -33,6 +33,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
|
||||
|
@ -1236,15 +1237,10 @@ class LightningModule(
|
|||
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.
|
||||
|
||||
"""
|
||||
if on_tpu and TPU_AVAILABLE:
|
||||
xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs})
|
||||
|
||||
elif self.trainer.amp_backend is not None:
|
||||
self.trainer.precision_connector.backend.optimizer_step(
|
||||
self.trainer, optimizer, optimizer_closure)
|
||||
|
||||
else:
|
||||
optimizer.step(closure=optimizer_closure, *args, **kwargs)
|
||||
if not isinstance(optimizer, LightningOptimizer):
|
||||
# wraps into LightingOptimizer only for running step
|
||||
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
|
||||
optimizer.step(closure=optimizer_closure, *args, **kwargs)
|
||||
|
||||
def optimizer_zero_grad(
|
||||
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
|
||||
|
|
|
@ -35,7 +35,7 @@ def do_nothing_closure():
|
|||
class LightningOptimizer:
|
||||
"""
|
||||
This class is used to wrap the user optimizers and handle properly
|
||||
the backward and optimizer_step logic across accelerators, AMP, accumulated_grad_batches
|
||||
the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches
|
||||
"""
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
|
@ -60,17 +60,35 @@ class LightningOptimizer:
|
|||
self._trainer = None
|
||||
self._optimizer = optimizer
|
||||
self._accumulate_grad_batches = accumulate_grad_batches
|
||||
self._use_accumulate_grad_batches_from_trainer = accumulate_grad_batches is None
|
||||
self._automatic_optimization = None
|
||||
self._optimizer_idx = None
|
||||
|
||||
@property
|
||||
def accumulate_grad_batches(self):
|
||||
return self._accumulate_grad_batches
|
||||
|
||||
@accumulate_grad_batches.setter
|
||||
def accumulate_grad_batches(self, accumulate_grad_batches):
|
||||
self._accumulate_grad_batches = accumulate_grad_batches
|
||||
|
||||
def _on_trainer_init(self, trainer):
|
||||
self._trainer = proxy(trainer)
|
||||
self._automatic_optimization = trainer.train_loop.automatic_optimization
|
||||
for opt_idx, opt in enumerate(trainer.optimizers):
|
||||
if opt == self._optimizer:
|
||||
self._optimizer_idx = opt_idx
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def to_lightning_optimizer(cls, optimizer, trainer):
|
||||
optimizer = cls(optimizer)
|
||||
optimizer._on_trainer_init(trainer)
|
||||
return optimizer
|
||||
|
||||
def _accumulated_batches_reached(self):
|
||||
if self._use_accumulate_grad_batches_from_trainer:
|
||||
accumulate_grad_batches = self._trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = self._accumulate_grad_batches
|
||||
return (self._trainer.batch_idx + 1) % accumulate_grad_batches == 0
|
||||
if self.accumulate_grad_batches is None:
|
||||
return self._trainer.train_loop._accumulated_batches_reached()
|
||||
return (self._trainer.batch_idx + 1) % self.accumulate_grad_batches == 0
|
||||
|
||||
@property
|
||||
def _should_accumulate(self):
|
||||
|
@ -79,6 +97,45 @@ class LightningOptimizer:
|
|||
is_final_batch = self._trainer.train_loop._num_training_batches_reached()
|
||||
return not (accumulation_done or is_final_batch)
|
||||
|
||||
def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
|
||||
trainer = self._trainer
|
||||
optimizer = self._optimizer
|
||||
model = trainer.get_model()
|
||||
|
||||
if trainer.on_tpu:
|
||||
with trainer.profiler.profile(profiler_name):
|
||||
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
|
||||
|
||||
elif trainer.amp_backend is not None:
|
||||
trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
|
||||
|
||||
else:
|
||||
with trainer.profiler.profile(profiler_name):
|
||||
optimizer.step(closure=closure, *args, **kwargs)
|
||||
|
||||
trainer.train_loop.on_before_zero_grad(self)
|
||||
|
||||
model.optimizer_zero_grad(
|
||||
trainer.current_epoch,
|
||||
trainer.batch_idx,
|
||||
optimizer,
|
||||
self._optimizer_idx
|
||||
)
|
||||
|
||||
def _check_make_optimizer_step(self, make_optimizer_step: Optional[bool]) -> bool:
|
||||
if make_optimizer_step is not None and self._trainer.overriden_optimizer_zero_grad:
|
||||
raise MisconfigurationException(
|
||||
"When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed.")
|
||||
|
||||
if self._trainer.train_loop.automatic_optimization:
|
||||
if self._trainer.overriden_optimizer_step and self._trainer.overriden_optimizer_zero_grad:
|
||||
return True
|
||||
|
||||
if make_optimizer_step is None:
|
||||
make_optimizer_step = not self._should_accumulate
|
||||
|
||||
return make_optimizer_step
|
||||
|
||||
def step(self, *args, closure: Optional[Callable] = None, make_optimizer_step: Optional[bool] = None, **kwargs):
|
||||
"""
|
||||
Call this directly from your training_step when doing optimizations manually.
|
||||
|
@ -173,40 +230,23 @@ class LightningOptimizer:
|
|||
# Trainer(accumulate_grad_batches=x)
|
||||
opt_dis.step(closure=optimizer_closure, make_optimizer_step=True)
|
||||
"""
|
||||
profiler_name = "optimizer_step_and_closure"
|
||||
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
|
||||
|
||||
if closure is None:
|
||||
closure = do_nothing_closure
|
||||
profile_name = "optimizer_step"
|
||||
profile_name = f"optimizer_step_{self._optimizer_idx}"
|
||||
else:
|
||||
if not isinstance(closure, types.FunctionType):
|
||||
raise MisconfigurationException("When closure is provided, it should be a function")
|
||||
|
||||
if make_optimizer_step is None:
|
||||
make_optimizer_step = not self._should_accumulate
|
||||
|
||||
trainer = self._trainer
|
||||
optimizer = self._optimizer
|
||||
make_optimizer_step = self._check_make_optimizer_step(make_optimizer_step)
|
||||
|
||||
if make_optimizer_step:
|
||||
if trainer.on_tpu:
|
||||
with trainer.profiler.profile(profiler_name):
|
||||
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
|
||||
|
||||
elif trainer.amp_backend is not None:
|
||||
trainer.precision_connector.backend.optimizer_step(
|
||||
trainer, optimizer, closure)
|
||||
|
||||
else:
|
||||
with trainer.profiler.profile(profiler_name):
|
||||
optimizer.step(closure=closure, *args, **kwargs)
|
||||
|
||||
# perform zero grad
|
||||
optimizer.zero_grad()
|
||||
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
|
||||
else:
|
||||
# make sure to call optimizer_closure when accumulating
|
||||
with trainer.profiler.profile("closure"):
|
||||
with trainer.train_loop.block_ddp_sync_behaviour():
|
||||
with self._trainer.profiler.profile(f"closure_{self._optimizer_idx}"):
|
||||
with self._trainer.train_loop.block_ddp_sync_behaviour():
|
||||
closure()
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import itertools
|
||||
import threading
|
||||
from collections.abc import Mapping, Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
|
|
|
@ -137,5 +137,9 @@ class ApexPlugin(PrecisionPlugin):
|
|||
# TODO: pass the closure to the step ASAP
|
||||
with trainer.profiler.profile("closure"):
|
||||
closure()
|
||||
|
||||
if not self.trainer.train_loop.automatic_optimization:
|
||||
trainer.call_hook("on_after_backward")
|
||||
|
||||
with trainer.profiler.profile("optimizer_step"):
|
||||
optimizer.step()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch.distributed as torch_distrib
|
||||
from pytorch_lightning import _logger as log
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.plugins.plugin import LightningPlugin
|
||||
|
|
|
@ -69,6 +69,11 @@ class NativeAMPPlugin(PrecisionPlugin):
|
|||
# TODO: pass the closure to the step ASAP
|
||||
with trainer.profiler.profile("closure"):
|
||||
closure()
|
||||
|
||||
if not self.trainer.train_loop.automatic_optimization:
|
||||
trainer.scaler.unscale_(optimizer)
|
||||
trainer.call_hook("on_after_backward")
|
||||
|
||||
with trainer.profiler.profile("optimizer_step"):
|
||||
trainer.scaler.step(optimizer)
|
||||
trainer.scaler.update()
|
||||
|
|
|
@ -17,7 +17,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
|
||||
from pytorch_lightning.utilities import AMPType, FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if FAIRSCALE_AVAILABLE:
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -69,6 +69,37 @@ class ConfigValidator(object):
|
|||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
|
||||
)
|
||||
|
||||
trainer = self.trainer
|
||||
|
||||
trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
|
||||
trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
|
||||
|
||||
enable_pl_optimizer = trainer._enable_pl_optimizer
|
||||
automatic_optimization = trainer.train_loop.automatic_optimization
|
||||
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
|
||||
rank_zero_warn(
|
||||
"When overriding `LightningModule` optimizer_step with"
|
||||
" `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
|
||||
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
|
||||
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
|
||||
)
|
||||
|
||||
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()
|
||||
|
||||
has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
|
||||
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
|
||||
raise MisconfigurationException(
|
||||
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
|
||||
'`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
|
||||
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
|
||||
)
|
||||
|
||||
if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
|
||||
raise MisconfigurationException(
|
||||
'When overriding `LightningModule` optimizer_zero_grad with '
|
||||
'`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
|
||||
)
|
||||
|
||||
def __verify_eval_loop_configuration(self, model, eval_loop_name):
|
||||
step_name = f'{eval_loop_name}_step'
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class LoggerConnector:
|
|||
|
||||
@property
|
||||
def cached_results(self) -> Union[EpochResultStore, None]:
|
||||
return self._cached_results.get(self._current_stage)
|
||||
return self._cached_results.get(self._current_stage) # type: ignore
|
||||
|
||||
def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
|
||||
self._current_stage = LoggerStages.determine_stage(stage_or_testing)
|
||||
|
|
|
@ -855,7 +855,8 @@ class Trainer(
|
|||
model.setup(stage_name)
|
||||
|
||||
def _reset_result_and_set_hook_fx_name(self, hook_name):
|
||||
if "batch_start" in hook_name:
|
||||
# on_before_zero_grad is called within training_step
|
||||
if "batch_start" in hook_name or "on_before_zero_grad" in hook_name:
|
||||
return True
|
||||
model_ref = self.get_model()
|
||||
if model_ref is not None:
|
||||
|
|
|
@ -26,7 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AMPType, parsing
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
|
@ -321,7 +321,6 @@ class TrainLoop:
|
|||
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
|
||||
|
||||
# manually capture logged metrics
|
||||
model_ref._results = Result()
|
||||
model_ref._current_fx_name = 'training_step'
|
||||
model_ref._results = Result()
|
||||
training_step_output = self.trainer.accelerator_backend.training_step(args)
|
||||
|
@ -475,21 +474,34 @@ class TrainLoop:
|
|||
return training_step_output_for_epoch_end
|
||||
|
||||
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs):
|
||||
# optimizer step lightningModule hook
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
optimizer.step(closure=train_step_and_backward_closure)
|
||||
else:
|
||||
with self.trainer.profiler.profile("optimizer_step"):
|
||||
self.trainer.accelerator_backend.optimizer_step(
|
||||
optimizer, batch_idx, opt_idx, train_step_and_backward_closure, *args, **kwargs
|
||||
)
|
||||
model_ref = self.trainer.get_model()
|
||||
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
|
||||
|
||||
# native amp + lbfgs is a no go right now
|
||||
if using_native_amp and is_lbfgs:
|
||||
raise MisconfigurationException(
|
||||
'native PyTorch amp and lbfgs are not compatible.'
|
||||
' To request, please file a Github issue in PyTorch and tag @mcarilli')
|
||||
|
||||
# model hook
|
||||
model_ref.optimizer_step(
|
||||
epoch=self.trainer.current_epoch,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=opt_idx,
|
||||
optimizer_closure=train_step_and_backward_closure,
|
||||
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
|
||||
using_native_amp=using_native_amp,
|
||||
using_lbfgs=is_lbfgs,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
self.trainer.call_hook('on_before_zero_grad', optimizer)
|
||||
|
||||
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
|
||||
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
||||
def track_and_norm_grad(self, optimizer):
|
||||
# track gradient norms
|
||||
grad_norm_dic = self._track_gradient_norm()
|
||||
|
@ -708,7 +720,6 @@ class TrainLoop:
|
|||
if self._curr_step_result is None:
|
||||
# user decided to skip optimization
|
||||
# make sure to zero grad.
|
||||
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
|
||||
continue
|
||||
|
||||
batch_outputs = self._process_closure_result(
|
||||
|
@ -720,9 +731,6 @@ class TrainLoop:
|
|||
grad_norm_dic = self._cur_grad_norm_dict
|
||||
self._cur_grad_norm_dict = None
|
||||
|
||||
# hook + clear gradients
|
||||
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
|
||||
|
||||
# update running loss + reset accumulated loss
|
||||
self.update_running_loss()
|
||||
|
||||
|
@ -947,14 +955,3 @@ class TrainLoop:
|
|||
|
||||
# reset for next set of accumulated grads
|
||||
self.accumulated_loss.reset()
|
||||
|
||||
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
|
||||
if self.automatic_optimization:
|
||||
# hook
|
||||
self.on_before_zero_grad(optimizer)
|
||||
optimizers = enumerate([optimizer])
|
||||
else:
|
||||
optimizers = []
|
||||
|
||||
for idx, optimizer in optimizers:
|
||||
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
|
|
@ -11,9 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, call, ANY
|
||||
from unittest.mock import ANY, MagicMock, call
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base import BoringModel
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.optim import SGD, Adam
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import BoringModel
|
||||
|
||||
|
||||
def test_automatic_optimization(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def optimizer_step(self, *_, **__):
|
||||
pass
|
||||
|
||||
model = TestModel()
|
||||
|
||||
try:
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
accumulate_grad_batches=2,
|
||||
automatic_optimization=True
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
except MisconfigurationException as e:
|
||||
assert "It ensures optimizer_step or optimizer_zero_grad are called on every batch" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
|
||||
def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir):
|
||||
|
||||
with patch("torch.optim.SGD.step") as sgd_step, \
|
||||
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
|
||||
patch("torch.optim.Adam.step") as adam_step, \
|
||||
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
|
||||
return [optimizer, optimizer_2]
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
|
||||
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
|
||||
|
||||
# update generator opt every 2 steps
|
||||
if optimizer_idx == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
assert isinstance(optimizer, SGD)
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
if not enable_pl_optimizer:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# update discriminator opt every 4 steps
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 4 == 0:
|
||||
assert isinstance(optimizer, Adam)
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
if not enable_pl_optimizer:
|
||||
optimizer.zero_grad()
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=8,
|
||||
accumulate_grad_batches=1,
|
||||
automatic_optimization=True,
|
||||
enable_pl_optimizer=enable_pl_optimizer
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert sgd_step.call_count == 4
|
||||
assert sgd_zero_grad.call_count == 4
|
||||
assert adam_step.call_count == 2
|
||||
assert adam_zero_grad.call_count == 2
|
|
@ -21,6 +21,7 @@ from torch.optim import Adam, Optimizer
|
|||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
|
||||
|
||||
|
||||
|
@ -188,10 +189,226 @@ def test_state(tmpdir):
|
|||
assert isinstance(lightning_optimizer, Optimizer)
|
||||
lightning_dict = {}
|
||||
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx",
|
||||
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_lightning_step"]
|
||||
"_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization",
|
||||
"_accumulate_grad_batches"]
|
||||
for k, v in lightning_optimizer.__dict__.items():
|
||||
if k not in special_attrs:
|
||||
lightning_dict[k] = v
|
||||
assert lightning_dict == optimizer.__dict__
|
||||
assert optimizer.state_dict() == lightning_optimizer.state_dict()
|
||||
assert optimizer.state == lightning_optimizer.state
|
||||
|
||||
|
||||
def test_lightning_optimizer_automatic_optimization(tmpdir):
|
||||
"""
|
||||
Test lightning optimize works with make_optimizer_step in automatic_optimization
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
outputs = sum(outputs, [])
|
||||
torch.stack([x["loss"] for x in outputs]).mean()
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
|
||||
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
|
||||
|
||||
optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 2 == 0)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
|
||||
optimizer_1 = LightningOptimizer(optimizer_1, 4)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
|
||||
return [optimizer_1, optimizer_2], [lr_scheduler]
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
enable_pl_optimizer=True,
|
||||
automatic_optimization=True
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
|
||||
"""
|
||||
Test lightning optimize works with optimizer_zero_grad overrides in automatic_optimization
|
||||
"""
|
||||
|
||||
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
|
||||
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
outputs = sum(outputs, [])
|
||||
torch.stack([x["loss"] for x in outputs]).mean()
|
||||
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
||||
if optimizer_idx == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 5 == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
|
||||
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
|
||||
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
|
||||
return [optimizer_1, optimizer_2], [lr_scheduler]
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
enable_pl_optimizer=True,
|
||||
automatic_optimization=True
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert adam_zero_grad.call_count == 2
|
||||
assert sgd_zero_grad.call_count == 5
|
||||
|
||||
|
||||
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step(tmpdir):
|
||||
"""
|
||||
Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization
|
||||
"""
|
||||
|
||||
try:
|
||||
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
|
||||
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
outputs = sum(outputs, [])
|
||||
torch.stack([x["loss"] for x in outputs]).mean()
|
||||
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
||||
if optimizer_idx == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 5 == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
|
||||
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
|
||||
|
||||
if optimizer_idx == 0:
|
||||
optimizer.step(closure=optimizer_closure, make_optimizer_step=batch_idx % 3 == 0)
|
||||
return
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
|
||||
return [optimizer_1, optimizer_2], [lr_scheduler]
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=20,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
enable_pl_optimizer=True,
|
||||
automatic_optimization=True
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert adam_zero_grad.call_count == 4
|
||||
assert sgd_zero_grad.call_count == 10
|
||||
|
||||
except MisconfigurationException as e:
|
||||
assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str(e)
|
||||
|
||||
|
||||
def test_lightning_optimizer_automatic_optimization_make_optimizer_step_2(tmpdir):
|
||||
"""
|
||||
Test lightning optimize works with make_optimizer_step in automatic_optimization
|
||||
"""
|
||||
|
||||
with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
|
||||
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
outputs = sum(outputs, [])
|
||||
torch.stack([x["loss"] for x in outputs]).mean()
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
|
||||
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
|
||||
|
||||
make_optimizer_step = None
|
||||
if optimizer_idx == 0:
|
||||
make_optimizer_step = batch_idx % 4 == 0
|
||||
optimizer.step(closure=optimizer_closure, make_optimizer_step=make_optimizer_step)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
|
||||
return [optimizer_1, optimizer_2], [lr_scheduler]
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=20,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
enable_pl_optimizer=True,
|
||||
automatic_optimization=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert adam_zero_grad.call_count == 20
|
||||
assert sgd_zero_grad.call_count == 5
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
|
||||
from tests.base.boring_model import BoringModel
|
||||
from pytorch_lightning import Trainer
|
||||
import pytest
|
||||
import os
|
||||
from unittest import mock
|
||||
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@mock.patch.dict(os.environ, {
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export PL_RUNNING_SPECIAL_TESTS=1
|
||||
# Running special tests
|
||||
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -563,6 +564,15 @@ def test_multiple_optimizers_step(tmpdir):
|
|||
Tests that `step` works with several optimizers
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
|
||||
called = False
|
||||
|
||||
def on_after_backward(self):
|
||||
self.called = True
|
||||
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
|
||||
if not (torch.isinf(norm) or torch.isnan(norm)):
|
||||
assert norm.item() < 100, norm.item()
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# manual
|
||||
(opt_a, opt_b) = self.optimizers()
|
||||
|
@ -621,6 +631,7 @@ def test_multiple_optimizers_step(tmpdir):
|
|||
|
||||
num_manual_backward_calls = 3
|
||||
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls
|
||||
assert model.called
|
||||
|
||||
|
||||
def test_step_with_optimizer_closure(tmpdir):
|
||||
|
@ -891,3 +902,32 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m
|
|||
|
||||
expected_calls = [call(closure=ANY, optim='adam') for s in range(2)]
|
||||
mock_adam_step.assert_has_calls(expected_calls)
|
||||
|
||||
|
||||
def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir):
|
||||
"""
|
||||
Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException
|
||||
"""
|
||||
try:
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def optimizer_zero_grad(self, *_):
|
||||
pass
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
model.training_epoch_end = None
|
||||
|
||||
limit_train_batches = 8
|
||||
trainer = Trainer(
|
||||
automatic_optimization=False,
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
log_every_n_steps=1,
|
||||
accumulate_grad_batches=2,
|
||||
enable_pl_optimizer=True,
|
||||
)
|
||||
except MisconfigurationException as e:
|
||||
assert "`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported" in str(e)
|
||||
|
|
Loading…
Reference in New Issue