Update extensions doc (#10778)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
8bf7f9cce7
commit
ce95891f6a
|
@ -168,7 +168,7 @@ To load a model along with its weights, biases and hyperparameters use the follo
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightingModule.load_from_checkpoint(PATH)
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
|
||||
print(model.learning_rate)
|
||||
# prints the learning_rate you used in this checkpoint
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
############
|
||||
Accelerators
|
||||
############
|
||||
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
|
||||
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, IPUs). Accelerators
|
||||
also manage distributed communication through :ref:`Plugins` (like DP, DDP, HPC cluster) and
|
||||
can also be configured to run on arbitrary clusters or to link up to arbitrary
|
||||
computational strategies like 16-bit precision via AMP and Apex.
|
||||
|
@ -26,7 +26,7 @@ One to handle differences from the training routine and one to handle different
|
|||
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
|
||||
|
||||
accelerator = GPUAccelerator(
|
||||
precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"),
|
||||
precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda"),
|
||||
training_type_plugin=DDPPlugin(),
|
||||
)
|
||||
trainer = Trainer(accelerator=accelerator)
|
||||
|
|
|
@ -21,7 +21,7 @@ Callback
|
|||
|
||||
A callback is a self-contained program that can be reused across projects.
|
||||
|
||||
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL
|
||||
Lightning has a callback system to execute them when needed. Callbacks should capture NON-ESSENTIAL
|
||||
logic that is NOT required for your :doc:`lightning module <../common/lightning_module>` to run.
|
||||
|
||||
Here's the flow of how the callback hooks are executed:
|
||||
|
@ -47,10 +47,10 @@ Example:
|
|||
|
||||
class MyPrintingCallback(Callback):
|
||||
def on_init_start(self, trainer):
|
||||
print("Starting to init trainer!")
|
||||
print("Starting to initialize the trainer!")
|
||||
|
||||
def on_init_end(self, trainer):
|
||||
print("trainer is init now")
|
||||
print("trainer is initialized now")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print("do something when training ends")
|
||||
|
@ -60,8 +60,8 @@ Example:
|
|||
|
||||
.. testoutput::
|
||||
|
||||
Starting to init trainer!
|
||||
trainer is init now
|
||||
Starting to initialize the trainer!
|
||||
trainer is initialized now
|
||||
|
||||
We successfully extended functionality without polluting our super clean
|
||||
:doc:`lightning module <../common/lightning_module>` research code.
|
||||
|
@ -86,7 +86,7 @@ Lightning has a few built-in callbacks.
|
|||
|
||||
.. note::
|
||||
For a richer collection of callbacks, check out our
|
||||
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/callbacks.html>`_.
|
||||
`bolts library <https://lightning-bolts.readthedocs.io/en/latest/index.html>`_.
|
||||
|
||||
.. currentmodule:: pytorch_lightning.callbacks
|
||||
|
||||
|
@ -108,12 +108,13 @@ Lightning has a few built-in callbacks.
|
|||
ModelCheckpoint
|
||||
ModelPruning
|
||||
ModelSummary
|
||||
ProgressBar
|
||||
ProgressBarBase
|
||||
QuantizationAwareTraining
|
||||
RichModelSummary
|
||||
RichProgressBar
|
||||
QuantizationAwareTraining
|
||||
StochasticWeightAveraging
|
||||
Timer
|
||||
TQDMProgressBar
|
||||
XLAStatsMonitor
|
||||
|
||||
----------
|
||||
|
@ -197,245 +198,310 @@ The following are best practices when using/designing callbacks.
|
|||
|
||||
.. _hooks:
|
||||
|
||||
Available Callback hooks
|
||||
------------------------
|
||||
Callback API
|
||||
------------
|
||||
Here is the full API of methods available in the Callback base class.
|
||||
|
||||
The :class:`~pytorch_lightning.callbacks.Callback` class is the base for all the callbacks in Lightning just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models.
|
||||
It defines a public interface that each callback implementation must follow, the key ones are:
|
||||
|
||||
Properties
|
||||
^^^^^^^^^^
|
||||
|
||||
state_key
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoattribute:: pytorch_lightning.callbacks.Callback.state_key
|
||||
:noindex:
|
||||
|
||||
|
||||
Hooks
|
||||
^^^^^
|
||||
|
||||
on_configure_sharded_model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_configure_sharded_model
|
||||
:noindex:
|
||||
|
||||
on_before_accelerator_backend_setup
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_accelerator_backend_setup
|
||||
:noindex:
|
||||
|
||||
setup
|
||||
^^^^^
|
||||
~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.setup
|
||||
:noindex:
|
||||
|
||||
teardown
|
||||
^^^^^^^^
|
||||
~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.teardown
|
||||
:noindex:
|
||||
|
||||
on_init_start
|
||||
^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_init_start
|
||||
:noindex:
|
||||
|
||||
on_init_end
|
||||
^^^^^^^^^^^
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_init_end
|
||||
:noindex:
|
||||
|
||||
on_fit_start
|
||||
^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_fit_start
|
||||
:noindex:
|
||||
|
||||
on_fit_end
|
||||
^^^^^^^^^^
|
||||
~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_fit_end
|
||||
:noindex:
|
||||
|
||||
on_sanity_check_start
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_start
|
||||
:noindex:
|
||||
|
||||
on_sanity_check_end
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_end
|
||||
:noindex:
|
||||
|
||||
on_train_batch_start
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_start
|
||||
:noindex:
|
||||
|
||||
on_train_batch_end
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_end
|
||||
:noindex:
|
||||
|
||||
on_train_epoch_start
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_train_epoch_end
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_validation_epoch_start
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_validation_epoch_end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_test_epoch_start
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_test_epoch_end
|
||||
^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_predict_epoch_start
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_predict_epoch_end
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_epoch_start
|
||||
^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_epoch_end
|
||||
^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_batch_start
|
||||
^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_start
|
||||
:noindex:
|
||||
|
||||
on_batch_end
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end
|
||||
:noindex:
|
||||
|
||||
on_validation_batch_start
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_start
|
||||
:noindex:
|
||||
|
||||
on_validation_batch_end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_end
|
||||
:noindex:
|
||||
|
||||
on_test_batch_start
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_start
|
||||
:noindex:
|
||||
|
||||
on_test_batch_end
|
||||
^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_end
|
||||
:noindex:
|
||||
|
||||
on_batch_end
|
||||
^^^^^^^^^^^^
|
||||
on_predict_batch_start
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_batch_start
|
||||
:noindex:
|
||||
|
||||
on_predict_batch_end
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_batch_end
|
||||
:noindex:
|
||||
|
||||
on_train_start
|
||||
^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_start
|
||||
:noindex:
|
||||
|
||||
on_train_end
|
||||
^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_train_end
|
||||
:noindex:
|
||||
|
||||
on_pretrain_routine_start
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start
|
||||
:noindex:
|
||||
|
||||
on_pretrain_routine_end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end
|
||||
:noindex:
|
||||
|
||||
on_validation_start
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_start
|
||||
:noindex:
|
||||
|
||||
on_validation_end
|
||||
^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_end
|
||||
:noindex:
|
||||
|
||||
on_test_start
|
||||
^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_start
|
||||
:noindex:
|
||||
|
||||
on_test_end
|
||||
^^^^^^^^^^^
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_test_end
|
||||
:noindex:
|
||||
|
||||
on_predict_start
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_start
|
||||
:noindex:
|
||||
|
||||
on_predict_end
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_end
|
||||
:noindex:
|
||||
|
||||
on_keyboard_interrupt
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_keyboard_interrupt
|
||||
:noindex:
|
||||
|
||||
on_exception
|
||||
^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_exception
|
||||
:noindex:
|
||||
|
||||
on_save_checkpoint
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint
|
||||
:noindex:
|
||||
|
||||
on_load_checkpoint
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
|
||||
:noindex:
|
||||
|
||||
on_before_backward
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_backward
|
||||
:noindex:
|
||||
|
||||
on_after_backward
|
||||
^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
|
||||
:noindex:
|
||||
|
||||
on_before_optimizer_step
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
|
||||
:noindex:
|
||||
|
||||
on_before_zero_grad
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
|
||||
:noindex:
|
||||
|
|
|
@ -50,7 +50,7 @@ Think of this as swapping out the engine in a car!
|
|||
|
||||
----------
|
||||
|
||||
Understanding the default Trainer loop
|
||||
Understanding the Default Trainer Loop
|
||||
--------------------------------------
|
||||
|
||||
The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with:
|
||||
|
@ -75,7 +75,7 @@ The core research logic is simply shifted to the :class:`~pytorch_lightning.core
|
|||
# loss = loss_function(y_hat, y) moved to training_step
|
||||
loss = lightning_module.training_step(batch, i)
|
||||
|
||||
# Lighting handles automatically:
|
||||
# Lightning handles automatically:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -109,12 +109,12 @@ Defining a loop within a class interface instead of hard-coding a raw Python for
|
|||
|
||||
.. _override default loops:
|
||||
|
||||
Overriding the default loops
|
||||
Overriding the default Loops
|
||||
----------------------------
|
||||
|
||||
The fastest way to get started with loops, is to override functionality of an existing loop.
|
||||
Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating,
|
||||
:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing,
|
||||
Lightning has 4 main loops which relies on : :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for fitting (training and validating),
|
||||
:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for validating or testing,
|
||||
:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting.
|
||||
|
||||
For simple changes that don't require a custom loop, you can modify each of these loops.
|
||||
|
@ -166,11 +166,11 @@ Now simply attach the correct loop in the trainer directly:
|
|||
# fit() now uses the new FitLoop!
|
||||
trainer.fit(...)
|
||||
|
||||
# the equivalent for validate(), test(), predict()
|
||||
# the equivalent for validate()
|
||||
val_loop = CustomValLoop()
|
||||
trainer = Trainer()
|
||||
trainer.validate_loop = val_loop
|
||||
trainer.validate(model)
|
||||
trainer.validate(...)
|
||||
|
||||
Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning!
|
||||
|
||||
|
@ -179,7 +179,7 @@ Now your code is FULLY flexible and you can still leverage ALL the best parts of
|
|||
|
||||
----------
|
||||
|
||||
Creating a new loop from scratch
|
||||
Creating a New Loop From Scratch
|
||||
--------------------------------
|
||||
|
||||
You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class.
|
||||
|
@ -212,8 +212,7 @@ Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
|
|||
# fit() now uses your fancy loop!
|
||||
trainer.fit(...)
|
||||
|
||||
Now you have full control over the Trainer.
|
||||
But beware: The power of loop customization comes with great responsibility.
|
||||
But beware: Loop customization gives you more power and full control over the Trainer and with great power comes great responsibility.
|
||||
We recommend that you familiarize yourself with :ref:`overriding the default loops <override default loops>` first before you start building a new loop from the ground up.
|
||||
|
||||
----------
|
||||
|
@ -222,7 +221,7 @@ Loop API
|
|||
--------
|
||||
Here is the full API of methods available in the Loop base class.
|
||||
|
||||
The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models.
|
||||
The :class:`~pytorch_lightning.loops.base.Loop` class is the base of all loops in the same way as the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base of all models.
|
||||
It defines a public interface that each loop implementation must follow, the key ones are:
|
||||
|
||||
Properties
|
||||
|
@ -348,6 +347,12 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
|
|||
It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).
|
||||
* - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization`
|
||||
- Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step.
|
||||
* - :class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop`
|
||||
- The :class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` is the top-level loop where validation/testing starts.
|
||||
It simply iterates over each evaluation dataloader from one to the next by calling :code:`EvaluationEpochLoop.run()` in its :code:`advance()` method.
|
||||
* - :class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop`
|
||||
- The :class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` is the top-level loop where prediction starts.
|
||||
It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method.
|
||||
|
||||
|
||||
----------
|
||||
|
@ -382,6 +387,7 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
|
|||
# Implement the research use-case where we mask labels from labelled dataset.
|
||||
datamodule = ActiveLearningDataModule(
|
||||
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
|
||||
initial_num_labels=5,
|
||||
val_split=0.1,
|
||||
)
|
||||
|
||||
|
@ -390,7 +396,8 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
|
|||
torch.nn.Dropout(p=0.1),
|
||||
torch.nn.Linear(512, datamodule.num_classes),
|
||||
)
|
||||
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())
|
||||
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())
|
||||
|
||||
|
||||
# 3.1 Create the trainer
|
||||
trainer = flash.Trainer(max_epochs=3)
|
||||
|
@ -410,7 +417,7 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
|
|||
# 5. Save the model!
|
||||
trainer.save_checkpoint("image_classification_model.pt")
|
||||
|
||||
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
|
||||
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py>`_.
|
||||
|
||||
|
||||
----------
|
||||
|
|
|
@ -5,7 +5,7 @@ Loops (Advanced)
|
|||
|
||||
.. _persisting loop state:
|
||||
|
||||
Persisting the state of loops
|
||||
Persisting the State of Loops
|
||||
-----------------------------
|
||||
|
||||
.. note::
|
||||
|
@ -18,7 +18,7 @@ A powerful property of the class-based loop interface is that it can own an inte
|
|||
Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place.
|
||||
This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5.
|
||||
|
||||
The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :class:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state.
|
||||
The two hooks :meth:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and :meth:`~pytorch_lightning.loops.base.Loop.on_load_checkpoint` function very similarly to how LightningModules and Callbacks save and load state.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ PrecisionPlugin
|
|||
|
||||
|
||||
Futhermore, for multi-node training Lightning provides cluster environment plugins that allow the advanced user
|
||||
to configure Lighting to integrate with a :ref:`custom-cluster`.
|
||||
to configure Lightning to integrate with a :ref:`custom-cluster`.
|
||||
|
||||
|
||||
.. image:: ../_static/images/accelerator/overview.svg
|
||||
|
|
|
@ -59,7 +59,7 @@ class YieldLoop(OptimizerLoop):
|
|||
def on_run_start(self, batch, optimizers, batch_idx):
|
||||
super().on_run_start(batch, optimizers, batch_idx)
|
||||
if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step):
|
||||
raise MisconfigurationException("The LightingModule does not yield anything in the `training_step`.")
|
||||
raise MisconfigurationException("The `LightningModule` does not yield anything in the `training_step`.")
|
||||
assert self.trainer.lightning_module.automatic_optimization
|
||||
|
||||
# We request the generator once and save it for later
|
||||
|
|
|
@ -169,6 +169,10 @@ class Callback(abc.ABC):
|
|||
"""Called when the training batch begins."""
|
||||
pass
|
||||
|
||||
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when the training batch ends."""
|
||||
pass
|
||||
|
||||
def on_validation_batch_start(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
|
@ -223,10 +227,6 @@ class Callback(abc.ABC):
|
|||
"""Called when the predict batch ends."""
|
||||
pass
|
||||
|
||||
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when the training batch ends."""
|
||||
pass
|
||||
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when the train begins."""
|
||||
pass
|
||||
|
|
|
@ -473,7 +473,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
for key, val in metrics.items():
|
||||
# `step` is ignored because Neptune expects strictly increasing step values which
|
||||
# Lighting does not always guarantee.
|
||||
# Lightning does not always guarantee.
|
||||
self.experiment[key].log(val)
|
||||
|
||||
@rank_zero_only
|
||||
|
|
|
@ -2016,7 +2016,7 @@ class Trainer(
|
|||
"""Attach a custom fit loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`.
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`.
|
||||
"""
|
||||
loop.trainer = self
|
||||
self._fit_loop = loop
|
||||
|
@ -2030,7 +2030,7 @@ class Trainer(
|
|||
"""Attach a custom validation loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
|
||||
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
|
||||
"""
|
||||
loop.trainer = self
|
||||
|
|
Loading…
Reference in New Issue