{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "04-transformers-text-classification.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true, "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "8ag5ANQPJ_j9", "colab_type": "text" }, "source": [ "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n", "\n", "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n", "\n", "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", "\n", " - [HuggingFace nlp](https://github.com/huggingface/nlp)\n", " - [HuggingFace transformers](https://github.com/huggingface/transformers)" ] }, { "cell_type": "markdown", "metadata": { "id": "fqlsVTj7McZ3", "colab_type": "text" }, "source": [ "### Setup" ] }, { "cell_type": "code", "metadata": { "id": "OIhHrRL-MnKK", "colab_type": "code", "colab": {} }, "source": [ "!pip install pytorch-lightning datasets transformers" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6yuQT_ZQMpCg", "colab_type": "code", "colab": {} }, "source": [ "from argparse import ArgumentParser\n", "from datetime import datetime\n", "from typing import Optional\n", "\n", "import nlp\n", "import numpy as np\n", "import pytorch_lightning as pl\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from transformers import (\n", " AdamW,\n", " AutoModelForSequenceClassification,\n", " AutoConfig,\n", " AutoTokenizer,\n", " get_linear_schedule_with_warmup,\n", " glue_compute_metrics\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "9ORJfiuiNZ_N", "colab_type": "text" }, "source": [ "## GLUE DataModule" ] }, { "cell_type": "code", "metadata": { "id": "jW9xQhZxMz1G", "colab_type": "code", "colab": {} }, "source": [ "class GLUEDataModule(pl.LightningDataModule):\n", "\n", " task_text_field_map = {\n", " 'cola': ['sentence'],\n", " 'sst2': ['sentence'],\n", " 'mrpc': ['sentence1', 'sentence2'],\n", " 'qqp': ['question1', 'question2'],\n", " 'stsb': ['sentence1', 'sentence2'],\n", " 'mnli': ['premise', 'hypothesis'],\n", " 'qnli': ['question', 'sentence'],\n", " 'rte': ['sentence1', 'sentence2'],\n", " 'wnli': ['sentence1', 'sentence2'],\n", " 'ax': ['premise', 'hypothesis']\n", " }\n", "\n", " glue_task_num_labels = {\n", " 'cola': 2,\n", " 'sst2': 2,\n", " 'mrpc': 2,\n", " 'qqp': 2,\n", " 'stsb': 1,\n", " 'mnli': 3,\n", " 'qnli': 2,\n", " 'rte': 2,\n", " 'wnli': 2,\n", " 'ax': 3\n", " }\n", "\n", " loader_columns = [\n", " 'nlp_idx',\n", " 'input_ids',\n", " 'token_type_ids',\n", " 'attention_mask',\n", " 'start_positions',\n", " 'end_positions',\n", " 'labels'\n", " ]\n", "\n", " def __init__(\n", " self,\n", " model_name_or_path: str,\n", " task_name: str ='mrpc',\n", " max_seq_length: int = 128,\n", " train_batch_size: int = 32,\n", " eval_batch_size: int = 32,\n", " **kwargs\n", " ):\n", " super().__init__()\n", " self.model_name_or_path = model_name_or_path\n", " self.task_name = task_name\n", " self.max_seq_length = max_seq_length\n", " self.train_batch_size = train_batch_size\n", " self.eval_batch_size = eval_batch_size\n", "\n", " self.text_fields = self.task_text_field_map[task_name]\n", " self.num_labels = self.glue_task_num_labels[task_name]\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", "\n", " def setup(self, stage):\n", " self.dataset = nlp.load_dataset('glue', self.task_name)\n", "\n", " for split in self.dataset.keys():\n", " self.dataset[split] = self.dataset[split].map(\n", " self.convert_to_features,\n", " batched=True,\n", " remove_columns=['label'],\n", " )\n", " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", "\n", " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n", "\n", " def prepare_data(self):\n", " nlp.load_dataset('glue', self.task_name)\n", " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", " \n", " def train_dataloader(self):\n", " return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n", " \n", " def val_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def test_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def convert_to_features(self, example_batch, indices=None):\n", "\n", " # Either encode single sentence or sentence pairs\n", " if len(self.text_fields) > 1:\n", " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", " else:\n", " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", "\n", " # Tokenize the text/text pairs\n", " features = self.tokenizer.batch_encode_plus(\n", " texts_or_text_pairs,\n", " max_length=self.max_seq_length,\n", " pad_to_max_length=True,\n", " truncation=True\n", " )\n", "\n", " # Rename label to labels to make it easier to pass to model forward\n", " features['labels'] = example_batch['label']\n", "\n", " return features" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "jQC3a6KuOpX3", "colab_type": "text" }, "source": [ "#### You could use this datamodule with standalone PyTorch if you wanted..." ] }, { "cell_type": "code", "metadata": { "id": "JCMH3IAsNffF", "colab_type": "code", "colab": {} }, "source": [ "dm = GLUEDataModule('distilbert-base-uncased')\n", "dm.prepare_data()\n", "dm.setup('fit')\n", "next(iter(dm.train_dataloader()))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "l9fQ_67BO2Lj", "colab_type": "text" }, "source": [ "## GLUE Model" ] }, { "cell_type": "code", "metadata": { "id": "gtn5YGKYO65B", "colab_type": "code", "colab": {} }, "source": [ "class GLUETransformer(pl.LightningModule):\n", " def __init__(\n", " self,\n", " model_name_or_path: str,\n", " num_labels: int,\n", " learning_rate: float = 2e-5,\n", " adam_epsilon: float = 1e-8,\n", " warmup_steps: int = 0,\n", " weight_decay: float = 0.0,\n", " train_batch_size: int = 32,\n", " eval_batch_size: int = 32,\n", " eval_splits: Optional[list] = None,\n", " **kwargs\n", " ):\n", " super().__init__()\n", "\n", " self.save_hyperparameters()\n", "\n", " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n", " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n", " self.metric = nlp.load_metric(\n", " 'glue',\n", " self.hparams.task_name,\n", " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n", " )\n", "\n", " def forward(self, **inputs):\n", " return self.model(**inputs)\n", "\n", " def training_step(self, batch, batch_idx):\n", " outputs = self(**batch)\n", " loss = outputs[0]\n", " return pl.TrainResult(loss)\n", "\n", " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", " outputs = self(**batch)\n", " val_loss, logits = outputs[:2]\n", "\n", " if self.hparams.num_labels >= 1:\n", " preds = torch.argmax(logits, axis=1)\n", " elif self.hparams.num_labels == 1:\n", " preds = logits.squeeze()\n", "\n", " labels = batch[\"labels\"]\n", "\n", " return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n", "\n", " def validation_epoch_end(self, outputs):\n", " if self.hparams.task_name == 'mnli':\n", " for i, output in enumerate(outputs):\n", " # matched or mismatched\n", " split = self.hparams.eval_splits[i].split('_')[-1]\n", " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", " loss = torch.stack([x['loss'] for x in output]).mean()\n", " if i == 0:\n", " result = pl.EvalResult(checkpoint_on=loss)\n", " result.log(f'val_loss_{split}', loss, prog_bar=True)\n", " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(preds, labels).items()}\n", " result.log_dict(split_metrics, prog_bar=True)\n", " return result\n", "\n", " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", " loss = torch.stack([x['loss'] for x in outputs]).mean()\n", " result = pl.EvalResult(checkpoint_on=loss)\n", " result.log('val_loss', loss, prog_bar=True)\n", " result.log_dict(self.metric.compute(preds, labels), prog_bar=True)\n", " return result\n", "\n", " def setup(self, stage):\n", " if stage == 'fit':\n", " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n", " train_loader = self.train_dataloader()\n", "\n", " # Calculate total steps\n", " self.total_steps = (\n", " (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n", " // self.hparams.accumulate_grad_batches\n", " * float(self.hparams.max_epochs)\n", " )\n", "\n", " def configure_optimizers(self):\n", " \"Prepare optimizer and schedule (linear warmup and decay)\"\n", " model = self.model\n", " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", " optimizer_grouped_parameters = [\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", " \"weight_decay\": self.hparams.weight_decay,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", " \"weight_decay\": 0.0,\n", " },\n", " ]\n", " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", "\n", " scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n", " )\n", " scheduler = {\n", " 'scheduler': scheduler,\n", " 'interval': 'step',\n", " 'frequency': 1\n", " }\n", " return [optimizer], [scheduler]\n", "\n", " @staticmethod\n", " def add_model_specific_args(parent_parser):\n", " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n", " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n", " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n", " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n", " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", " return parser" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ha-NdIP_xbd3", "colab_type": "text" }, "source": [ "### ⚡ Quick Tip \n", " - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration" ] }, { "cell_type": "code", "metadata": { "id": "3dEHnl3RPlAR", "colab_type": "code", "colab": {} }, "source": [ "def parse_args(args=None):\n", " parser = ArgumentParser()\n", " parser = pl.Trainer.add_argparse_args(parser)\n", " parser = GLUEDataModule.add_argparse_args(parser)\n", " parser = GLUETransformer.add_model_specific_args(parser)\n", " parser.add_argument('--seed', type=int, default=42)\n", " return parser.parse_args(args)\n", "\n", "\n", "def main(args):\n", " pl.seed_everything(args.seed)\n", " dm = GLUEDataModule.from_argparse_args(args)\n", " dm.prepare_data()\n", " dm.setup('fit')\n", " model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n", " trainer = pl.Trainer.from_argparse_args(args)\n", " return dm, model, trainer" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "PkuLaeec3sJ-", "colab_type": "text" }, "source": [ "# Training" ] }, { "cell_type": "markdown", "metadata": { "id": "QSpueK5UPsN7", "colab_type": "text" }, "source": [ "## CoLA\n", "\n", "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)" ] }, { "cell_type": "code", "metadata": { "id": "NJnFmtpnPu0Y", "colab_type": "code", "colab": {} }, "source": [ "mocked_args = \"\"\"\n", " --model_name_or_path albert-base-v2\n", " --task_name cola\n", " --max_epochs 3\n", " --gpus 1\"\"\".split()\n", "\n", "args = parse_args(mocked_args)\n", "dm, model, trainer = main(args)\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "_MrNsTnqdz4z", "colab_type": "text" }, "source": [ "## MRPC\n", "\n", "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)" ] }, { "cell_type": "code", "metadata": { "id": "LBwRxg9Cb3d-", "colab_type": "code", "colab": {} }, "source": [ "mocked_args = \"\"\"\n", " --model_name_or_path distilbert-base-cased\n", " --task_name mrpc\n", " --max_epochs 3\n", " --gpus 1\"\"\".split()\n", "\n", "args = parse_args(mocked_args)\n", "dm, model, trainer = main(args)\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "iZhbn0HzfdCu", "colab_type": "text" }, "source": [ "## MNLI\n", "\n", " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n", "\n", " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n", "\n", "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)" ] }, { "cell_type": "code", "metadata": { "id": "AvsZMOggfcWW", "colab_type": "code", "colab": {} }, "source": [ "mocked_args = \"\"\"\n", " --model_name_or_path distilbert-base-uncased\n", " --task_name mnli\n", " --max_epochs 1\n", " --gpus 1\n", " --limit_train_batches 10\n", " --progress_bar_refresh_rate 20\"\"\".split()\n", "\n", "args = parse_args(mocked_args)\n", "dm, model, trainer = main(args)\n", "trainer.fit(model, dm)" ], "execution_count": null, "outputs": [] } ] }