Edited using Colaboratory (#3601)

This commit is contained in:
Nathan Raw 2020-09-22 11:15:25 -06:00 committed by GitHub
parent 0b222fc6cf
commit ba01ec9dbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 25 deletions

View File

@ -6,8 +6,7 @@
"name": "04-transformers-text-classification.ipynb", "name": "04-transformers-text-classification.ipynb",
"provenance": [], "provenance": [],
"collapsed_sections": [], "collapsed_sections": [],
"toc_visible": true, "toc_visible": true
"include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
"name": "python3", "name": "python3",
@ -16,16 +15,6 @@
"accelerator": "GPU" "accelerator": "GPU"
}, },
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
@ -42,9 +31,10 @@
"---\n", "---\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
"\n", "\n",
" - [HuggingFace nlp](https://github.com/huggingface/nlp)\n", " - [HuggingFace datasets](https://github.com/huggingface/datasets)\n",
" - [HuggingFace transformers](https://github.com/huggingface/transformers)" " - [HuggingFace transformers](https://github.com/huggingface/transformers)"
] ]
}, },
@ -83,7 +73,7 @@
"from datetime import datetime\n", "from datetime import datetime\n",
"from typing import Optional\n", "from typing import Optional\n",
"\n", "\n",
"import nlp\n", "import datasets\n",
"import numpy as np\n", "import numpy as np\n",
"import pytorch_lightning as pl\n", "import pytorch_lightning as pl\n",
"import torch\n", "import torch\n",
@ -97,7 +87,7 @@
" glue_compute_metrics\n", " glue_compute_metrics\n",
")" ")"
], ],
"execution_count": null, "execution_count": 2,
"outputs": [] "outputs": []
}, },
{ {
@ -147,7 +137,7 @@
" }\n", " }\n",
"\n", "\n",
" loader_columns = [\n", " loader_columns = [\n",
" 'nlp_idx',\n", " 'datasets_idx',\n",
" 'input_ids',\n", " 'input_ids',\n",
" 'token_type_ids',\n", " 'token_type_ids',\n",
" 'attention_mask',\n", " 'attention_mask',\n",
@ -177,7 +167,7 @@
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
"\n", "\n",
" def setup(self, stage):\n", " def setup(self, stage):\n",
" self.dataset = nlp.load_dataset('glue', self.task_name)\n", " self.dataset = datasets.load_dataset('glue', self.task_name)\n",
"\n", "\n",
" for split in self.dataset.keys():\n", " for split in self.dataset.keys():\n",
" self.dataset[split] = self.dataset[split].map(\n", " self.dataset[split] = self.dataset[split].map(\n",
@ -191,7 +181,7 @@
" self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n", " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
"\n", "\n",
" def prepare_data(self):\n", " def prepare_data(self):\n",
" nlp.load_dataset('glue', self.task_name)\n", " datasets.load_dataset('glue', self.task_name)\n",
" AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
" \n", " \n",
" def train_dataloader(self):\n", " def train_dataloader(self):\n",
@ -230,7 +220,7 @@
"\n", "\n",
" return features" " return features"
], ],
"execution_count": null, "execution_count": 3,
"outputs": [] "outputs": []
}, },
{ {
@ -297,7 +287,7 @@
"\n", "\n",
" self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n", " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n", " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
" self.metric = nlp.load_metric(\n", " self.metric = datasets.load_metric(\n",
" 'glue',\n", " 'glue',\n",
" self.hparams.task_name,\n", " self.hparams.task_name,\n",
" experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n", " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
@ -335,7 +325,7 @@
" if i == 0:\n", " if i == 0:\n",
" result = pl.EvalResult(checkpoint_on=loss)\n", " result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log(f'val_loss_{split}', loss, prog_bar=True)\n", " result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(preds, labels).items()}\n", " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
" result.log_dict(split_metrics, prog_bar=True)\n", " result.log_dict(split_metrics, prog_bar=True)\n",
" return result\n", " return result\n",
"\n", "\n",
@ -344,7 +334,7 @@
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n", " loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
" result = pl.EvalResult(checkpoint_on=loss)\n", " result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log('val_loss', loss, prog_bar=True)\n", " result.log('val_loss', loss, prog_bar=True)\n",
" result.log_dict(self.metric.compute(preds, labels), prog_bar=True)\n", " result.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
" return result\n", " return result\n",
"\n", "\n",
" def setup(self, stage):\n", " def setup(self, stage):\n",
@ -394,7 +384,7 @@
" parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
" return parser" " return parser"
], ],
"execution_count": null, "execution_count": 5,
"outputs": [] "outputs": []
}, },
{ {
@ -434,7 +424,7 @@
" trainer = pl.Trainer.from_argparse_args(args)\n", " trainer = pl.Trainer.from_argparse_args(args)\n",
" return dm, model, trainer" " return dm, model, trainer"
], ],
"execution_count": null, "execution_count": 6,
"outputs": [] "outputs": []
}, },
{ {