diff --git a/train.py b/train.py index 42b416c5..b6e88c73 100644 --- a/train.py +++ b/train.py @@ -142,7 +142,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ """main training function""" logger = log(rank) - local_loss, num_examples, len_contexts, len_answers, iteration = 0, 0, 0, 0, 1 + local_loss, num_examples, len_contexts, len_answers, iteration = 0, 0, 0, 0, start_iteration train_iter_deep = deepcopy(train_iterations) local_train_metric_dict = {}