diff --git a/docs/source/lightning_module.rst b/docs/source/lightning_module.rst index 011ccb5f68..509811c836 100644 --- a/docs/source/lightning_module.rst +++ b/docs/source/lightning_module.rst @@ -122,8 +122,306 @@ Which you can train by doing: ---------- -LightningModule for research ----------------------------- +Training +-------- + +Training loop +^^^^^^^^^^^^^ +To add a training loop use the `training_step` method + +.. code-block:: python + + class LitClassifier(pl.LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + return loss + +Under the hood, Lightning does the following (pseudocode): + +.. code-block:: python + + # put model in train mode + model.train() + torch.set_grad_enabled(True) + + outs = [] + for batch in train_dataloader: + # forward + out = training_step(val_batch) + + # backward + loss.backward() + + # apply and clear grads + optimizer.step() + optimizer.zero_grad() + +Training epoch-level metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you want to calculate epoch-level metrics and log them, use the `.log` method + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + + # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + +The `.log` object automatically reduces the requested metrics across the full epoch. +Here's the pseudocode of what it does under the hood: + +.. code-block:: python + + outs = [] + for batch in train_dataloader: + # forward + out = training_step(val_batch) + + # backward + loss.backward() + + # apply and clear grads + optimizer.step() + optimizer.zero_grad() + + epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs])) + +Train epoch-level operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + preds = ... + return {'loss': loss, 'other_stuff': preds} + + def training_epoch_end(self, training_step_outputs): + for pred in training_step_outputs: + # do something + +The matching pseudocode is: + +.. code-block:: python + + outs = [] + for batch in train_dataloader: + # forward + out = training_step(val_batch) + + # backward + loss.backward() + + # apply and clear grads + optimizer.step() + optimizer.zero_grad() + + training_epoch_end(outs) + +Training with DataParallel +~~~~~~~~~~~~~~~~~~~~~~~~~~ +When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the master GPU for processing (dp, or ddp2). + +In this case, implement the `training_step_end` method + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return {'loss': loss, 'pred': pred} + + def training_step_end(self, batch_parts): + gpu_0_prediction = batch_parts.pred[0]['pred'] + gpu_1_prediction = batch_parts.pred[1]['pred'] + + # do something with both outputs + return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 + + def training_epoch_end(self, training_step_outputs): + for out in training_step_outputs: + # do something with preds + +The full pseudocode that lighting does under the hood is: + +.. code-block:: python + + outs = [] + for train_batch in train_dataloader: + batches = split_batch(train_batch) + dp_outs = [] + for sub_batch in batches: + # 1 + dp_out = training_step(sub_batch) + dp_outs.append(dp_out) + + # 2 + out = training_step_end(dp_outs) + outs.append(out) + + # do something with the outputs for all batches + # 3 + training_epoch_end(outs) + +------------------ + +Validation loop +^^^^^^^^^^^^^^^ +To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`: + +.. code-block:: python + + class LitModel(pl.LightningModule): + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + self.log('val_loss', loss) + +Under the hood, Lightning does the following: + +.. code-block:: python + + # ... + for batch in train_dataloader: + loss = model.training_step() + loss.backward() + # ... + + if validate_at_some_point: + # disable grads + batchnorm + dropout + torch.set_grad_enabled(False) + model.eval() + + # ----------------- VAL LOOP --------------- + for val_batch in model.val_dataloader: + val_out = model.validation_step(val_batch) + # ----------------- VAL LOOP --------------- + + # enable grads + batchnorm + dropout + torch.set_grad_enabled(True) + model.train() + +Validation epoch-level metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`. + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return pred + + def validation_epoch_end(self, validation_step_outputs): + for pred in validation_step_outputs: + # do something with a pred + +Validating with DataParallel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the master GPU for processing (dp, or ddp2). + +In this case, implement the `validation_step_end` method + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return {'loss': loss, 'pred': pred} + + def validation_step_end(self, batch_parts): + gpu_0_prediction = batch_parts.pred[0]['pred'] + gpu_1_prediction = batch_parts.pred[1]['pred'] + + # do something with both outputs + return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 + + def validation_epoch_end(self, validation_step_outputs): + for out in validation_step_outputs: + # do something with preds + +The full pseudocode that lighting does under the hood is: + +.. code-block:: python + + outs = [] + for batch in dataloader: + batches = split_batch(batch) + dp_outs = [] + for sub_batch in batches: + # 1 + dp_out = validation_step(sub_batch) + dp_outs.append(dp_out) + + # 2 + out = validation_step_end(dp_outs) + outs.append(out) + + # do something with the outputs for all batches + # 3 + validation_epoch_end(outs) + +---------------- + +Test loop +^^^^^^^^^ +The process for adding a test loop is the same as the process for adding a validation loop. Please refer to +the section above for details. + +The only difference is that the test loop is only called when `.test()` is used: + +.. code-block:: python + + model = Model() + trainer = Trainer() + trainer.fit() + + # automatically loads the best weights for you + trainer.test(model) + +There are two ways to call `test()`: + +.. code-block:: python + + # call after training + trainer = Trainer() + trainer.fit(model) + + # automatically auto-loads the best weights + trainer.test(test_dataloaders=test_dataloader) + + # or call with pretrained model + model = MyLightningModule.load_from_checkpoint(PATH) + trainer = Trainer() + trainer.test(model, test_dataloaders=test_dataloader) + +---------- + +Inference +--------- For research, LightningModules are best structured as systems. A model (colloquially) refers to something like a resnet or RNN. A system, may be a collection of models. Here @@ -233,7 +531,7 @@ Note that in this case, the train loop and val loop are exactly the same. We can We create a new method called `shared_step` that all loops can use. This method name is arbitrary and NOT reserved. -Inference in Research +Inference in research ^^^^^^^^^^^^^^^^^^^^^ In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule. @@ -258,10 +556,8 @@ such as text generation: ... return decoded ---------------------- - -LightningModule for production ------------------------------- +Inference in production +^^^^^^^^^^^^^^^^^^^^^^^ For cases like production, you might want to iterate different models inside a LightningModule. .. code-block:: python @@ -322,8 +618,6 @@ Tasks can be arbitrarily complex such as implementing GAN training, self-supervi self.discriminator = discriminator ... -Inference in production -^^^^^^^^^^^^^^^^^^^^^^^ When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a `LightningModule`. @@ -342,360 +636,55 @@ a `LightningModule`. model.eval() y_hat = model(x) - -Training loop -------------- -To add a training loop use the `training_step` method - -.. code-block:: python - - class LitClassifier(pl.LightningModule): - - def __init__(self, model): - super().__init__() - self.model = model - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - return loss - -Under the hood, Lightning does the following (pseudocode): - -.. code-block:: python - - # put model in train mode - model.train() - torch.set_grad_enabled(True) - - outs = [] - for batch in train_dataloader: - # forward - out = training_step(val_batch) - - # backward - loss.backward() - - # apply and clear grads - optimizer.step() - optimizer.zero_grad() - -Training epoch-level metrics -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you want to calculate epoch-level metrics and log them, use the `.log` method - -.. code-block:: python - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - - # logs metrics for each training_step, and the average across the epoch, to the progress bar and logger - self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - -The `.log` object automatically reduces the requested metrics across the full epoch. -Here's the pseudocode of what it does under the hood: - -.. code-block:: python - - outs = [] - for batch in train_dataloader: - # forward - out = training_step(val_batch) - - # backward - loss.backward() - - # apply and clear grads - optimizer.step() - optimizer.zero_grad() - - epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs])) - -Train epoch-level operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself. - -.. code-block:: python - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - preds = ... - return {'loss': loss, 'other_stuff': preds} - - def training_epoch_end(self, training_step_outputs): - for pred in training_step_outputs: - # do something - -The matching pseudocode is: - -.. code-block:: python - - outs = [] - for batch in train_dataloader: - # forward - out = training_step(val_batch) - - # backward - loss.backward() - - # apply and clear grads - optimizer.step() - optimizer.zero_grad() - - training_epoch_end(outs) - -Training with DataParallel -^^^^^^^^^^^^^^^^^^^^^^^^^^ -When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might -need to aggregate them on the master GPU for processing (dp, or ddp2). - -In this case, implement the `training_step_end` method - -.. code-block:: python - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - pred = ... - return {'loss': loss, 'pred': pred} - - def training_step_end(self, batch_parts): - gpu_0_prediction = batch_parts.pred[0]['pred'] - gpu_1_prediction = batch_parts.pred[1]['pred'] - - # do something with both outputs - return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 - - def training_epoch_end(self, training_step_outputs): - for out in training_step_outputs: - # do something with preds - -The full pseudocode that lighting does under the hood is: - -.. code-block:: python - - outs = [] - for train_batch in train_dataloader: - batches = split_batch(train_batch) - dp_outs = [] - for sub_batch in batches: - # 1 - dp_out = training_step(sub_batch) - dp_outs.append(dp_out) - - # 2 - out = training_step_end(dp_outs) - outs.append(out) - - # do something with the outputs for all batches - # 3 - training_epoch_end(outs) - ------------------- - -Validation loop ---------------- -To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`: - -.. code-block:: python - - class LitModel(pl.LightningModule): - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - self.log('val_loss', loss) - -Under the hood, Lightning does the following: - -.. code-block:: python - - # ... - for batch in train_dataloader: - loss = model.training_step() - loss.backward() - # ... - - if validate_at_some_point: - # disable grads + batchnorm + dropout - torch.set_grad_enabled(False) - model.eval() - - # ----------------- VAL LOOP --------------- - for val_batch in model.val_dataloader: - val_out = model.validation_step(val_batch) - # ----------------- VAL LOOP --------------- - - # enable grads + batchnorm + dropout - torch.set_grad_enabled(True) - model.train() - -Validation epoch-level metrics -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`. - -.. code-block:: python - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - pred = ... - return pred - - def validation_epoch_end(self, validation_step_outputs): - for pred in validation_step_outputs: - # do something with a pred - -Validating with DataParallel -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When training using a `distributed_backend` that splits data from each batch across GPUs, sometimes you might -need to aggregate them on the master GPU for processing (dp, or ddp2). - -In this case, implement the `validation_step_end` method - -.. code-block:: python - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, y) - pred = ... - return {'loss': loss, 'pred': pred} - - def validation_step_end(self, batch_parts): - gpu_0_prediction = batch_parts.pred[0]['pred'] - gpu_1_prediction = batch_parts.pred[1]['pred'] - - # do something with both outputs - return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 - - def validation_epoch_end(self, validation_step_outputs): - for out in validation_step_outputs: - # do something with preds - -The full pseudocode that lighting does under the hood is: - -.. code-block:: python - - outs = [] - for batch in dataloader: - batches = split_batch(batch) - dp_outs = [] - for sub_batch in batches: - # 1 - dp_out = validation_step(sub_batch) - dp_outs.append(dp_out) - - # 2 - out = validation_step_end(dp_outs) - outs.append(out) - - # do something with the outputs for all batches - # 3 - validation_epoch_end(outs) - ----------------- - -Test loop ---------- -The process for adding a test loop is the same as the process for adding a validation loop. Please refer to -the section above for details. - -The only difference is that the test loop is only called when `.test()` is used: - -.. code-block:: python - - model = Model() - trainer = Trainer() - trainer.fit() - - # automatically loads the best weights for you - trainer.test(model) - -There are two ways to call `test()`: - -.. code-block:: python - - # call after training - trainer = Trainer() - trainer.fit(model) - - # automatically auto-loads the best weights - trainer.test(test_dataloaders=test_dataloader) - - # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() - trainer.test(model, test_dataloaders=test_dataloader) - ----------- - -Live demo ---------- -Check out this -`COLAB `_ -for a live demo. - ----------- LightningModule API ------------------- -Training loop methods -^^^^^^^^^^^^^^^^^^^^^ +Methods +^^^^^^^ -training_step -~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step - :noindex: - -training_step_end -~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step_end - :noindex: - -training_epoch_end -~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end - :noindex: - ---------------- - -Validation loop methods -^^^^^^^^^^^^^^^^^^^^^^^ - -validation_step -~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step - :noindex: - -validation_step_end -~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step_end - :noindex: - -validation_epoch_end +configure_optimizers ~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers :noindex: ----------------- +forward +~~~~~~~ -test loop methods -^^^^^^^^^^^^^^^^^ +.. automethod:: pytorch_lightning.core.lightning.LightningModule.forward + :noindex: + +freeze +~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.freeze + :noindex: + +log +~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.log + :noindex: + +log_dict +~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict + :noindex: + +print +~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.print + :noindex: + +save_hyperparameters +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters + :noindex: test_step ~~~~~~~~~ @@ -715,78 +704,6 @@ test_epoch_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.test_epoch_end :noindex: --------------- - -configure_optimizers -^^^^^^^^^^^^^^^^^^^^ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers - :noindex: - --------------- - -Manual optimization -^^^^^^^^^^^^^^^^^^^ -Use these methods when doing manual optimization - -manual_backward -~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward - :noindex: - -Convenience methods -^^^^^^^^^^^^^^^^^^^ -Use these methods for convenience - -print -~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.print - :noindex: - -save_hyperparameters -~~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters - :noindex: - ------------- - -Logging methods -^^^^^^^^^^^^^^^ -Use these methods to interact with the loggers - -log -~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.log - :noindex: - -log_dict -~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict - :noindex: - ------------- - -Inference methods -^^^^^^^^^^^^^^^^^ -Use these hooks for inference with a lightning module - -forward -~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.forward - :noindex: - -freeze -~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.freeze - :noindex: - to_onnx ~~~~~~~ @@ -799,12 +716,47 @@ to_torchscript .. automethod:: pytorch_lightning.core.lightning.LightningModule.to_torchscript :noindex: +training_step +~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step + :noindex: + +training_step_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step_end + :noindex: + +training_epoch_end +~~~~~~~~~~~~~~~~~~ +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end + :noindex: + unfreeze ~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.unfreeze :noindex: +validation_step +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step + :noindex: + +validation_step_end +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step_end + :noindex: + +validation_epoch_end +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end + :noindex: + ------------ Properties @@ -950,10 +902,7 @@ True if using TPUs -------------- Hooks ------ - -Hook lifecycle pseudocode -^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^ This is the pseudocode to describe how all the hooks are called during a call to `.fit()` .. code-block:: python @@ -1029,11 +978,11 @@ This is the pseudocode to describe how all the hooks are called during a call to model.train() torch.set_grad_enabled(True) +backward +~~~~~~~~ -Advanced hooks -^^^^^^^^^^^^^^ -Use these hooks to modify advanced functionality - +.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward + :noindex: get_progress_bar_dict ~~~~~~~~~~~~~~~~~~~~~ @@ -1041,70 +990,10 @@ get_progress_bar_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: -tbptt_split_batch -~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch - :noindex: - -Checkpoint hooks -^^^^^^^^^^^^^^^^ -These hooks allow you to modify checkpoints - -on_load_checkpoint -~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint - :noindex: - -on_save_checkpoint -~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint - :noindex: - -------------- - -Data hooks -^^^^^^^^^^ -Use these hooks if you want to couple a LightningModule to a dataset. - -.. note:: The same collection of hooks is available in a DataModule class to decouple the data from the model. - -train_dataloader -~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader - :noindex: - -val_dataloader -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader - :noindex: - -test_dataloader +manual_backward ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader - :noindex: - -prepare_data -~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.prepare_data - :noindex: - ------------- - -Optimization hooks -^^^^^^^^^^^^^^^^^^ -These are hooks related to the optimization procedure. - -backward -~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward +.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward :noindex: on_after_backward @@ -1118,22 +1007,6 @@ on_before_zero_grad .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad :noindex: -optimizer_step -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_step - :noindex: - -optimizer_zero_grad -~~~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad - :noindex: - -Training lifecycle hooks -^^^^^^^^^^^^^^^^^^^^^^^^^ -These hooks are called during training - on_fit_start ~~~~~~~~~~~~ @@ -1146,6 +1019,20 @@ on_fit_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_end :noindex: + +on_load_checkpoint +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint + :noindex: + +on_save_checkpoint +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint + :noindex: + + on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1182,6 +1069,7 @@ on_test_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: + on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1230,18 +1118,60 @@ on_validation_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: +optimizer_step +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_step + :noindex: + +optimizer_zero_grad +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad + :noindex: + +prepare_data +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.prepare_data + :noindex: + setup ~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup :noindex: +tbptt_split_batch +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch + :noindex: + teardown ~~~~~~~~ .. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown :noindex: +train_dataloader +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader + :noindex: + +val_dataloader +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader + :noindex: + +test_dataloader +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader + :noindex: + transfer_batch_to_device ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 357469e6d5..f6189244fa 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -17,12 +17,6 @@ on Google Cloud (GCP), Google Colab and Kaggle Environments. For more informatio ---------------- -Live demo ----------- -Check out this `Google Colab `_ to see how to train MNIST on TPUs. - ----------------- - TPU Terminology --------------- A TPU is a Tensor processing unit. Each TPU has 8 cores where each