lightning/pytorch_lightning/callbacks/lambda_function.py

97 lines
4.0 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Lambda Callback
^^^^^^^^^^^^^^^
Create a simple callback on the fly using lambda functions.
"""
from typing import Callable, Optional
from pytorch_lightning.callbacks.base import Callback
class LambdaCallback(Callback):
r"""
Create a simple callback on the fly using lambda functions.
Args:
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LambdaCallback
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
"""
def __init__(
self,
on_before_accelerator_backend_setup: Optional[Callable] = None,
setup: Optional[Callable] = None,
on_configure_sharded_model: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
on_fit_start: Optional[Callable] = None,
on_fit_end: Optional[Callable] = None,
on_sanity_check_start: Optional[Callable] = None,
on_sanity_check_end: Optional[Callable] = None,
on_train_batch_start: Optional[Callable] = None,
on_train_batch_end: Optional[Callable] = None,
on_train_epoch_start: Optional[Callable] = None,
on_train_epoch_end: Optional[Callable] = None,
on_validation_epoch_start: Optional[Callable] = None,
on_validation_epoch_end: Optional[Callable] = None,
on_test_epoch_start: Optional[Callable] = None,
on_test_epoch_end: Optional[Callable] = None,
on_epoch_start: Optional[Callable] = None,
on_epoch_end: Optional[Callable] = None,
on_batch_start: Optional[Callable] = None,
on_validation_batch_start: Optional[Callable] = None,
on_validation_batch_end: Optional[Callable] = None,
on_test_batch_start: Optional[Callable] = None,
on_test_batch_end: Optional[Callable] = None,
on_batch_end: Optional[Callable] = None,
on_train_start: Optional[Callable] = None,
on_train_end: Optional[Callable] = None,
on_pretrain_routine_start: Optional[Callable] = None,
on_pretrain_routine_end: Optional[Callable] = None,
on_validation_start: Optional[Callable] = None,
on_validation_end: Optional[Callable] = None,
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
2021-09-01 08:49:00 +00:00
on_exception: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
Add the `on_before_backward` hook (#7865) * Add callback to hook tests and add predict test * Fix lambda callback test * Simplify lambda call test * Use LambdaCallback * Dynamically append to called for the model * Remove print * Consistency * Consistency * Prepare args/kwargs testing * yapf doesn't like dict literals * Add arguments for fit no val test * Add arguments for fit no val test * add before_backward_hook * add test * resolve flake8 * resolve tests * update changelog * add on_before_backward to LightningModule * update on comments * Test arguments * Datamodule refactor * Fix eval test * remove extra file * resolve bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move to hooks * update * resolve flake8 * update on comments * Update full fit + val test * Update test * Remove FIXME * Remove FIXME * Undo change * Fix * Parametrize fit hook test * Comment * Parametrize fit hook test with different precision plugins * Fix tests * Parametrize fit hook test with manual optimization * Unnecessary parenthesis * WIP * Comments * Fix message * Test CI error * Revert "Test CI error" This reverts commit 39c4a85a83cf32081b721f939ff83500b93f2dd3. * Add ddp training type teardown * Update CHANGELOG * Adrian's fix * Use destructor * Update CHANGELOG.md * RPC destructor * Update pytorch_lightning/plugins/training_type/ddp.py * Why do you not work :( * Missing condition * Fix deepspeed test * GC collect in conftest * Do not show warnings for special tests * Needs to run on 1.8 To avoid: "RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:32, unhandled cuda error, NCCL version 2.4.8" * Run torch 1.8 * Skip test due to 'Python bus error' * Debug NCCL * shm size * Disable warnings for special tests * Remove NCCL_DEBUG statement * Try smaller shm size * Revert "Skip test due to 'Python bus error'" This reverts commit e0a3e8785d2fecd63667da433a648f958d60ef89. * README and adjust versions * Avoid self.on_gpu call * empty cache cleanup * More garbage collection * Unroll parametrizations * Do not reuse mock * Undo changes * Undo notebooks modification * resolve test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * delete file * Undo * Fix test * Revert "WIP" This reverts commit f5828a8c426ff44275f560aec8d898f56da2cbfe. * Rename * Remove optimizers * Fix bug with LightningOptimizer * Add optimizers * update * update * Update CHANGELOG * On after backward refactor * Do not call super * Fixes * Remove should_accumulate * pre/post backward refactor * Call the LM backward hook * Update tests * Remove dev debug patch * Fix test * Remove optimizer arguments and typing * Docs fixes * Fix comment * Undo changes * Split manual and auto * Undo change * Deepsource * Remove optimizers * Undo changes * Call the hook * Docs * Docs Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-09 06:15:57 +00:00
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
on_predict_end: Optional[Callable] = None,
on_predict_batch_start: Optional[Callable] = None,
on_predict_batch_end: Optional[Callable] = None,
on_predict_epoch_start: Optional[Callable] = None,
on_predict_epoch_end: Optional[Callable] = None,
):
for k, v in locals().items():
if k == "self":
continue
if v is not None:
setattr(self, k, v)