2020-05-05 02:16:54 +00:00
|
|
|
.. testsetup:: *
|
|
|
|
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Transfer Learning
|
|
|
|
-----------------
|
|
|
|
|
|
|
|
Using Pretrained Models
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
Sometimes we want to use a LightningModule as a pretrained model. This is fine because
|
|
|
|
a LightningModule is just a `torch.nn.Module`!
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. note:: Remember that a LightningModule is EXACTLY a torch.nn.Module but with more capabilities.
|
2020-03-03 21:42:49 +00:00
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Let's use the `AutoEncoder` as a feature extractor in a separate model.
|
|
|
|
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
class Encoder(torch.nn.Module):
|
|
|
|
...
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class AutoEncoder(LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
def __init__(self):
|
|
|
|
self.encoder = Encoder()
|
|
|
|
self.decoder = Decoder()
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class CIFAR10Classifier(LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
def __init__(self):
|
|
|
|
# init the pretrained LightningModule
|
|
|
|
self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
|
|
|
|
self.feature_extractor.freeze()
|
|
|
|
|
|
|
|
# the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
|
2020-03-03 21:42:49 +00:00
|
|
|
self.classifier = nn.Linear(100, 10)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
representations = self.feature_extractor(x)
|
|
|
|
x = self.classifier(representations)
|
|
|
|
...
|
|
|
|
|
|
|
|
We used our pretrained Autoencoder (a LightningModule) for transfer learning!
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Example: Imagenet (computer Vision)
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
|
|
|
:skipif: not TORCHVISION_AVAILABLE
|
2020-03-03 21:42:49 +00:00
|
|
|
|
|
|
|
import torchvision.models as models
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class ImagenetTransferLearning(LightningModule):
|
2020-03-03 21:42:49 +00:00
|
|
|
def __init__(self):
|
|
|
|
# init a pretrained resnet
|
|
|
|
num_target_classes = 10
|
2020-05-05 02:16:54 +00:00
|
|
|
self.feature_extractor = models.resnet50(
|
2020-03-03 21:42:49 +00:00
|
|
|
pretrained=True,
|
|
|
|
num_classes=num_target_classes)
|
|
|
|
self.feature_extractor.eval()
|
|
|
|
|
|
|
|
# use the pretrained model to classify cifar-10 (10 image classes)
|
|
|
|
self.classifier = nn.Linear(2048, num_target_classes)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
representations = self.feature_extractor(x)
|
|
|
|
x = self.classifier(representations)
|
|
|
|
...
|
|
|
|
|
|
|
|
Finetune
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
model = ImagenetTransferLearning()
|
2020-03-03 21:42:49 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
And use it to predict your data of interest
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
|
2020-03-03 21:42:49 +00:00
|
|
|
model.freeze()
|
|
|
|
|
|
|
|
x = some_images_from_cifar10()
|
|
|
|
predictions = model(x)
|
|
|
|
|
|
|
|
We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10.
|
|
|
|
In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.
|
|
|
|
|
|
|
|
Example: BERT (NLP)
|
|
|
|
^^^^^^^^^^^^^^^^^^^
|
|
|
|
Lightning is completely agnostic to what's used for transfer learning so long
|
2020-03-03 15:52:16 +00:00
|
|
|
as it is a `torch.nn.Module` subclass.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Here's a model that uses `Huggingface transformers <https://github.com/huggingface/transformers>`_.
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class BertMNLIFinetuner(LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
|
|
|
|
self.W = nn.Linear(bert.config.hidden_size, 3)
|
|
|
|
self.num_classes = 3
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids):
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
h, _, attn = self.bert(input_ids=input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
token_type_ids=token_type_ids)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
h_cls = h[:, 0]
|
|
|
|
logits = self.W(h_cls)
|
|
|
|
return logits, attn
|