From a46578457a6753f8e7ca752b519e7803c4cb6a81 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Wed, 19 Sep 2018 17:07:54 +0000 Subject: [PATCH] bwds compat for old CoVe params; print metrics after preds --- predict.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/predict.py b/predict.py index fa1baade..d2a82bc9 100644 --- a/predict.py +++ b/predict.py @@ -155,10 +155,10 @@ def run(args, field, val_sets, model): with open(results_file_name) as results_file: metrics = json.loads(results_file.readlines()[0]) - print(metrics) if not args.silent: for i, (p, a) in enumerate(zip(predictions, answers)): print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n') + print(metrics) def get_args(): @@ -188,6 +188,8 @@ def get_args(): for r in retrieve: if r in config: setattr(args, r, config[r]) + elif 'cove' in r: + setattr(args, r, False) else: setattr(args, r, None) args.dropout_ratio = 0.0 @@ -259,7 +261,14 @@ if __name__ == '__main__': print(f'Initializing Model') Model = getattr(models, args.model) model = Model(field, args) - model.load_state_dict(save_dict['model_state_dict']) + model_dict = save_dict['model_state_dict'] + backwards_compatible_cove_dict = {} + for k, v in model_dict.items(): + if 'cove.rnn.' in k: + k = k.replace('cove.rnn.', 'cove.rnn1.') + backwards_compatible_cove_dict[k] = v + model_dict = backwards_compatible_cove_dict + model.load_state_dict(model_dict) field, splits = prepare_data(args, field) model.set_embeddings(field.vocab.vectors)