546 lines
20 KiB
Plaintext
546 lines
20 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "04-transformers-text-classification.ipynb",
|
|
"provenance": [],
|
|
"collapsed_sections": [],
|
|
"toc_visible": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "8ag5ANQPJ_j9",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
|
|
"\n",
|
|
"This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
|
|
"\n",
|
|
"[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
|
|
"\n",
|
|
"---\n",
|
|
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
|
|
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
|
|
" - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n",
|
|
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
|
|
"\n",
|
|
" - [HuggingFace datasets](https://github.com/huggingface/datasets)\n",
|
|
" - [HuggingFace transformers](https://github.com/huggingface/transformers)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fqlsVTj7McZ3",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "OIhHrRL-MnKK",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"!pip install pytorch-lightning datasets transformers"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "6yuQT_ZQMpCg",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"from argparse import ArgumentParser\n",
|
|
"from datetime import datetime\n",
|
|
"from typing import Optional\n",
|
|
"\n",
|
|
"import datasets\n",
|
|
"import numpy as np\n",
|
|
"import pytorch_lightning as pl\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from transformers import (\n",
|
|
" AdamW,\n",
|
|
" AutoModelForSequenceClassification,\n",
|
|
" AutoConfig,\n",
|
|
" AutoTokenizer,\n",
|
|
" get_linear_schedule_with_warmup,\n",
|
|
" glue_compute_metrics\n",
|
|
")"
|
|
],
|
|
"execution_count": 2,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "9ORJfiuiNZ_N",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## GLUE DataModule"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "jW9xQhZxMz1G",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"class GLUEDataModule(pl.LightningDataModule):\n",
|
|
"\n",
|
|
" task_text_field_map = {\n",
|
|
" 'cola': ['sentence'],\n",
|
|
" 'sst2': ['sentence'],\n",
|
|
" 'mrpc': ['sentence1', 'sentence2'],\n",
|
|
" 'qqp': ['question1', 'question2'],\n",
|
|
" 'stsb': ['sentence1', 'sentence2'],\n",
|
|
" 'mnli': ['premise', 'hypothesis'],\n",
|
|
" 'qnli': ['question', 'sentence'],\n",
|
|
" 'rte': ['sentence1', 'sentence2'],\n",
|
|
" 'wnli': ['sentence1', 'sentence2'],\n",
|
|
" 'ax': ['premise', 'hypothesis']\n",
|
|
" }\n",
|
|
"\n",
|
|
" glue_task_num_labels = {\n",
|
|
" 'cola': 2,\n",
|
|
" 'sst2': 2,\n",
|
|
" 'mrpc': 2,\n",
|
|
" 'qqp': 2,\n",
|
|
" 'stsb': 1,\n",
|
|
" 'mnli': 3,\n",
|
|
" 'qnli': 2,\n",
|
|
" 'rte': 2,\n",
|
|
" 'wnli': 2,\n",
|
|
" 'ax': 3\n",
|
|
" }\n",
|
|
"\n",
|
|
" loader_columns = [\n",
|
|
" 'datasets_idx',\n",
|
|
" 'input_ids',\n",
|
|
" 'token_type_ids',\n",
|
|
" 'attention_mask',\n",
|
|
" 'start_positions',\n",
|
|
" 'end_positions',\n",
|
|
" 'labels'\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" model_name_or_path: str,\n",
|
|
" task_name: str ='mrpc',\n",
|
|
" max_seq_length: int = 128,\n",
|
|
" train_batch_size: int = 32,\n",
|
|
" eval_batch_size: int = 32,\n",
|
|
" **kwargs\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" self.model_name_or_path = model_name_or_path\n",
|
|
" self.task_name = task_name\n",
|
|
" self.max_seq_length = max_seq_length\n",
|
|
" self.train_batch_size = train_batch_size\n",
|
|
" self.eval_batch_size = eval_batch_size\n",
|
|
"\n",
|
|
" self.text_fields = self.task_text_field_map[task_name]\n",
|
|
" self.num_labels = self.glue_task_num_labels[task_name]\n",
|
|
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
|
|
"\n",
|
|
" def setup(self, stage):\n",
|
|
" self.dataset = datasets.load_dataset('glue', self.task_name)\n",
|
|
"\n",
|
|
" for split in self.dataset.keys():\n",
|
|
" self.dataset[split] = self.dataset[split].map(\n",
|
|
" self.convert_to_features,\n",
|
|
" batched=True,\n",
|
|
" remove_columns=['label'],\n",
|
|
" )\n",
|
|
" self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
|
|
" self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
|
|
"\n",
|
|
" self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
|
|
"\n",
|
|
" def prepare_data(self):\n",
|
|
" datasets.load_dataset('glue', self.task_name)\n",
|
|
" AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
|
|
" \n",
|
|
" def train_dataloader(self):\n",
|
|
" return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
|
|
" \n",
|
|
" def val_dataloader(self):\n",
|
|
" if len(self.eval_splits) == 1:\n",
|
|
" return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
|
|
" elif len(self.eval_splits) > 1:\n",
|
|
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
|
|
"\n",
|
|
" def test_dataloader(self):\n",
|
|
" if len(self.eval_splits) == 1:\n",
|
|
" return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
|
|
" elif len(self.eval_splits) > 1:\n",
|
|
" return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
|
|
"\n",
|
|
" def convert_to_features(self, example_batch, indices=None):\n",
|
|
"\n",
|
|
" # Either encode single sentence or sentence pairs\n",
|
|
" if len(self.text_fields) > 1:\n",
|
|
" texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
|
|
" else:\n",
|
|
" texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
|
|
"\n",
|
|
" # Tokenize the text/text pairs\n",
|
|
" features = self.tokenizer.batch_encode_plus(\n",
|
|
" texts_or_text_pairs,\n",
|
|
" max_length=self.max_seq_length,\n",
|
|
" pad_to_max_length=True,\n",
|
|
" truncation=True\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Rename label to labels to make it easier to pass to model forward\n",
|
|
" features['labels'] = example_batch['label']\n",
|
|
"\n",
|
|
" return features"
|
|
],
|
|
"execution_count": 3,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jQC3a6KuOpX3",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"#### You could use this datamodule with standalone PyTorch if you wanted..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "JCMH3IAsNffF",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"dm = GLUEDataModule('distilbert-base-uncased')\n",
|
|
"dm.prepare_data()\n",
|
|
"dm.setup('fit')\n",
|
|
"next(iter(dm.train_dataloader()))"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "l9fQ_67BO2Lj",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## GLUE Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "gtn5YGKYO65B",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"class GLUETransformer(pl.LightningModule):\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" model_name_or_path: str,\n",
|
|
" num_labels: int,\n",
|
|
" learning_rate: float = 2e-5,\n",
|
|
" adam_epsilon: float = 1e-8,\n",
|
|
" warmup_steps: int = 0,\n",
|
|
" weight_decay: float = 0.0,\n",
|
|
" train_batch_size: int = 32,\n",
|
|
" eval_batch_size: int = 32,\n",
|
|
" eval_splits: Optional[list] = None,\n",
|
|
" **kwargs\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.save_hyperparameters()\n",
|
|
"\n",
|
|
" self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
|
|
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
|
|
" self.metric = datasets.load_metric(\n",
|
|
" 'glue',\n",
|
|
" self.hparams.task_name,\n",
|
|
" experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, **inputs):\n",
|
|
" return self.model(**inputs)\n",
|
|
"\n",
|
|
" def training_step(self, batch, batch_idx):\n",
|
|
" outputs = self(**batch)\n",
|
|
" loss = outputs[0]\n",
|
|
" return pl.TrainResult(loss)\n",
|
|
"\n",
|
|
" def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
|
|
" outputs = self(**batch)\n",
|
|
" val_loss, logits = outputs[:2]\n",
|
|
"\n",
|
|
" if self.hparams.num_labels >= 1:\n",
|
|
" preds = torch.argmax(logits, axis=1)\n",
|
|
" elif self.hparams.num_labels == 1:\n",
|
|
" preds = logits.squeeze()\n",
|
|
"\n",
|
|
" labels = batch[\"labels\"]\n",
|
|
"\n",
|
|
" return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
|
|
"\n",
|
|
" def validation_epoch_end(self, outputs):\n",
|
|
" if self.hparams.task_name == 'mnli':\n",
|
|
" for i, output in enumerate(outputs):\n",
|
|
" # matched or mismatched\n",
|
|
" split = self.hparams.eval_splits[i].split('_')[-1]\n",
|
|
" preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
|
|
" labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
|
|
" loss = torch.stack([x['loss'] for x in output]).mean()\n",
|
|
" if i == 0:\n",
|
|
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
|
" result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
|
|
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
|
|
" result.log_dict(split_metrics, prog_bar=True)\n",
|
|
" return result\n",
|
|
"\n",
|
|
" preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
|
|
" labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
|
|
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
|
|
" result = pl.EvalResult(checkpoint_on=loss)\n",
|
|
" result.log('val_loss', loss, prog_bar=True)\n",
|
|
" result.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
|
|
" return result\n",
|
|
"\n",
|
|
" def setup(self, stage):\n",
|
|
" if stage == 'fit':\n",
|
|
" # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
|
|
" train_loader = self.train_dataloader()\n",
|
|
"\n",
|
|
" # Calculate total steps\n",
|
|
" self.total_steps = (\n",
|
|
" (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
|
|
" // self.hparams.accumulate_grad_batches\n",
|
|
" * float(self.hparams.max_epochs)\n",
|
|
" )\n",
|
|
"\n",
|
|
" def configure_optimizers(self):\n",
|
|
" \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
|
|
" model = self.model\n",
|
|
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
|
|
" optimizer_grouped_parameters = [\n",
|
|
" {\n",
|
|
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
|
|
" \"weight_decay\": self.hparams.weight_decay,\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
|
|
" \"weight_decay\": 0.0,\n",
|
|
" },\n",
|
|
" ]\n",
|
|
" optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
|
|
"\n",
|
|
" scheduler = get_linear_schedule_with_warmup(\n",
|
|
" optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
|
|
" )\n",
|
|
" scheduler = {\n",
|
|
" 'scheduler': scheduler,\n",
|
|
" 'interval': 'step',\n",
|
|
" 'frequency': 1\n",
|
|
" }\n",
|
|
" return [optimizer], [scheduler]\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def add_model_specific_args(parent_parser):\n",
|
|
" parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
|
|
" parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
|
|
" parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
|
|
" parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
|
|
" parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
|
|
" return parser"
|
|
],
|
|
"execution_count": 5,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ha-NdIP_xbd3",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"### ⚡ Quick Tip \n",
|
|
" - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "3dEHnl3RPlAR",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"def parse_args(args=None):\n",
|
|
" parser = ArgumentParser()\n",
|
|
" parser = pl.Trainer.add_argparse_args(parser)\n",
|
|
" parser = GLUEDataModule.add_argparse_args(parser)\n",
|
|
" parser = GLUETransformer.add_model_specific_args(parser)\n",
|
|
" parser.add_argument('--seed', type=int, default=42)\n",
|
|
" return parser.parse_args(args)\n",
|
|
"\n",
|
|
"\n",
|
|
"def main(args):\n",
|
|
" pl.seed_everything(args.seed)\n",
|
|
" dm = GLUEDataModule.from_argparse_args(args)\n",
|
|
" dm.prepare_data()\n",
|
|
" dm.setup('fit')\n",
|
|
" model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
|
|
" trainer = pl.Trainer.from_argparse_args(args)\n",
|
|
" return dm, model, trainer"
|
|
],
|
|
"execution_count": 6,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "PkuLaeec3sJ-",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"# Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QSpueK5UPsN7",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## CoLA\n",
|
|
"\n",
|
|
"See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "NJnFmtpnPu0Y",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"mocked_args = \"\"\"\n",
|
|
" --model_name_or_path albert-base-v2\n",
|
|
" --task_name cola\n",
|
|
" --max_epochs 3\n",
|
|
" --gpus 1\"\"\".split()\n",
|
|
"\n",
|
|
"args = parse_args(mocked_args)\n",
|
|
"dm, model, trainer = main(args)\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_MrNsTnqdz4z",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## MRPC\n",
|
|
"\n",
|
|
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "LBwRxg9Cb3d-",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"mocked_args = \"\"\"\n",
|
|
" --model_name_or_path distilbert-base-cased\n",
|
|
" --task_name mrpc\n",
|
|
" --max_epochs 3\n",
|
|
" --gpus 1\"\"\".split()\n",
|
|
"\n",
|
|
"args = parse_args(mocked_args)\n",
|
|
"dm, model, trainer = main(args)\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "iZhbn0HzfdCu",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"## MNLI\n",
|
|
"\n",
|
|
" - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
|
|
"\n",
|
|
" - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
|
|
"\n",
|
|
"See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "AvsZMOggfcWW",
|
|
"colab_type": "code",
|
|
"colab": {}
|
|
},
|
|
"source": [
|
|
"mocked_args = \"\"\"\n",
|
|
" --model_name_or_path distilbert-base-uncased\n",
|
|
" --task_name mnli\n",
|
|
" --max_epochs 1\n",
|
|
" --gpus 1\n",
|
|
" --limit_train_batches 10\n",
|
|
" --progress_bar_refresh_rate 20\"\"\".split()\n",
|
|
"\n",
|
|
"args = parse_args(mocked_args)\n",
|
|
"dm, model, trainer = main(args)\n",
|
|
"trainer.fit(model, dm)"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |