#!/usr/bin/env python3 import os import math import time import random import collections from copy import deepcopy import logging from pprint import pformat from logging import handlers import ujson as json import torch import numpy as np from text import torchtext from tensorboardX import SummaryWriter import string import arguments import models from validate import validate from multiprocess import Multiprocess, DistributedDataParallel from metrics import compute_metrics from util import elapsed_time, get_splits, batch_fn, set_seed, preprocess_examples, get_trainable_params, count_params def initialize_logger(args, rank='main'): # set up file logger logger = logging.getLogger(f'process_{rank}') logger.setLevel(logging.DEBUG) handler = handlers.RotatingFileHandler(os.path.join(args.log_dir, f'process_{rank}.log'), maxBytes=1024*1024*10, backupCount=1) handler.setLevel(logging.DEBUG) formatter = logging.Formatter('%(name)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) handler = logging.StreamHandler() handler.setFormatter(formatter) handler.setLevel(logging.DEBUG) logger.addHandler(handler) logger.propagate = False return logger def log(rank='main'): return logging.getLogger(f'process_{rank}') def prepare_data(args, field, logger): if field is None: logger.info(f'Constructing field') FIELD = torchtext.data.ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) else: FIELD = field train_sets, val_sets, vocab_sets = [], [], [] for task in args.train_tasks: logger.info(f'Loading {task}') kwargs = {'test': None} kwargs['subsample'] = args.subsample kwargs['validation'] = None logger.info(f'Adding {task} to training datasets') split = get_splits(args, task, FIELD, **kwargs)[0] logger.info(f'{task} has {len(split)} training examples') train_sets.append(split) if args.vocab_tasks is not None and task in args.vocab_tasks: vocab_sets.extend(split) for task in args.val_tasks: logger.info(f'Loading {task}') kwargs = {'test': None} kwargs['subsample'] = args.subsample kwargs['train'] = None logger.info(f'Adding {task} to validation datasets') split = get_splits(args, task, FIELD, **kwargs)[0] logger.info(f'{task} has {len(split)} validation examples') val_sets.append(split) if args.vocab_tasks is not None and task in args.vocab_tasks: vocab_sets.extend(split) if args.load is None: logger.info(f'Getting pretrained word vectors') char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings) vectors = [char_vectors, glove_vectors] vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets logger.info(f'Building vocabulary') FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab] FIELD.decoder_stoi = {word: idx for idx, word in enumerate(FIELD.decoder_itos)} FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)} FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi} logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens') logger.info(f'The first 500 tokens:') print(FIELD.vocab.itos[:500]) logger.info('Preprocessing training data') preprocess_examples(args, args.train_tasks, train_sets, FIELD, logger, train=True) logger.info('Preprocessing validation data') preprocess_examples(args, args.val_tasks, val_sets, FIELD, logger, train=args.val_filter) return FIELD, train_sets, val_sets def to_iter(args, world_size, val_batch_size, data, train=True, token_testing=False, sort=None): sort = sort if not token_testing else True shuffle = None if not token_testing else False reverse = args.reverse Iterator = torchtext.data.BucketIterator if train else torchtext.data.Iterator it = Iterator(data, batch_size=val_batch_size, device=0 if world_size > 0 else -1, batch_size_fn=batch_fn if train else None, distributed=world_size>1, train=train, repeat=train, sort=sort, shuffle=shuffle, reverse=args.reverse) return it def get_learning_rate(i, args): return 0.1 * 10 / math.sqrt(args.dimension) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) def step(model, batch, opt, iteration, field, task, lr=None, grad_clip=None, writer=None, it=None): model.train() opt.zero_grad() loss, predictions = model(batch) loss.backward() if lr is not None: opt.param_groups[0]['lr'] = lr if grad_clip > 0.0: torch.nn.utils.clip_grad_norm(model.params, grad_clip) opt.step() return loss.data[0], {} def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_size=1, log_every=10, val_every=100, save_every=1000, rounds=False, val_iters=[], writer=None, start_iteration=1, rnd=1): """main training function""" logger = log(rank) local_loss, num_examples, len_contexts, len_answers, iteration = 0, 0, 0, 0, start_iteration train_iter_deep = deepcopy(train_iterations) local_train_metric_dict = {} train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters] while True: # For some number of rounds, we 'jump start' some subset of the tasks # by training them and not others # once the specified number of rounds is completed, # switch to normal round robin training if rnd start_iteration: task_progress = f'{task_iteration}/{task_iterations}:' if task_iterations is not None else '' round_progress = f'round_{rnd}:' if rounds else '' # validate if (val_every is not None and ((iteration % args.val_every == 0 % args.val_every) or (args.load and iteration == start_iteration + 1))): train_task_val_metric = None for val_task_idx, (val_task, val_iter) in enumerate(val_iters): val_loss, metric_dict = validate(val_task, val_iter, model, logger, field, world_size, rank, num_print=args.num_print, args=args) if val_loss is not None: log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}val_{val_task}:val_loss{val_loss.data[0]:.4f}:' writer.add_scalars(f'loss/val', {val_task: val_loss.data[0]}, iteration) else: log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}val_{val_task}:' metric_entry = '' for metric_key, metric_value in metric_dict.items(): metric_entry += f'{metric_key}_{metric_value:.2f}:' metric_entry = metric_entry[:-1] # val log logger.info(log_entry + metric_entry) if writer is not None: for metric_key, metric_value in metric_dict.items(): writer.add_scalars(f'val/{metric_key}', {val_task: metric_value}, iteration) writer.add_scalars(f'{metric_key}/val', {val_task: metric_value}, iteration) writer.add_scalars(f'{metric_key}/{val_task}', {'val': metric_value}, iteration) writer.add_scalars(f'{val_task}/{metric_key}', {'val': metric_value}, iteration) writer.add_scalars(f'{val_task}/val', {f'{metric_key}': metric_value}, iteration) # saving if save_every is not None and (iteration % args.save_every == 0 % args.save_every): if world_size > 1: torch.distributed.barrier() if rank is not None and rank == 0: torch.save({'model_state_dict': model.state_dict(), 'field': field}, os.path.join(args.log_dir, f'iteration_{iteration}.pth')) if world_size > 1: torch.distributed.barrier() torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) if world_size > 1: torch.distributed.barrier() # lr update lr = opt.param_groups[0]['lr'] if args.warmup > 0 and args.transformer_lr: lr = get_learning_rate(iteration, args) # param update loss, train_metric_dict = step(model, batch, opt, iteration, field, task, lr=lr, grad_clip=args.grad_clip, writer=writer, it=train_iter) # train metrics local_loss += loss for metric_name, metric_val in train_metric_dict.items(): if metric_name in local_train_metric_dict: local_train_metric_dict[metric_name] += metric_val / args.log_every else: local_train_metric_dict[metric_name] = metric_val / args.log_every # train logs num_examples += batch.context.size(0) len_contexts += batch.context.size(1) len_answers += batch.answer.size(1) if log_every is not None and (iteration % log_every == 0 % log_every): local_loss /= args.log_every num_examples /= args.log_every len_contexts /= args.log_every len_answers /= args.log_every avg_batch_size = f'avbatch_{num_examples:.0f}_{len_contexts:.0f}_{len_answers:.0f}:' metric_entry = '' for metric_key, metric_value in local_train_metric_dict.items(): metric_entry += f'{metric_key}_{metric_value:.2f}:' metric_entry = f'{metric_entry[:-1]}' logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}{avg_batch_size}loss_{local_loss:.4f}{metric_entry}') num_examples = 0 len_contexts = 0 len_answers = 0 if writer is not None: writer.add_scalars(f'loss/train', {f'{task}': local_loss}, iteration) for metric_key, metric_value in local_train_metric_dict.items(): writer.add_scalars(f'train/{metric_key}', {task: metric_value}, iteration) writer.add_scalars(f'{metric_key}/train', {task: metric_value}, iteration) writer.add_scalars(f'{metric_key}/{task}', {'train': metric_value}, iteration) writer.add_scalars(f'{task}/{metric_key}', {'train': metric_value}, iteration) writer.add_scalars(f'{task}/train', {f'{metric_key}': metric_value}, iteration) local_loss = 0 local_train_metric_dict = {} num_examples = 0 # book keeping task_iteration += 1 iteration += 1 if task_iterations is not None and task_iteration > task_iterations: break # book keeping rnd += 1 if not rounds: break def run(args, run_args, rank=0, world_size=1): set_seed(args, rank=rank) logger = initialize_logger(args, rank) field, train_sets, val_sets, save_dict = run_args logger.start = time.time() logger.info(f'Preparing iterators') train_iters = [(name, to_iter(args, world_size, tok, x, token_testing=args.token_testing)) for name, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)] val_iters = [(name, to_iter(args, world_size, tok, x, train=False, token_testing=args.token_testing, sort=False if 'sql' in name else None)) for name, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)] logger.info(f'Initializing Writer') writer = SummaryWriter(log_dir=args.log_dir) model = init_model(args, field, logger, world_size) opt = init_opt(args, model) start_iteration = 1 if save_dict is not None: logger.info(f'Loading model from {os.path.join(args.save, args.load)}') save_dict = torch.load(os.path.join(args.save, args.load)) model.load_state_dict(save_dict['model_state_dict']) if args.resume: logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth') opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth'))) start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1]) logger.info(f'Begin Training') train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters, rank=rank, world_size=world_size, log_every=args.log_every, val_every=args.val_every, rounds=len(train_iters)>1, writer=writer if rank==0 else None, save_every=args.save_every, start_iteration=start_iteration) def init_model(args, field, logger, world_size): logger.info(f'Initializing {args.model}') Model = getattr(models, args.model) model = Model(field, args) params = get_trainable_params(model) num_param = count_params(params) logger.info(f'{args.model} has {num_param:,} parameters') if args.gpus[0] > -1: model.cuda() if world_size > 1: logger.info(f'Wrapping model for distributed') model = DistributedDataParallel(model) model.params = params return model def init_opt(args, model): opt = None if args.transformer_lr: opt = torch.optim.Adam(model.params, betas=(0.9, 0.98), eps=1e-9) else: opt = torch.optim.Adam(model.params, betas=(args.beta0, 0.999)) return opt def main(): args = arguments.parse() if args is None: return set_seed(args) logger = initialize_logger(args) logger.info(f'Arguments:\n{pformat(vars(args))}') field, save_dict = None, None if args.load is not None: logger.info(f'Loading field from {os.path.join(args.save, args.load)}') save_dict = torch.load(os.path.join(args.save, args.load)) field = save_dict['field'] field, train_sets, val_sets = prepare_data(args, field, logger) run_args = (field, train_sets, val_sets, save_dict) if len(args.gpus) > 1: logger.info(f'Multiprocessing') mp = Multiprocess(run, args) mp.run(run_args) else: logger.info(f'Processing') run(args, run_args, world_size=args.world_size) if __name__ == '__main__': main()