lightning/notebooks/04-transformers-text-classi...

543 lines
20 KiB
Plaintext
Raw Normal View History

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "04-transformers-text-classification.ipynb",
"provenance": [],
"collapsed_sections": [],
2020-09-22 17:15:25 +00:00
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "8ag5ANQPJ_j9",
"colab_type": "text"
},
"source": [
"# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
"\n",
"This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
"\n",
"[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
"\n",
"---\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
2020-09-22 17:15:25 +00:00
" - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
"\n",
2020-09-22 17:15:25 +00:00
" - [HuggingFace datasets](https://github.com/huggingface/datasets)\n",
" - [HuggingFace transformers](https://github.com/huggingface/transformers)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fqlsVTj7McZ3",
"colab_type": "text"
},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "OIhHrRL-MnKK",
"colab_type": "code",
"colab": {}
},
"source": [
"!pip install pytorch-lightning datasets transformers"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6yuQT_ZQMpCg",
"colab_type": "code",
"colab": {}
},
"source": [
"from argparse import ArgumentParser\n",
"from datetime import datetime\n",
"from typing import Optional\n",
"\n",
2020-09-22 17:15:25 +00:00
"import datasets\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from transformers import (\n",
" AdamW,\n",
" AutoModelForSequenceClassification,\n",
" AutoConfig,\n",
" AutoTokenizer,\n",
" get_linear_schedule_with_warmup,\n",
" glue_compute_metrics\n",
")"
],
2020-09-22 17:15:25 +00:00
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9ORJfiuiNZ_N",
"colab_type": "text"
},
"source": [
"## GLUE DataModule"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jW9xQhZxMz1G",
"colab_type": "code",
"colab": {}
},
"source": [
"class GLUEDataModule(pl.LightningDataModule):\n",
"\n",
" task_text_field_map = {\n",
" 'cola': ['sentence'],\n",
" 'sst2': ['sentence'],\n",
" 'mrpc': ['sentence1', 'sentence2'],\n",
" 'qqp': ['question1', 'question2'],\n",
" 'stsb': ['sentence1', 'sentence2'],\n",
" 'mnli': ['premise', 'hypothesis'],\n",
" 'qnli': ['question', 'sentence'],\n",
" 'rte': ['sentence1', 'sentence2'],\n",
" 'wnli': ['sentence1', 'sentence2'],\n",
" 'ax': ['premise', 'hypothesis']\n",
" }\n",
"\n",
" glue_task_num_labels = {\n",
" 'cola': 2,\n",
" 'sst2': 2,\n",
" 'mrpc': 2,\n",
" 'qqp': 2,\n",
" 'stsb': 1,\n",
" 'mnli': 3,\n",
" 'qnli': 2,\n",
" 'rte': 2,\n",
" 'wnli': 2,\n",
" 'ax': 3\n",
" }\n",
"\n",
" loader_columns = [\n",
2020-09-22 17:15:25 +00:00
" 'datasets_idx',\n",
" 'input_ids',\n",
" 'token_type_ids',\n",
" 'attention_mask',\n",
" 'start_positions',\n",
" 'end_positions',\n",
" 'labels'\n",
" ]\n",
"\n",
" def __init__(\n",
" self,\n",
" model_name_or_path: str,\n",
" task_name: str ='mrpc',\n",
" max_seq_length: int = 128,\n",
" train_batch_size: int = 32,\n",
" eval_batch_size: int = 32,\n",
" **kwargs\n",
" ):\n",
" super().__init__()\n",
" self.model_name_or_path = model_name_or_path\n",
" self.task_name = task_name\n",
" self.max_seq_length = max_seq_length\n",
" self.train_batch_size = train_batch_size\n",
" self.eval_batch_size = eval_batch_size\n",
"\n",
" self.text_fields = self.task_text_field_map[task_name]\n",
" self.num_labels = self.glue_task_num_labels[task_name]\n",
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
"\n",
" def setup(self, stage):\n",
2020-09-22 17:15:25 +00:00
" self.dataset = datasets.load_dataset('glue', self.task_name)\n",
"\n",
" for split in self.dataset.keys():\n",
" self.dataset[split] = self.dataset[split].map(\n",
" self.convert_to_features,\n",
" batched=True,\n",
" remove_columns=['label'],\n",
" )\n",
" self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
" self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
"\n",
" self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
"\n",
" def prepare_data(self):\n",
2020-09-22 17:15:25 +00:00
" datasets.load_dataset('glue', self.task_name)\n",
" AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
" \n",
" def train_dataloader(self):\n",
" return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
" \n",
" def val_dataloader(self):\n",
" if len(self.eval_splits) == 1:\n",
" return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
" elif len(self.eval_splits) > 1:\n",
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
"\n",
" def test_dataloader(self):\n",
" if len(self.eval_splits) == 1:\n",
" return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
" elif len(self.eval_splits) > 1:\n",
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
"\n",
" def convert_to_features(self, example_batch, indices=None):\n",
"\n",
" # Either encode single sentence or sentence pairs\n",
" if len(self.text_fields) > 1:\n",
" texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
" else:\n",
" texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
"\n",
" # Tokenize the text/text pairs\n",
" features = self.tokenizer.batch_encode_plus(\n",
" texts_or_text_pairs,\n",
" max_length=self.max_seq_length,\n",
" pad_to_max_length=True,\n",
" truncation=True\n",
" )\n",
"\n",
" # Rename label to labels to make it easier to pass to model forward\n",
" features['labels'] = example_batch['label']\n",
"\n",
" return features"
],
2020-09-22 17:15:25 +00:00
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jQC3a6KuOpX3",
"colab_type": "text"
},
"source": [
"#### You could use this datamodule with standalone PyTorch if you wanted..."
]
},
{
"cell_type": "code",
"metadata": {
"id": "JCMH3IAsNffF",
"colab_type": "code",
"colab": {}
},
"source": [
"dm = GLUEDataModule('distilbert-base-uncased')\n",
"dm.prepare_data()\n",
"dm.setup('fit')\n",
"next(iter(dm.train_dataloader()))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "l9fQ_67BO2Lj",
"colab_type": "text"
},
"source": [
"## GLUE Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gtn5YGKYO65B",
"colab_type": "code",
"colab": {}
},
"source": [
"class GLUETransformer(pl.LightningModule):\n",
" def __init__(\n",
" self,\n",
" model_name_or_path: str,\n",
" num_labels: int,\n",
" learning_rate: float = 2e-5,\n",
" adam_epsilon: float = 1e-8,\n",
" warmup_steps: int = 0,\n",
" weight_decay: float = 0.0,\n",
" train_batch_size: int = 32,\n",
" eval_batch_size: int = 32,\n",
" eval_splits: Optional[list] = None,\n",
" **kwargs\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.save_hyperparameters()\n",
"\n",
" self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
2020-09-22 17:15:25 +00:00
" self.metric = datasets.load_metric(\n",
" 'glue',\n",
" self.hparams.task_name,\n",
" experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
" )\n",
"\n",
" def forward(self, **inputs):\n",
" return self.model(**inputs)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" outputs = self(**batch)\n",
" loss = outputs[0]\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
" outputs = self(**batch)\n",
" val_loss, logits = outputs[:2]\n",
"\n",
" if self.hparams.num_labels >= 1:\n",
" preds = torch.argmax(logits, axis=1)\n",
" elif self.hparams.num_labels == 1:\n",
" preds = logits.squeeze()\n",
"\n",
" labels = batch[\"labels\"]\n",
"\n",
" return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
"\n",
" def validation_epoch_end(self, outputs):\n",
" if self.hparams.task_name == 'mnli':\n",
" for i, output in enumerate(outputs):\n",
" # matched or mismatched\n",
" split = self.hparams.eval_splits[i].split('_')[-1]\n",
" preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in output]).mean()\n",
" self.log(f'val_loss_{split}', loss, prog_bar=True)\n",
2020-09-22 17:15:25 +00:00
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
" self.log_dict(split_metrics, prog_bar=True)\n",
" return loss\n",
"\n",
" preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
" return loss\n",
"\n",
" def setup(self, stage):\n",
" if stage == 'fit':\n",
" # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
" train_loader = self.train_dataloader()\n",
"\n",
" # Calculate total steps\n",
" self.total_steps = (\n",
" (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
" // self.hparams.accumulate_grad_batches\n",
" * float(self.hparams.max_epochs)\n",
" )\n",
"\n",
" def configure_optimizers(self):\n",
" \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
" model = self.model\n",
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
" optimizer_grouped_parameters = [\n",
" {\n",
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
" \"weight_decay\": self.hparams.weight_decay,\n",
" },\n",
" {\n",
" \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
" \"weight_decay\": 0.0,\n",
" },\n",
" ]\n",
" optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
"\n",
" scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
" )\n",
" scheduler = {\n",
" 'scheduler': scheduler,\n",
" 'interval': 'step',\n",
" 'frequency': 1\n",
" }\n",
" return [optimizer], [scheduler]\n",
"\n",
" @staticmethod\n",
" def add_model_specific_args(parent_parser):\n",
" parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
" parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
" parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
" parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
" parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
" return parser"
],
2020-09-22 17:15:25 +00:00
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ha-NdIP_xbd3",
"colab_type": "text"
},
"source": [
"### ⚡ Quick Tip \n",
" - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3dEHnl3RPlAR",
"colab_type": "code",
"colab": {}
},
"source": [
"def parse_args(args=None):\n",
" parser = ArgumentParser()\n",
" parser = pl.Trainer.add_argparse_args(parser)\n",
" parser = GLUEDataModule.add_argparse_args(parser)\n",
" parser = GLUETransformer.add_model_specific_args(parser)\n",
" parser.add_argument('--seed', type=int, default=42)\n",
" return parser.parse_args(args)\n",
"\n",
"\n",
"def main(args):\n",
" pl.seed_everything(args.seed)\n",
" dm = GLUEDataModule.from_argparse_args(args)\n",
" dm.prepare_data()\n",
" dm.setup('fit')\n",
" model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
" trainer = pl.Trainer.from_argparse_args(args)\n",
" return dm, model, trainer"
],
2020-09-22 17:15:25 +00:00
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkuLaeec3sJ-",
"colab_type": "text"
},
"source": [
"# Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QSpueK5UPsN7",
"colab_type": "text"
},
"source": [
"## CoLA\n",
"\n",
"See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NJnFmtpnPu0Y",
"colab_type": "code",
"colab": {}
},
"source": [
"mocked_args = \"\"\"\n",
" --model_name_or_path albert-base-v2\n",
" --task_name cola\n",
" --max_epochs 3\n",
" --gpus 1\"\"\".split()\n",
"\n",
"args = parse_args(mocked_args)\n",
"dm, model, trainer = main(args)\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_MrNsTnqdz4z",
"colab_type": "text"
},
"source": [
"## MRPC\n",
"\n",
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LBwRxg9Cb3d-",
"colab_type": "code",
"colab": {}
},
"source": [
"mocked_args = \"\"\"\n",
" --model_name_or_path distilbert-base-cased\n",
" --task_name mrpc\n",
" --max_epochs 3\n",
" --gpus 1\"\"\".split()\n",
"\n",
"args = parse_args(mocked_args)\n",
"dm, model, trainer = main(args)\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iZhbn0HzfdCu",
"colab_type": "text"
},
"source": [
"## MNLI\n",
"\n",
" - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
"\n",
" - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
"\n",
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AvsZMOggfcWW",
"colab_type": "code",
"colab": {}
},
"source": [
"mocked_args = \"\"\"\n",
" --model_name_or_path distilbert-base-uncased\n",
" --task_name mnli\n",
" --max_epochs 1\n",
" --gpus 1\n",
" --limit_train_batches 10\n",
" --progress_bar_refresh_rate 20\"\"\".split()\n",
"\n",
"args = parse_args(mocked_args)\n",
"dm, model, trainer = main(args)\n",
"trainer.fit(model, dm)"
],
"execution_count": null,
"outputs": []
}
]
2020-09-22 17:15:25 +00:00
}