fix ckpt names
This commit is contained in:
parent
6b56b4f2cb
commit
4734a3a47c
|
@ -22,6 +22,15 @@ text/.DS_Store
|
|||
src/
|
||||
workdir/
|
||||
*save*/
|
||||
test_py.sh
|
||||
|
||||
# paraphraser
|
||||
*/paraphrase/*.txt
|
||||
lightning_logs/
|
||||
pytorch_model.bin
|
||||
*/paraphrase/bart-*
|
||||
cnn*
|
||||
/tests/*/
|
||||
|
||||
|
||||
# C extensions
|
||||
|
|
|
@ -10,8 +10,6 @@ from torch.utils.data import DataLoader
|
|||
from genienlp.paraphrase.transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup, MODEL_MODES
|
||||
from genienlp.paraphrase.utils import Seq2SeqDataset, sort_checkpoints
|
||||
|
||||
from transformers import BartTokenizer, MBartTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -251,7 +251,8 @@ def add_generic_args(parser, root_dir):
|
|||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument("--max_to_keep", type=int, default=2, help="Number of checkpoints to keep")
|
||||
parser.add_argument("--n_gpu", type=int, default=1)
|
||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
|
@ -289,8 +290,8 @@ def generic_train(model, args):
|
|||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
||||
)
|
||||
filepath=os.path.join(args.output_dir, 'mbart-{epoch:02d}'), monitor="val_loss", mode="min", save_top_k=args.max_to_keep
|
||||
)
|
||||
|
||||
train_params = dict(
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
|
|
Loading…
Reference in New Issue