diff --git a/.gitignore b/.gitignore index 676e1d44..a5e9fe66 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,15 @@ text/.DS_Store src/ workdir/ *save*/ +test_py.sh + +# paraphraser +*/paraphrase/*.txt +lightning_logs/ +pytorch_model.bin +*/paraphrase/bart-* +cnn* +/tests/*/ # C extensions diff --git a/genienlp/paraphrase/finetune_bart.py b/genienlp/paraphrase/finetune_bart.py index 94a41e7a..8de57a3e 100644 --- a/genienlp/paraphrase/finetune_bart.py +++ b/genienlp/paraphrase/finetune_bart.py @@ -10,8 +10,6 @@ from torch.utils.data import DataLoader from genienlp.paraphrase.transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup, MODEL_MODES from genienlp.paraphrase.utils import Seq2SeqDataset, sort_checkpoints -from transformers import BartTokenizer, MBartTokenizer - logger = logging.getLogger(__name__) diff --git a/genienlp/paraphrase/transformer_base.py b/genienlp/paraphrase/transformer_base.py index 0ed10123..569cde44 100644 --- a/genienlp/paraphrase/transformer_base.py +++ b/genienlp/paraphrase/transformer_base.py @@ -251,7 +251,8 @@ def add_generic_args(parser, root_dir): help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) - + + parser.add_argument("--max_to_keep", type=int, default=2, help="Number of checkpoints to keep") parser.add_argument("--n_gpu", type=int, default=1) parser.add_argument("--n_tpu_cores", type=int, default=0) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") @@ -289,8 +290,8 @@ def generic_train(model, args): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) checkpoint_callback = pl.callbacks.ModelCheckpoint( - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5 - ) + filepath=os.path.join(args.output_dir, 'mbart-{epoch:02d}'), monitor="val_loss", mode="min", save_top_k=args.max_to_keep +) train_params = dict( accumulate_grad_batches=args.gradient_accumulation_steps,