Clean up old checkpoints as we go along
Introduce an utility Saver class, that does what tensorflow's Saver does: keeps track of saved checkpoints in a separate file, and deletes the old ones before saving a new one.
This commit is contained in:
parent
410c6cd8ec
commit
b950927a2b
|
@ -32,30 +32,25 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import collections
|
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from logging import handlers
|
from logging import handlers
|
||||||
import ujson as json
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .text import torchtext
|
from .text import torchtext
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
import string
|
|
||||||
|
|
||||||
from decanlp import arguments
|
from . import arguments
|
||||||
from . import models
|
from . import models
|
||||||
from .validate import validate
|
from .validate import validate
|
||||||
from .multiprocess import Multiprocess, DistributedDataParallel
|
from .multiprocess import Multiprocess, DistributedDataParallel
|
||||||
from .metrics import compute_metrics
|
|
||||||
from .util import elapsed_time, get_splits, batch_fn, set_seed, preprocess_examples, get_trainable_params, count_params
|
from .util import elapsed_time, get_splits, batch_fn, set_seed, preprocess_examples, get_trainable_params, count_params
|
||||||
|
from .utils.saver import Saver
|
||||||
|
|
||||||
def initialize_logger(args, rank='main'):
|
def initialize_logger(args, rank='main'):
|
||||||
# set up file logger
|
# set up file logger
|
||||||
|
@ -185,6 +180,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
||||||
local_train_metric_dict = {}
|
local_train_metric_dict = {}
|
||||||
|
|
||||||
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
|
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
|
||||||
|
saver = Saver(args.log_dir)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# For some number of rounds, we 'jump start' some subset of the tasks
|
# For some number of rounds, we 'jump start' some subset of the tasks
|
||||||
|
@ -254,7 +250,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
||||||
save_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field,
|
save_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field,
|
||||||
'best_decascore': best_decascore}
|
'best_decascore': best_decascore}
|
||||||
|
|
||||||
torch.save(save_state_dict, os.path.join(args.log_dir, f'iteration_{iteration}.pth'))
|
saver.save(save_state_dict, global_step=iteration)
|
||||||
if should_save_best:
|
if should_save_best:
|
||||||
logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}found new best model')
|
logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}found new best model')
|
||||||
torch.save(save_state_dict, os.path.join(args.log_dir, 'best.pth'))
|
torch.save(save_state_dict, os.path.join(args.log_dir, 'best.pth'))
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
#
|
||||||
|
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# * Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# * Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
'''
|
||||||
|
Created on Mar 3, 2019
|
||||||
|
|
||||||
|
@author: gcampagn
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import ujson as json
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Saver(object):
|
||||||
|
'''
|
||||||
|
Wrap pytorch's save functionality into an interface similar to tensorflow.train.Saver
|
||||||
|
|
||||||
|
In particular, this class takes care of automatically cleaning up old checkpoints,
|
||||||
|
and creating checkpoint files to keep track of which saves are valid and which are not.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, savedir, max_to_keep=5):
|
||||||
|
self._savedir = savedir
|
||||||
|
self._max_to_keep = max_to_keep
|
||||||
|
assert max_to_keep >= 1
|
||||||
|
|
||||||
|
self._loaded_last_checkpoints = False
|
||||||
|
self._latest_checkpoint = None
|
||||||
|
self._all_checkpoints = None
|
||||||
|
|
||||||
|
def _maybe_load_last_checkpoints(self):
|
||||||
|
if self._loaded_last_checkpoints:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(os.path.join(self._savedir, 'checkpoint.json')) as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
self._loaded_last_checkpoints = True
|
||||||
|
self._all_checkpoints = data['all']
|
||||||
|
self._latest_checkpoint = data['latest']
|
||||||
|
except FileNotFoundError:
|
||||||
|
self._loaded_last_checkpoints = True
|
||||||
|
self._all_checkpoints = []
|
||||||
|
self._latest_checkpoint = None
|
||||||
|
|
||||||
|
def save(self, save_dict, global_step):
|
||||||
|
self._maybe_load_last_checkpoints()
|
||||||
|
|
||||||
|
filename = 'iteration_' + str(global_step)
|
||||||
|
abspath = os.path.join(self._savedir, filename)
|
||||||
|
|
||||||
|
self._latest_checkpoint = filename
|
||||||
|
self._all_checkpoints.append(filename)
|
||||||
|
if len(self._all_checkpoints) > self._max_to_keep:
|
||||||
|
try:
|
||||||
|
todelete = self._all_checkpoints.pop(0)
|
||||||
|
os.unlink(os.path.join(self._savedir, todelete))
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
logging.warn('Failed to delete old checkpoint: %s', e)
|
||||||
|
torch.save(save_dict, abspath)
|
||||||
|
|
Loading…
Reference in New Issue