fix ckpt names

This commit is contained in:
mehrad 2020-04-27 16:30:08 -07:00
parent 6b56b4f2cb
commit 4734a3a47c
3 changed files with 13 additions and 5 deletions

9
.gitignore vendored
View File

@ -22,6 +22,15 @@ text/.DS_Store
src/
workdir/
*save*/
test_py.sh
# paraphraser
*/paraphrase/*.txt
lightning_logs/
pytorch_model.bin
*/paraphrase/bart-*
cnn*
/tests/*/
# C extensions

View File

@ -10,8 +10,6 @@ from torch.utils.data import DataLoader
from genienlp.paraphrase.transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup, MODEL_MODES
from genienlp.paraphrase.utils import Seq2SeqDataset, sort_checkpoints
from transformers import BartTokenizer, MBartTokenizer
logger = logging.getLogger(__name__)

View File

@ -251,7 +251,8 @@ def add_generic_args(parser, root_dir):
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--max_to_keep", type=int, default=2, help="Number of checkpoints to keep")
parser.add_argument("--n_gpu", type=int, default=1)
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
@ -289,8 +290,8 @@ def generic_train(model, args):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)
filepath=os.path.join(args.output_dir, 'mbart-{epoch:02d}'), monitor="val_loss", mode="min", save_top_k=args.max_to_keep
)
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,