.. role:: hidden :class: hidden-section .. _lightning_module: LightningModule =============== A :class:`~LightningModule` organizes your PyTorch code into 5 sections - Computations (init). - Train loop (training_step) - Validation loop (validation_step) - Test loop (test_step) - Optimizers (configure_optimizers) | .. raw:: html | Notice a few things. 1. It's the SAME code. 2. The PyTorch code IS NOT abstracted - just organized. 3. All the other code that's not in the :class:`~LightningModule` has been automated for you by the trainer. | .. code-block:: python net = Net() trainer = Trainer() trainer.fit(net) 4. There are no .cuda() or .to() calls... Lightning does these for you. | .. code-block:: python # don't do in lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) # do this instead x = x # leave it alone! # or to init a new tensor new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x) 5. Lightning by default handles the distributed sampler for you. | .. code-block:: python # Don't do in Lightning... data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler) # do this instead data = MNIST(...) DataLoader(data) 6. A :class:`~LightningModule` is a :class:`torch.nn.Module` but with added functionality. Use it as such! | .. code-block:: python net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x) Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let's be real, you probably should do anyway). ------------ Minimal Example --------------- Here are the only required methods. .. code-block:: python import pytorch_lightning as pl class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) Which you can train by doing: .. code-block:: python train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) trainer = pl.Trainer() model = LitModel() trainer.fit(model, train_loader) The LightningModule has many convenience methods, but the core ones you need to know about are: .. list-table:: :widths: 50 50 :header-rows: 1 * - Name - Description * - init - Define computations here * - forward - Use for inference only (separate from training_step) * - training_step - the full training loop * - validation_step - the full validation loop * - test_step - the full test loop * - configure_optimizers - define optimizers and LR schedulers ---------- Training -------- Training loop ^^^^^^^^^^^^^ To add a training loop use the `training_step` method .. code-block:: python class LitClassifier(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) return loss Under the hood, Lightning does the following (pseudocode): .. code-block:: python # put model in train mode model.train() torch.set_grad_enabled(True) losses = [] for batch in train_dataloader: # forward loss = training_step(batch) losses.append(loss.detach()) # clear gradients optimizer.zero_grad() # backward loss.backward() # update parameters optimizer.step() Training epoch-level metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you want to calculate epoch-level metrics and log them, use the `.log` method .. code-block:: python def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) # logs metrics for each training_step, # and the average across the epoch, to the progress bar and logger self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss The `.log` object automatically reduces the requested metrics across the full epoch. Here's the pseudocode of what it does under the hood: .. code-block:: python outs = [] for batch in train_dataloader: # forward out = training_step(val_batch) outs.append(out) # clear gradients optimizer.zero_grad() # backward loss.backward() # update parameters optimizer.step() epoch_metric = torch.mean(torch.stack([x["train_loss"] for x in outs])) Train epoch-level operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself. .. code-block:: python def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) preds = ... return {"loss": loss, "other_stuff": preds} def training_epoch_end(self, training_step_outputs): for pred in training_step_outputs: ... The matching pseudocode is: .. code-block:: python outs = [] for batch in train_dataloader: # forward out = training_step(val_batch) outs.append(out) # clear gradients optimizer.zero_grad() # backward loss.backward() # update parameters optimizer.step() training_epoch_end(outs) Training with DataParallel ~~~~~~~~~~~~~~~~~~~~~~~~~~ When training using an `accelerator` that splits data from each batch across GPUs, sometimes you might need to aggregate them on the main GPU for processing (dp, or ddp2). In this case, implement the `training_step_end` method .. code-block:: python def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return {"loss": loss, "pred": pred} def training_step_end(self, batch_parts): # predictions from each GPU predictions = batch_parts["pred"] # losses from each GPU losses = batch_parts["loss"] gpu_0_prediction = predictions[0] gpu_1_prediction = predictions[1] # do something with both outputs return (losses[0] + losses[1]) / 2 def training_epoch_end(self, training_step_outputs): for out in training_step_outputs: ... The full pseudocode that lighting does under the hood is: .. code-block:: python outs = [] for train_batch in train_dataloader: batches = split_batch(train_batch) dp_outs = [] for sub_batch in batches: # 1 dp_out = training_step(sub_batch) dp_outs.append(dp_out) # 2 out = training_step_end(dp_outs) outs.append(out) # do something with the outputs for all batches # 3 training_epoch_end(outs) ------------------ Validation loop ^^^^^^^^^^^^^^^ To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`: .. code-block:: python class LitModel(pl.LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss) Under the hood, Lightning does the following: .. code-block:: python # ... for batch in train_dataloader: loss = model.training_step() loss.backward() # ... if validate_at_some_point: # disable grads + batchnorm + dropout torch.set_grad_enabled(False) model.eval() # ----------------- VAL LOOP --------------- for val_batch in model.val_dataloader: val_out = model.validation_step(val_batch) # ----------------- VAL LOOP --------------- # enable grads + batchnorm + dropout torch.set_grad_enabled(True) model.train() Validation epoch-level metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`. .. code-block:: python def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return pred def validation_epoch_end(self, validation_step_outputs): for pred in validation_step_outputs: ... Validating with DataParallel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When training using an `accelerator` that splits data from each batch across GPUs, sometimes you might need to aggregate them on the main GPU for processing (dp, or ddp2). In this case, implement the `validation_step_end` method .. code-block:: python def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return {"loss": loss, "pred": pred} def validation_step_end(self, batch_parts): # predictions from each GPU predictions = batch_parts["pred"] # losses from each GPU losses = batch_parts["loss"] gpu_0_prediction = predictions[0] gpu_1_prediction = predictions[1] # do something with both outputs return (losses[0] + losses[1]) / 2 def validation_epoch_end(self, validation_step_outputs): for out in validation_step_outputs: ... The full pseudocode that lighting does under the hood is: .. code-block:: python outs = [] for batch in dataloader: batches = split_batch(batch) dp_outs = [] for sub_batch in batches: # 1 dp_out = validation_step(sub_batch) dp_outs.append(dp_out) # 2 out = validation_step_end(dp_outs) outs.append(out) # do something with the outputs for all batches # 3 validation_epoch_end(outs) ---------------- Test loop ^^^^^^^^^ The process for adding a test loop is the same as the process for adding a validation loop. Please refer to the section above for details. The only difference is that the test loop is only called when `.test()` is used: .. code-block:: python model = Model() trainer = Trainer() trainer.fit() # automatically loads the best weights for you trainer.test(model) There are two ways to call `test()`: .. code-block:: python # call after training trainer = Trainer() trainer.fit(model) # automatically auto-loads the best weights trainer.test(dataloaders=test_dataloader) # or call with pretrained model model = MyLightningModule.load_from_checkpoint(PATH) trainer = Trainer() trainer.test(model, dataloaders=test_dataloader) ---------- Inference --------- For research, LightningModules are best structured as systems. .. code-block:: python import pytorch_lightning as pl import torch from torch import nn class Autoencoder(pl.LightningModule): def __init__(self, latent_dim=2): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) def training_step(self, batch, batch_idx): x, _ = batch # encode x = x.view(x.size(0), -1) z = self.encoder(x) # decode recons = self.decoder(z) # reconstruction reconstruction_loss = nn.functional.mse_loss(recons, x) return reconstruction_loss def validation_step(self, batch, batch_idx): x, _ = batch x = x.view(x.size(0), -1) z = self.encoder(x) recons = self.decoder(z) reconstruction_loss = nn.functional.mse_loss(recons, x) self.log("val_reconstruction", reconstruction_loss) def predict_step(self, batch, batch_idx, dataloader_idx): x, _ = batch # encode # for predictions, we could return the embedding or the reconstruction or both based on our need. x = x.view(x.size(0), -1) return self.encoder(x) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002) Which can be trained like this: .. code-block:: python autoencoder = Autoencoder() trainer = pl.Trainer(gpus=1) trainer.fit(autoencoder, train_dataloader, val_dataloader) This simple model generates examples that look like this (the encoders and decoders are too weak) .. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png :width: 300 The methods above are part of the lightning interface: - training_step - validation_step - test_step - predict_step - configure_optimizers Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code. .. code-block:: python class Autoencoder(pl.LightningModule): def __init__(self, latent_dim=2): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) def training_step(self, batch, batch_idx): loss = self.shared_step(batch) return loss def validation_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("val_loss", loss) def shared_step(self, batch): x, _ = batch # encode x = x.view(x.size(0), -1) z = self.encoder(x) # decode recons = self.decoder(z) # loss return nn.functional.mse_loss(recons, x) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002) We create a new method called `shared_step` that all loops can use. This method name is arbitrary and NOT reserved. Inference in research ^^^^^^^^^^^^^^^^^^^^^ In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule. .. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager. .. code-block:: python class Autoencoder(pl.LightningModule): def forward(self, x): return self.decoder(x) model = Autoencoder() model.eval() with torch.no_grad(): reconstruction = model(embedding) The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation: .. code-block:: python class Seq2Seq(pl.LightningModule): def forward(self, x): embeddings = self(x) hidden_states = self.encoder(embeddings) for h in hidden_states: # decode ... return decoded In the case where you want to scale your inference, you should be using :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`. .. code-block:: python class Autoencoder(pl.LightningModule): def forward(self, x): return self.decoder(x) def predict_step(self, batch, batch_idx, dataloader_idx=None): # this calls forward return self(batch) data_module = ... model = Autoencoder() trainer = Trainer(gpus=2) trainer.predict(model, data_module) Inference in production ^^^^^^^^^^^^^^^^^^^^^^^ For cases like production, you might want to iterate different models inside a LightningModule. .. code-block:: python import pytorch_lightning as pl from pytorch_lightning.metrics import functional as FM class ClassificationTask(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) return loss def validation_step(self, batch, batch_idx): loss, acc = self._shared_eval_step(batch, batch_idx) metrics = {"val_acc": acc, "val_loss": loss} self.log_dict(metrics) return metrics def test_step(self, batch, batch_idx): loss, acc = self._shared_eval_step(batch, batch_idx) metrics = {"test_acc": acc, "test_loss": loss} self.log_dict(metrics) return metrics def _shared_eval_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) acc = FM.accuracy(y_hat, y) return loss, acc def predict_step(self, batch, batch_idx, dataloader_idx): x, y = batch y_hat = self.model(x) def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=0.02) Then pass in any arbitrary model to be fit with this task .. code-block:: python for model in [resnet50(), vgg16(), BidirectionalRNN()]: task = ClassificationTask(model) trainer = Trainer(gpus=2) trainer.fit(task, train_dataloader, val_dataloader) Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL. .. code-block:: python class GANTask(pl.LightningModule): def __init__(self, generator, discriminator): super().__init__() self.generator = generator self.discriminator = discriminator ... When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a `LightningModule`. - You can export to onnx. - Or trace using Jit. - or run in the python runtime. .. code-block:: python task = ClassificationTask(model) trainer = Trainer(gpus=2) trainer.fit(task, train_dataloader, val_dataloader) # use model after training or load weights and drop into the production system model.eval() y_hat = model(x) ----------- LightningModule API ------------------- Methods ^^^^^^^ configure_callbacks ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_callbacks :noindex: configure_optimizers ~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers :noindex: forward ~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.forward :noindex: freeze ~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.freeze :noindex: log ~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.log :noindex: log_dict ~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict :noindex: manual_backward ~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward :noindex: print ~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.print :noindex: predict_step ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.predict_step :noindex: save_hyperparameters ~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters :noindex: test_step ~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step :noindex: test_step_end ~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step_end :noindex: test_epoch_end ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.test_epoch_end :noindex: to_onnx ~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.to_onnx :noindex: to_torchscript ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.to_torchscript :noindex: training_step ~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step :noindex: training_step_end ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step_end :noindex: training_epoch_end ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end :noindex: unfreeze ~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.unfreeze :noindex: validation_step ~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step :noindex: validation_step_end ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step_end :noindex: validation_epoch_end ~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end :noindex: ------------ Properties ^^^^^^^^^^ These are properties available in a LightningModule. ----------- current_epoch ~~~~~~~~~~~~~ The current epoch .. code-block:: python def training_step(self): if self.current_epoch == 0: ... ------------- device ~~~~~~ The device the module is on. Use it to keep your code device agnostic .. code-block:: python def training_step(self): z = torch.rand(2, 3, device=self.device) ------------- global_rank ~~~~~~~~~~~ The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs, the 4th GPU on the 10th machine has global_rank = 39 ------------- global_step ~~~~~~~~~~~ The current step (does not reset each epoch) .. code-block:: python def training_step(self): self.logger.experiment.log_image(..., step=self.global_step) ------------- hparams ~~~~~~~ The arguments saved by calling ``save_hyperparameters`` passed through ``__init__()`` could be accessed by the ``hparams`` attribute. .. code-block:: python def __init__(self, learning_rate): self.save_hyperparameters() def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.learning_rate) -------------- logger ~~~~~~ The current logger being used (tensorboard or other supported logger) .. code-block:: python def training_step(self): # the generic logger (same no matter if tensorboard or other supported logger) self.logger # the particular logger tensorboard_logger = self.logger.experiment -------------- local_rank ~~~~~~~~~~~ The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine has local_rank = 0. ----------- precision ~~~~~~~~~ The type of precision used: .. code-block:: python def training_step(self): if self.precision == 16: ... ------------ trainer ~~~~~~~ Pointer to the trainer .. code-block:: python def training_step(self): max_steps = self.trainer.max_steps any_flag = self.trainer.any_flag ------------ use_amp ~~~~~~~ True if using Automatic Mixed Precision (AMP) -------------- automatic_optimization ~~~~~~~~~~~~~~~~~~~~~~ When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used. See :ref:`manual optimization` for details. .. code-block:: python def __init__(self): self.automatic_optimization = False def training_step(self, batch, batch_idx): opt = self.optimizers(use_pl_optimizer=True) loss = ... opt.zero_grad() self.manual_backward(loss) opt.step() This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. .. code-block:: python def __init__(self): self.automatic_optimization = False def training_step(self, batch, batch_idx): # access your optimizers with use_pl_optimizer=False. Default is True opt_a, opt_b = self.optimizers(use_pl_optimizer=True) gen_loss = ... opt_a.zero_grad() self.manual_backward(gen_loss) opt_a.step() disc_loss = ... opt_b.zero_grad() self.manual_backward(disc_loss) opt_b.step() -------------- example_input_array ~~~~~~~~~~~~~~~~~~~ Set and access example_input_array which is basically a single batch. .. code-block:: python def __init__(self): self.example_input_array = ... self.generator = ... def on_train_epoch_end(self): # generate some images using the example_input_array gen_images = self.generator(self.example_input_array) -------------- datamodule ~~~~~~~~~~ Set or access your datamodule. .. code-block:: python def configure_optimizers(self): num_training_samples = len(self.trainer.datamodule.train_dataloader()) ... -------------- model_size ~~~~~~~~~~ Get the model file size (in megabytes) using ``self.model_size`` inside LightningModule. -------------- truncated_bptt_steps ^^^^^^^^^^^^^^^^^^^^ Truncated back prop breaks performs backprop every k steps of a much longer sequence. This is made possible by passing training batches splitted along the time-dimensions into splits of size k to the ``training_step``. In order to keep the same forward propagation behavior, all hidden states should be kept in-between each time-dimension split. If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it. (`Williams et al. "An efficient gradient-based algorithm for on-line training of recurrent network trajectories." `_) `Tutorial `_ .. testcode:: python from pytorch_lightning import LightningModule class MyModel(LightningModule): def __init__(self, input_size, hidden_size, num_layers): super().__init__() # batch_first has to be set to True self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, ) ... # Important: This property activates truncated backpropagation through time # Setting this value to 2 splits the batch into sequences of size 2 self.truncated_bptt_steps = 2 # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): x, y = batch # the training step must be updated to accept a ``hiddens`` argument # hiddens are the hiddens from the previous truncated backprop step out, hiddens = self.lstm(x, hiddens) ... return {"loss": ..., "hiddens": hiddens} Lightning takes care of splitting your batch along the time-dimension. It is assumed to be the second dimension of your batches. Therefore, in the example above we have set ``batch_first=True``. .. code-block:: python # we use the second as the time dimension # (batch, time, ...) sub_batch = batch[0, 0:t, ...] To modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: .. testcode:: python class LitMNIST(LightningModule): def tbptt_split_batch(self, batch, split_size): # do your own splitting on the batch return splits -------------- Hooks ^^^^^ This is the pseudocode to describe the structure of :meth:`~pytorch_lightning.trainer.Trainer.fit`. The inputs and outputs of each function are not represented for simplicity. Please check each function's API reference for more information. .. code-block:: python def fit(self): if global_rank == 0: # prepare data is called on GLOBAL_ZERO only prepare_data() configure_callbacks() with parallel(devices): # devices can be GPUs, TPUs, ... train_on_device(model) def train_on_device(model): # called PER DEVICE on_fit_start() setup("fit") configure_optimizers() on_pretrain_routine_start() on_pretrain_routine_end() # the sanity check runs here on_train_start() for epoch in epochs: train_loop() on_train_end() on_fit_end() teardown("fit") def train_loop(): on_epoch_start() on_train_epoch_start() for batch in train_dataloader(): on_train_batch_start() on_before_batch_transfer() transfer_batch_to_device() on_after_batch_transfer() training_step() on_before_zero_grad() optimizer_zero_grad() on_before_backward() backward() on_after_backward() on_before_optimizer_step() configure_gradient_clipping() optimizer_step() on_train_batch_end() if should_check_val: val_loop() # end training epoch training_epoch_end() on_train_epoch_end() on_epoch_end() def val_loop(): on_validation_model_eval() # calls `model.eval()` torch.set_grad_enabled(False) on_validation_start() on_epoch_start() on_validation_epoch_start() for batch in val_dataloader(): on_validation_batch_start() on_before_batch_transfer() transfer_batch_to_device() on_after_batch_transfer() validation_step() on_validation_batch_end() validation_epoch_end() on_validation_epoch_end() on_epoch_end() on_validation_end() # set up for train on_validation_model_train() # calls `model.train()` torch.set_grad_enabled(True) backward ~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.backward :noindex: on_before_backward ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward :noindex: on_after_backward ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad :noindex: on_fit_start ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_start :noindex: on_fit_end ~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_end :noindex: on_load_checkpoint ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint :noindex: on_train_start ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start :noindex: on_train_end ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end :noindex: on_validation_start ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start :noindex: on_validation_end ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start :noindex: on_pretrain_routine_end ~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end :noindex: on_test_batch_start ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start :noindex: on_test_batch_end ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end :noindex: on_test_epoch_start ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start :noindex: on_test_epoch_end ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: on_test_start ~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_start :noindex: on_test_end ~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start :noindex: on_train_batch_end ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end :noindex: on_epoch_start ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start :noindex: on_epoch_end ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end :noindex: on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start :noindex: on_train_epoch_end ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end :noindex: on_validation_batch_start ~~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start :noindex: on_validation_batch_end ~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end :noindex: on_validation_epoch_start ~~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start :noindex: on_validation_epoch_end ~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: on_post_move_to_device ~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device :noindex: on_validation_model_eval ~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval :noindex: on_validation_model_train ~~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train :noindex: on_test_model_eval ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval :noindex: on_test_model_train ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train :noindex: on_before_optimizer_step ~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step :noindex: configure_gradient_clipping ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping :noindex: optimizer_step ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_step :noindex: optimizer_zero_grad ~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad :noindex: prepare_data ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.prepare_data :noindex: setup ~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.setup :noindex: tbptt_split_batch ~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch :noindex: teardown ~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown :noindex: train_dataloader ~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader :noindex: transfer_batch_to_device ~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device :noindex: on_before_batch_transfer ~~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer :noindex: on_after_batch_transfer ~~~~~~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer :noindex: add_to_queue ~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.add_to_queue :noindex: get_from_queue ~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_from_queue :noindex: