diff --git a/decanlp/train.py b/decanlp/train.py index 30411b52..7b142b27 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -32,30 +32,25 @@ import os import math import time -import random -import collections import sys from copy import deepcopy import logging from pprint import pformat from logging import handlers -import ujson as json import torch -import numpy as np from .text import torchtext from tensorboardX import SummaryWriter -import string -from decanlp import arguments +from . import arguments from . import models from .validate import validate from .multiprocess import Multiprocess, DistributedDataParallel -from .metrics import compute_metrics from .util import elapsed_time, get_splits, batch_fn, set_seed, preprocess_examples, get_trainable_params, count_params +from .utils.saver import Saver def initialize_logger(args, rank='main'): # set up file logger @@ -185,6 +180,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ local_train_metric_dict = {} train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters] + saver = Saver(args.log_dir) while True: # For some number of rounds, we 'jump start' some subset of the tasks @@ -254,7 +250,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ save_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field, 'best_decascore': best_decascore} - torch.save(save_state_dict, os.path.join(args.log_dir, f'iteration_{iteration}.pth')) + saver.save(save_state_dict, global_step=iteration) if should_save_best: logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}found new best model') torch.save(save_state_dict, os.path.join(args.log_dir, 'best.pth')) diff --git a/decanlp/utils/saver.py b/decanlp/utils/saver.py new file mode 100644 index 00000000..8224c12e --- /dev/null +++ b/decanlp/utils/saver.py @@ -0,0 +1,89 @@ +# +# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Created on Mar 3, 2019 + +@author: gcampagn +''' + +import torch +import ujson as json +import os +import logging + +logger = logging.getLogger(__name__) + +class Saver(object): + ''' + Wrap pytorch's save functionality into an interface similar to tensorflow.train.Saver + + In particular, this class takes care of automatically cleaning up old checkpoints, + and creating checkpoint files to keep track of which saves are valid and which are not. + ''' + + def __init__(self, savedir, max_to_keep=5): + self._savedir = savedir + self._max_to_keep = max_to_keep + assert max_to_keep >= 1 + + self._loaded_last_checkpoints = False + self._latest_checkpoint = None + self._all_checkpoints = None + + def _maybe_load_last_checkpoints(self): + if self._loaded_last_checkpoints: + return + + try: + with open(os.path.join(self._savedir, 'checkpoint.json')) as fp: + data = json.load(fp) + self._loaded_last_checkpoints = True + self._all_checkpoints = data['all'] + self._latest_checkpoint = data['latest'] + except FileNotFoundError: + self._loaded_last_checkpoints = True + self._all_checkpoints = [] + self._latest_checkpoint = None + + def save(self, save_dict, global_step): + self._maybe_load_last_checkpoints() + + filename = 'iteration_' + str(global_step) + abspath = os.path.join(self._savedir, filename) + + self._latest_checkpoint = filename + self._all_checkpoints.append(filename) + if len(self._all_checkpoints) > self._max_to_keep: + try: + todelete = self._all_checkpoints.pop(0) + os.unlink(os.path.join(self._savedir, todelete)) + except (OSError, IOError) as e: + logging.warn('Failed to delete old checkpoint: %s', e) + torch.save(save_dict, abspath) + \ No newline at end of file