90 lines
2.7 KiB
ReStructuredText
90 lines
2.7 KiB
ReStructuredText
Bolts
|
|
=====
|
|
`PyTorch Lightning Bolts <https://pytorch-lightning-bolts.readthedocs.io/en/latest/>`_, is our official collection
|
|
of prebuilt models across many research domains.
|
|
|
|
.. code-block:: bash
|
|
|
|
pip install pytorch-lightning-bolts
|
|
|
|
In bolts we have:
|
|
|
|
- A collection of pretrained state-of-the-art models.
|
|
- A collection of models designed to bootstrap your research.
|
|
- A collection of Callbacks, transforms, full datasets.
|
|
- All models work on CPUs, TPUs, GPUs and 16-bit precision.
|
|
|
|
-----------------
|
|
|
|
Quality control
|
|
---------------
|
|
Bolts are built-by the Lightning community and contributed to bolts.
|
|
The lightning team guarantees that contributions are:
|
|
|
|
- Rigorously Tested (CPUs, GPUs, TPUs)
|
|
- Rigorously Documented
|
|
- Standardized via PyTorch Lightning
|
|
- Optimized for speed
|
|
- Checked for correctness
|
|
|
|
---------
|
|
|
|
Example 1: Pretrained, prebuilt models
|
|
--------------------------------------
|
|
|
|
.. code-block:: python
|
|
|
|
from pl_bolts.models import VAE, GPT2, ImageGPT, PixelCNN
|
|
from pl_bolts.models.self_supervised import AMDIM, CPCV2, SimCLR, MocoV2
|
|
from pl_bolts.models import LinearRegression, LogisticRegression
|
|
from pl_bolts.models.gans import GAN
|
|
from pl_bolts.callbacks import PrintTableMetricsCallback
|
|
from pl_bolts.datamodules import FashionMNISTDataModule, CIFAR10DataModule, ImagenetDataModule
|
|
|
|
------------
|
|
|
|
Example 2: Extend for faster research
|
|
-------------------------------------
|
|
Bolts are contributed with benchmarks and continuous-integration tests. This means
|
|
you can trust the implementations and use them to bootstrap your resarch much faster.
|
|
|
|
.. code-block:: python
|
|
|
|
from pl_bolts.models import ImageGPT
|
|
from pl_bolts.self_supervised import SimCLR
|
|
|
|
class VideoGPT(ImageGPT):
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
x = _shape_input(x)
|
|
|
|
logits = self.gpt(x)
|
|
simclr_features = self.simclr(x)
|
|
|
|
# -----------------
|
|
# do something new with GPT logits + simclr_features
|
|
# -----------------
|
|
|
|
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
|
|
|
|
logs = {"loss": loss}
|
|
return {"loss": loss, "log": logs}
|
|
|
|
----------
|
|
|
|
Example 3: Callbacks
|
|
--------------------
|
|
We also have a collection of callbacks.
|
|
|
|
.. code-block:: python
|
|
|
|
from pl_bolts.callbacks import PrintTableMetricsCallback
|
|
import pytorch_lightning as pl
|
|
|
|
trainer = pl.Trainer(callbacks=[PrintTableMetricsCallback()])
|
|
|
|
# loss│train_loss│val_loss│epoch
|
|
# ──────────────────────────────
|
|
# 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
|