docs cleaning - testcode (#5595)
* testcode - python * revert * simple * testcode @rst * pl * fix * pip * update * conf * conf * nn. * typo
This commit is contained in:
parent
c3587d39da
commit
f782230412
|
@ -18,7 +18,10 @@ references:
|
|||
pyenv global 3.7.3
|
||||
python --version
|
||||
pip install -r requirements/docs.txt
|
||||
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
|
||||
pip list
|
||||
cd docs
|
||||
make clean
|
||||
make html --jobs 2 SPHINXOPTS="-W"
|
||||
|
||||
checkout_ml_testing: &checkout_ml_testing
|
||||
run:
|
||||
|
|
|
@ -25,9 +25,9 @@ PATH_HERE = os.path.abspath(os.path.dirname(__file__))
|
|||
PATH_ROOT = os.path.join(PATH_HERE, '..', '..')
|
||||
sys.path.insert(0, os.path.abspath(PATH_ROOT))
|
||||
|
||||
builtins.__LIGHTNING_SETUP__ = True
|
||||
|
||||
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True))
|
||||
if SPHINX_MOCK_REQUIREMENTS:
|
||||
builtins.__LIGHTNING_SETUP__ = True
|
||||
|
||||
import pytorch_lightning # noqa: E402
|
||||
|
||||
|
@ -360,7 +360,10 @@ doctest_global_setup = """
|
|||
import importlib
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.utilities import (
|
||||
_NATIVE_AMP_AVAILABLE,
|
||||
_APEX_AVAILABLE,
|
||||
|
@ -369,6 +372,5 @@ from pytorch_lightning.utilities import (
|
|||
)
|
||||
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
|
||||
|
||||
|
||||
"""
|
||||
coverage_skip_undoc_in_source = True
|
||||
|
|
|
@ -24,8 +24,8 @@ Move the model architecture and forward pass to your :ref:`lightning_module`.
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = torch.nn.Linear(128, 10)
|
||||
self.layer_1 = nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
|
|
|
@ -49,7 +49,7 @@ To enable it:
|
|||
|
||||
- You can customize the callbacks behaviour by changing its parameters.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_accuracy',
|
||||
|
|
|
@ -161,9 +161,9 @@ improve readability and reproducibility.
|
|||
def __init__(self, hparams, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
||||
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
||||
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
|
||||
self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
||||
self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
||||
self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10)
|
||||
def train_dataloader(self):
|
||||
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
|
||||
|
||||
|
@ -182,9 +182,9 @@ improve readability and reproducibility.
|
|||
super().__init__()
|
||||
self.save_hyperparameters(conf)
|
||||
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
||||
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
||||
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
|
||||
self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
||||
self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
||||
self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10)
|
||||
|
||||
conf = OmegaConf.create(...)
|
||||
model = LitMNIST(conf)
|
||||
|
@ -225,7 +225,7 @@ polluting the ``main.py`` file, the ``LightningModule`` lets you define argument
|
|||
|
||||
def __init__(self, layer_1_dim, **kwargs):
|
||||
super().__init__()
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)
|
||||
self.layer_1 = nn.Linear(28 * 28, layer_1_dim)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
|
|
|
@ -80,9 +80,9 @@ Let's first start with the model. In this case, we'll design a 3-layer neural ne
|
|||
super().__init__()
|
||||
|
||||
# mnist images are (1, 28, 28) (channels, width, height)
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = torch.nn.Linear(128, 256)
|
||||
self.layer_3 = torch.nn.Linear(256, 10)
|
||||
self.layer_1 = nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = nn.Linear(128, 256)
|
||||
self.layer_3 = nn.Linear(256, 10)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
|
@ -118,7 +118,7 @@ equivalent to a pure PyTorch Module except it has added functionality. However,
|
|||
|
||||
Now we add the training_step which has all our training loop logic
|
||||
|
||||
.. testcode:: python
|
||||
.. testcode::
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
|
||||
|
@ -225,7 +225,7 @@ In this case, it's better to group the full definition of a dataset into a `Data
|
|||
- Val dataloader(s)
|
||||
- Test dataloader(s)
|
||||
|
||||
.. testcode:: python
|
||||
.. testcode::
|
||||
|
||||
class MyDataModule(LightningDataModule):
|
||||
|
||||
|
@ -420,9 +420,9 @@ For clarity, we'll recall that the full LightningModule now looks like this.
|
|||
class LitMNIST(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = torch.nn.Linear(128, 256)
|
||||
self.layer_3 = torch.nn.Linear(256, 10)
|
||||
self.layer_1 = nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = nn.Linear(128, 256)
|
||||
self.layer_3 = nn.Linear(256, 10)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
|
|
|
@ -96,7 +96,7 @@ Here are the only required methods.
|
|||
...
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.l1 = torch.nn.Linear(28 * 28, 10)
|
||||
... self.l1 = nn.Linear(28 * 28, 10)
|
||||
...
|
||||
... def forward(self, x):
|
||||
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
|
|
|
@ -141,9 +141,11 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
|
|||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
metric = pl.metrics.Accuracy()
|
||||
metric = Accuracy()
|
||||
self.train_acc = metric.clone()
|
||||
self.val_acc = metric.clone()
|
||||
self.test_acc = metric.clone()
|
||||
|
@ -164,7 +166,6 @@ be moved to the same device as the input of the metric:
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
|
||||
|
@ -186,13 +187,15 @@ as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and inst
|
|||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
class MyModule(LightningModule):
|
||||
def __init__(self):
|
||||
...
|
||||
# valid ways metrics will be identified as child modules
|
||||
self.metric1 = pl.metrics.Accuracy()
|
||||
self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy())
|
||||
self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()})
|
||||
self.metric1 = Accuracy()
|
||||
self.metric2 = nn.ModuleList(Accuracy())
|
||||
self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# all metrics will be on the same device as the input batch
|
||||
|
@ -222,7 +225,7 @@ from the base ``Metric`` class.
|
|||
|
||||
Example implementation:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
|
||||
|
@ -281,8 +284,8 @@ Example:
|
|||
.. testoutput::
|
||||
:options: +NORMALIZE_WHITESPACE
|
||||
|
||||
{'Accuracy': tensor(0.1250),
|
||||
'Precision': tensor(0.0667),
|
||||
{'Accuracy': tensor(0.1250),
|
||||
'Precision': tensor(0.0667),
|
||||
'Recall': tensor(0.1111)}
|
||||
|
||||
Similarly it can also reduce the amount of code required to log multiple metrics
|
||||
|
|
|
@ -698,7 +698,7 @@ This should be kept within the ``sequential_module`` variable within your ``Ligh
|
|||
class MyModel(LightningModule):
|
||||
def __init__(self):
|
||||
...
|
||||
self.sequential_module = torch.nn.Sequential(my_layers)
|
||||
self.sequential_module = nn.Sequential(my_layers)
|
||||
|
||||
# Split my module across 4 gpus, one layer each
|
||||
model = MyModel()
|
||||
|
|
|
@ -65,7 +65,8 @@ You could also use conda environments
|
|||
|
||||
Import the following:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
:skipif: not _TORCHVISION_AVAILABLE
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
@ -80,9 +81,9 @@ Import the following:
|
|||
Step 1: Define LightningModule
|
||||
******************************
|
||||
|
||||
.. code-block::
|
||||
.. testcode::
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -147,9 +148,9 @@ Under the hood a LightningModule is still just a :class:`torch.nn.Module` that g
|
|||
You can customize any part of training (such as the backward pass) by overriding any
|
||||
of the 20+ hooks found in :ref:`hooks`
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
|
||||
def backward(self, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
@ -259,7 +260,7 @@ or an inner loop, you can turn off automatic optimization and fully control the
|
|||
|
||||
First, turn off automatic optimization:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
trainer = Trainer(automatic_optimization=False)
|
||||
|
||||
|
@ -310,17 +311,21 @@ Option 2: Forward
|
|||
-----------------
|
||||
You can also add a forward method to do predictions however you want.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# ----------------------------------
|
||||
# using the AE to extract embeddings
|
||||
# ----------------------------------
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential()
|
||||
|
||||
def forward(self, x):
|
||||
embedding = self.encoder(x)
|
||||
return embedding
|
||||
|
||||
autoencoder = LitAutoencoder()
|
||||
autoencoder = LitAutoEncoder()
|
||||
autoencoder = autoencoder(torch.rand(1, 28 * 28))
|
||||
|
||||
|
||||
|
@ -329,14 +334,18 @@ You can also add a forward method to do predictions however you want.
|
|||
# ----------------------------------
|
||||
# or using the AE to generate images
|
||||
# ----------------------------------
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.decoder = nn.Sequential()
|
||||
|
||||
def forward(self):
|
||||
z = torch.rand(1, 3)
|
||||
image = self.decoder(z)
|
||||
image = image.view(1, 1, 28, 28)
|
||||
return image
|
||||
|
||||
autoencoder = LitAutoencoder()
|
||||
autoencoder = LitAutoEncoder()
|
||||
image_sample = autoencoder()
|
||||
|
||||
Option 3: Production
|
||||
|
@ -370,15 +379,15 @@ Using CPUs/GPUs/TPUs
|
|||
====================
|
||||
It's trivial to use CPUs, GPUs or TPUs in Lightning. There's **NO NEED** to change your code, simply change the :class:`~pytorch_lightning.trainer.Trainer` options.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# train on CPU
|
||||
trainer = pl.Trainer()
|
||||
trainer = Trainer()
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# train on 8 CPUs
|
||||
trainer = pl.Trainer(num_processes=8)
|
||||
trainer = Trainer(num_processes=8)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -583,7 +592,9 @@ Here's an example adding a not-so-fancy learning rate decay rule:
|
|||
|
||||
.. testcode::
|
||||
|
||||
class DecayLearningRate(pl.callbacks.Callback):
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
class DecayLearningRate(Callback):
|
||||
|
||||
def __init__(self):
|
||||
self.old_lrs = []
|
||||
|
@ -605,10 +616,7 @@ Here's an example adding a not-so-fancy learning rate decay rule:
|
|||
param_group['lr'] = new_lr
|
||||
self.old_lrs[opt_idx] = new_lr_group
|
||||
|
||||
And pass the callback to the Trainer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# And pass the callback to the Trainer
|
||||
decay_callback = DecayLearningRate()
|
||||
trainer = Trainer(callbacks=[decay_callback])
|
||||
|
||||
|
@ -629,9 +637,9 @@ LightningDataModules
|
|||
DataLoaders and data processing code tends to end up scattered around.
|
||||
Make your data code reusable by organizing it into a :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
class MNISTDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
|
@ -679,7 +687,7 @@ tokenizing, processing etc.
|
|||
Now you can simply pass your :class:`~pytorch_lightning.core.datamodule.LightningDataModule` to
|
||||
the :class:`~pytorch_lightning.trainer.Trainer`:
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: python
|
||||
|
||||
# init model
|
||||
model = LitModel()
|
||||
|
@ -702,33 +710,33 @@ Debugging
|
|||
=========
|
||||
Lightning has many tools for debugging. Here is an example of just a few of them:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# use only 10 train batches and 3 val batches
|
||||
trainer = pl.Trainer(limit_train_batches=10, limit_val_batches=3)
|
||||
trainer = Trainer(limit_train_batches=10, limit_val_batches=3)
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# Automatically overfit the sane batch of your model for a sanity test
|
||||
trainer = pl.Trainer(overfit_batches=1)
|
||||
trainer = Trainer(overfit_batches=1)
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# unit test all the code- hits every line of your code once to see if you have bugs,
|
||||
# instead of waiting hours to crash on validation
|
||||
trainer = pl.Trainer(fast_dev_run=True)
|
||||
trainer = Trainer(fast_dev_run=True)
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# train only 20% of an epoch
|
||||
trainer = pl.Trainer(limit_train_batches=0.2)
|
||||
trainer = Trainer(limit_train_batches=0.2)
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# run validation every 25% of a training epoch
|
||||
trainer = pl.Trainer(val_check_interval=0.25)
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
# Profile your code to find speed/memory bottlenecks
|
||||
Trainer(profiler=True)
|
||||
|
|
|
@ -34,7 +34,7 @@ To train a model using multiple nodes, do the following:
|
|||
def main(hparams):
|
||||
model = LightningTemplateModel(hparams)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
trainer = Trainer(
|
||||
gpus=8,
|
||||
num_nodes=4,
|
||||
accelerator='ddp'
|
||||
|
|
|
@ -46,10 +46,10 @@ Here's a LightningModule that defines a model:
|
|||
|
||||
Here's a lightningModule that defines a system:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, encoder: nn.Module = None, decoder: nn.Module = None)
|
||||
class LitModel(LightningModule):
|
||||
def __init__(self, encoder: nn.Module = None, decoder: nn.Module = None):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
@ -74,9 +74,9 @@ sensible defaults in the init so that the user doesn't have to guess.
|
|||
|
||||
Here's an example where a user will have to go hunt through files to figure out how to init this LightningModule.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
class LitModel(LightningModule):
|
||||
def __init__(self, params):
|
||||
self.lr = params.lr
|
||||
self.coef_x = params.coef_x
|
||||
|
@ -85,10 +85,11 @@ Models defined as such leave you with many questions; what is coef_x? is it a st
|
|||
|
||||
Instead, be explicit in your init
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, encoder: nn.Module, coeff_x: float = 0.2, lr: float = 1e-3)
|
||||
class LitModel(LightningModule):
|
||||
def __init__(self, encoder: nn.Module, coeff_x: float = 0.2, lr: float = 1e-3):
|
||||
...
|
||||
|
||||
Now the user doesn't have to guess. Instead they know the value type and the model has a sensible default where the
|
||||
user can see the value immediately.
|
||||
|
|
|
@ -1501,10 +1501,10 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
|
|||
|
||||
.. testcode::
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
# do your own splitting on the batch
|
||||
return splits
|
||||
class LitMNIST(LightningModule):
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
# do your own splitting on the batch
|
||||
return splits
|
||||
|
||||
val_check_interval
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -58,7 +58,7 @@ Example: Imagenet (computer Vision)
|
|||
backbone = models.resnet50(pretrained=True)
|
||||
num_filters = backbone.fc.in_features
|
||||
layers = list(backbone.children())[:-1]
|
||||
self.feature_extractor = torch.nn.Sequential(*layers)
|
||||
self.feature_extractor = nn.Sequential(*layers)
|
||||
|
||||
# use the pretrained model to classify cifar-10 (10 image classes)
|
||||
num_target_classes = 10
|
||||
|
|
|
@ -48,11 +48,11 @@ You can customize the checkpointing behavior to monitor any quantity of your tra
|
|||
3. Initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback, and set `monitor` to be the key of your quantity.
|
||||
4. Pass the callback to the `callbacks` :class:`~pytorch_lightning.trainer.Trainer` flag.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
|
@ -71,11 +71,11 @@ You can customize the checkpointing behavior to monitor any quantity of your tra
|
|||
|
||||
You can also control more advanced options, like `save_top_k`, to save the best k models and the `mode` of the monitored quantity (min/max), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
class LitAutoEncoder(LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
|
|
|
@ -32,26 +32,24 @@ def auto_move_data(fn: Callable) -> Callable:
|
|||
fn: A LightningModule method for which the arguments should be moved to the device
|
||||
the parameters are on.
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
.. code-block:: python
|
||||
# directly in the source code
|
||||
class LitModel(LightningModule):
|
||||
|
||||
# directly in the source code
|
||||
class LitModel(LightningModule):
|
||||
@auto_move_data
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@auto_move_data
|
||||
def forward(self, x):
|
||||
return x
|
||||
# or outside
|
||||
LitModel.forward = auto_move_data(LitModel.forward)
|
||||
|
||||
# or outside
|
||||
LitModel.forward = auto_move_data(LitModel.forward)
|
||||
model = LitModel()
|
||||
model = model.to('cuda')
|
||||
model(torch.zeros(1, 3))
|
||||
|
||||
model = LitModel()
|
||||
model = model.to('cuda')
|
||||
model(torch.zeros(1, 3))
|
||||
|
||||
# input gets moved to device
|
||||
# tensor([[0., 0., 0.]], device='cuda:0')
|
||||
# input gets moved to device
|
||||
# tensor([[0., 0., 0.]], device='cuda:0')
|
||||
|
||||
"""
|
||||
@wraps(fn)
|
||||
|
|
|
@ -391,20 +391,19 @@ class DataHooks:
|
|||
Lightning adds the correct sampler for distributed and arbitrary hardware.
|
||||
There is no need to set it yourself.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
def train_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
|
||||
download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
return loader
|
||||
def train_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
|
||||
download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
return loader
|
||||
|
||||
"""
|
||||
rank_zero_warn(
|
||||
|
@ -443,25 +442,24 @@ class DataHooks:
|
|||
Return:
|
||||
Single or multiple PyTorch DataLoaders.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
def test_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
|
||||
download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
def test_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
|
||||
download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
return loader
|
||||
return loader
|
||||
|
||||
# can also return multiple dataloaders
|
||||
def test_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
# can also return multiple dataloaders
|
||||
def test_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
|
||||
Note:
|
||||
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
|
||||
|
@ -495,25 +493,24 @@ class DataHooks:
|
|||
Return:
|
||||
Single or multiple PyTorch DataLoaders.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
def val_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=False,
|
||||
transform=transform, download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
def val_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=False,
|
||||
transform=transform, download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
return loader
|
||||
return loader
|
||||
|
||||
# can also return multiple dataloaders
|
||||
def val_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
# can also return multiple dataloaders
|
||||
def val_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
|
||||
Note:
|
||||
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
|
||||
|
@ -586,12 +583,11 @@ class CheckpointHooks:
|
|||
checkpoint: Loaded checkpoint
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
# 99% of the time you don't need to implement this method
|
||||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
# 99% of the time you don't need to implement this method
|
||||
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
|
||||
|
||||
Note:
|
||||
Lightning auto-restores global step, epoch, and train state including amp scaling.
|
||||
|
@ -606,12 +602,11 @@ class CheckpointHooks:
|
|||
Args:
|
||||
checkpoint: Checkpoint to be saved
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
# 99% of use cases you don't need to implement this method
|
||||
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
# 99% of use cases you don't need to implement this method
|
||||
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
|
||||
|
||||
Note:
|
||||
Lightning saves all aspects of training (epoch, global step, etc...)
|
||||
|
|
|
@ -161,12 +161,10 @@ class LightningModule(
|
|||
*args: The thing to print. Will be passed to Python's built-in print function.
|
||||
**kwargs: Will be passed to Python's built-in print function.
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def forward(self, x):
|
||||
self.print(x, 'in forward')
|
||||
def forward(self, x):
|
||||
self.print(x, 'in forward')
|
||||
|
||||
"""
|
||||
if self.trainer.is_global_zero:
|
||||
|
@ -409,36 +407,35 @@ class LightningModule(
|
|||
Return:
|
||||
Predicted output
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# example if we were using this model as a feature extractor
|
||||
def forward(self, x):
|
||||
feature_maps = self.convnet(x)
|
||||
return feature_maps
|
||||
# example if we were using this model as a feature extractor
|
||||
def forward(self, x):
|
||||
feature_maps = self.convnet(x)
|
||||
return feature_maps
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
feature_maps = self(x)
|
||||
logits = self.classifier(feature_maps)
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
feature_maps = self(x)
|
||||
logits = self.classifier(feature_maps)
|
||||
|
||||
# ...
|
||||
return loss
|
||||
# ...
|
||||
return loss
|
||||
|
||||
# splitting it this way allows model to be used a feature extractor
|
||||
model = MyModelAbove()
|
||||
# splitting it this way allows model to be used a feature extractor
|
||||
model = MyModelAbove()
|
||||
|
||||
inputs = server.get_request()
|
||||
results = model(inputs)
|
||||
server.write_results(results)
|
||||
inputs = server.get_request()
|
||||
results = model(inputs)
|
||||
server.write_results(results)
|
||||
|
||||
# -------------
|
||||
# This is in stark contrast to torch.nn.Module where normally you would have this:
|
||||
def forward(self, batch):
|
||||
x, y = batch
|
||||
feature_maps = self.convnet(x)
|
||||
logits = self.classifier(feature_maps)
|
||||
return logits
|
||||
# -------------
|
||||
# This is in stark contrast to torch.nn.Module where normally you would have this:
|
||||
def forward(self, batch):
|
||||
x, y = batch
|
||||
feature_maps = self.convnet(x)
|
||||
logits = self.classifier(feature_maps)
|
||||
return logits
|
||||
|
||||
"""
|
||||
return super().forward(*args, **kwargs)
|
||||
|
@ -655,37 +652,36 @@ class LightningModule(
|
|||
# if you have multiple val dataloaders:
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx)
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# CASE 1: A single validation dataset
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
# CASE 1: A single validation dataset
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
||||
# implement your own
|
||||
out = self(x)
|
||||
loss = self.loss(out, y)
|
||||
# implement your own
|
||||
out = self(x)
|
||||
loss = self.loss(out, y)
|
||||
|
||||
# log 6 example images
|
||||
# or generated text... or whatever
|
||||
sample_imgs = x[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('example_images', grid, 0)
|
||||
# log 6 example images
|
||||
# or generated text... or whatever
|
||||
sample_imgs = x[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('example_images', grid, 0)
|
||||
|
||||
# calculate acc
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
# calculate acc
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
|
||||
# log the outputs!
|
||||
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
|
||||
# log the outputs!
|
||||
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
|
||||
|
||||
If you pass in multiple val datasets, validation_step will have an additional argument.
|
||||
If you pass in multiple val datasets, validation_step will have an additional argument.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# CASE 2: multiple validation datasets
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
# dataloader_idx tells you which dataset this is.
|
||||
# CASE 2: multiple validation datasets
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
# dataloader_idx tells you which dataset this is.
|
||||
|
||||
Note:
|
||||
If you don't need to validate you don't need to implement this method.
|
||||
|
@ -831,38 +827,37 @@ class LightningModule(
|
|||
# if you have multiple test dataloaders:
|
||||
def test_step(self, batch, batch_idx, dataloader_idx)
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# CASE 1: A single test dataset
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
# CASE 1: A single test dataset
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
||||
# implement your own
|
||||
out = self(x)
|
||||
loss = self.loss(out, y)
|
||||
# implement your own
|
||||
out = self(x)
|
||||
loss = self.loss(out, y)
|
||||
|
||||
# log 6 example images
|
||||
# or generated text... or whatever
|
||||
sample_imgs = x[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('example_images', grid, 0)
|
||||
# log 6 example images
|
||||
# or generated text... or whatever
|
||||
sample_imgs = x[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('example_images', grid, 0)
|
||||
|
||||
# calculate acc
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
# calculate acc
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
|
||||
# log the outputs!
|
||||
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
|
||||
# log the outputs!
|
||||
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
|
||||
|
||||
If you pass in multiple validation datasets, :meth:`test_step` will have an additional
|
||||
argument.
|
||||
If you pass in multiple validation datasets, :meth:`test_step` will have an additional
|
||||
argument.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# CASE 2: multiple test datasets
|
||||
def test_step(self, batch, batch_idx, dataloader_idx):
|
||||
# dataloader_idx tells you which dataset this is.
|
||||
# CASE 2: multiple test datasets
|
||||
def test_step(self, batch, batch_idx, dataloader_idx):
|
||||
# dataloader_idx tells you which dataset this is.
|
||||
|
||||
Note:
|
||||
If you don't need to validate you don't need to implement this method.
|
||||
|
@ -1023,47 +1018,46 @@ class LightningModule(
|
|||
|
||||
Only the ``scheduler`` key is required, the rest will be set to the defaults above.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# most cases
|
||||
def configure_optimizers(self):
|
||||
opt = Adam(self.parameters(), lr=1e-3)
|
||||
return opt
|
||||
# most cases
|
||||
def configure_optimizers(self):
|
||||
opt = Adam(self.parameters(), lr=1e-3)
|
||||
return opt
|
||||
|
||||
# multiple optimizer case (e.g.: GAN)
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
return generator_opt, disriminator_opt
|
||||
# multiple optimizer case (e.g.: GAN)
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
return generator_opt, disriminator_opt
|
||||
|
||||
# example with learning rate schedulers
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
||||
return [generator_opt, disriminator_opt], [discriminator_sched]
|
||||
# example with learning rate schedulers
|
||||
def configure_optimizers(self):
|
||||
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
||||
return [generator_opt, disriminator_opt], [discriminator_sched]
|
||||
|
||||
# example with step-based learning rate schedulers
|
||||
def configure_optimizers(self):
|
||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
|
||||
'interval': 'step'} # called after each training step
|
||||
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
|
||||
return [gen_opt, dis_opt], [gen_sched, dis_sched]
|
||||
# example with step-based learning rate schedulers
|
||||
def configure_optimizers(self):
|
||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
|
||||
'interval': 'step'} # called after each training step
|
||||
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
|
||||
return [gen_opt, dis_opt], [gen_sched, dis_sched]
|
||||
|
||||
# example with optimizer frequencies
|
||||
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
|
||||
# https://arxiv.org/abs/1704.00028
|
||||
def configure_optimizers(self):
|
||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
n_critic = 5
|
||||
return (
|
||||
{'optimizer': dis_opt, 'frequency': n_critic},
|
||||
{'optimizer': gen_opt, 'frequency': 1}
|
||||
)
|
||||
# example with optimizer frequencies
|
||||
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
|
||||
# https://arxiv.org/abs/1704.00028
|
||||
def configure_optimizers(self):
|
||||
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
||||
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
||||
n_critic = 5
|
||||
return (
|
||||
{'optimizer': dis_opt, 'frequency': n_critic},
|
||||
{'optimizer': gen_opt, 'frequency': 1}
|
||||
)
|
||||
|
||||
Note:
|
||||
|
||||
|
@ -1211,50 +1205,49 @@ class LightningModule(
|
|||
using_native_amp: True if using native amp
|
||||
using_lbfgs: True if the matching optimizer is lbfgs
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# DEFAULT
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
# DEFAULT
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
# Alternating schedule for optimizer steps (i.e.: GANs)
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
# update generator opt every 2 steps
|
||||
if optimizer_idx == 0:
|
||||
if batch_idx % 2 == 0 :
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
# Alternating schedule for optimizer steps (i.e.: GANs)
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
# update generator opt every 2 steps
|
||||
if optimizer_idx == 0:
|
||||
if batch_idx % 2 == 0 :
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# update discriminator opt every 4 steps
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 4 == 0 :
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
# update discriminator opt every 4 steps
|
||||
if optimizer_idx == 1:
|
||||
if batch_idx % 4 == 0 :
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# ...
|
||||
# add as many optimizers as you want
|
||||
# ...
|
||||
# add as many optimizers as you want
|
||||
|
||||
|
||||
Here's another example showing how to use this for more advanced things such as
|
||||
learning rate warm-up:
|
||||
Here's another example showing how to use this for more advanced things such as
|
||||
learning rate warm-up:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# learning rate warm-up
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
# warm up lr
|
||||
if self.trainer.global_step < 500:
|
||||
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = lr_scale * self.learning_rate
|
||||
# learning rate warm-up
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
||||
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
||||
# warm up lr
|
||||
if self.trainer.global_step < 500:
|
||||
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = lr_scale * self.learning_rate
|
||||
|
||||
# update params
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
# update params
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
optimizer.zero_grad()
|
||||
|
||||
"""
|
||||
if not isinstance(optimizer, LightningOptimizer):
|
||||
|
@ -1282,26 +1275,25 @@ class LightningModule(
|
|||
back propagation through time. The default implementation splits root level Tensors and
|
||||
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_x = x[:, t:t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_x = x[:, t:t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
||||
|
||||
batch_split.append(split_x)
|
||||
batch_split.append(split_x)
|
||||
|
||||
splits.append(batch_split)
|
||||
splits.append(batch_split)
|
||||
|
||||
return splits
|
||||
return splits
|
||||
|
||||
Note:
|
||||
Called in the training loop after
|
||||
|
@ -1354,11 +1346,10 @@ class LightningModule(
|
|||
r"""
|
||||
Freeze all params for inference.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
model = MyLightningModule(...)
|
||||
model.freeze()
|
||||
model = MyLightningModule(...)
|
||||
model.freeze()
|
||||
|
||||
"""
|
||||
for param in self.parameters():
|
||||
|
|
|
@ -93,36 +93,35 @@ class ModelIO(object):
|
|||
Return:
|
||||
:class:`LightningModule` with loaded weights and hyperparameters (if available).
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
Example::
|
||||
|
||||
# load weights without mapping ...
|
||||
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
||||
# load weights without mapping ...
|
||||
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
||||
|
||||
# or load weights mapping all weights from GPU 1 to GPU 0 ...
|
||||
map_location = {'cuda:1':'cuda:0'}
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
map_location=map_location
|
||||
)
|
||||
# or load weights mapping all weights from GPU 1 to GPU 0 ...
|
||||
map_location = {'cuda:1':'cuda:0'}
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
map_location=map_location
|
||||
)
|
||||
|
||||
# or load weights and hyperparameters from separate files.
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
hparams_file='/path/to/hparams_file.yaml'
|
||||
)
|
||||
# or load weights and hyperparameters from separate files.
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
hparams_file='/path/to/hparams_file.yaml'
|
||||
)
|
||||
|
||||
# override some of the params with new values
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
PATH,
|
||||
num_layers=128,
|
||||
pretrained_ckpt_path: NEW_PATH,
|
||||
)
|
||||
# override some of the params with new values
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
PATH,
|
||||
num_layers=128,
|
||||
pretrained_ckpt_path: NEW_PATH,
|
||||
)
|
||||
|
||||
# predict
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
y_hat = pretrained_model(x)
|
||||
# predict
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
y_hat = pretrained_model(x)
|
||||
"""
|
||||
if map_location is not None:
|
||||
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
||||
|
|
|
@ -52,7 +52,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
**ONLINE MODE**
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
|
@ -70,7 +70,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
**OFFLINE MODE**
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
|
||||
|
|
|
@ -85,16 +85,15 @@ class ApexPlugin(PrecisionPlugin):
|
|||
Return:
|
||||
Apex wrapped model and optimizers
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
Examples::
|
||||
|
||||
# Default implementation used by Trainer.
|
||||
def configure_apex(self, amp, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(
|
||||
model, optimizers, opt_level=amp_level,
|
||||
)
|
||||
# Default implementation used by Trainer.
|
||||
def configure_apex(self, amp, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(
|
||||
model, optimizers, opt_level=amp_level,
|
||||
)
|
||||
|
||||
return model, optimizers
|
||||
return model, optimizers
|
||||
"""
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
|
Loading…
Reference in New Issue