bwds compat for old CoVe params; print metrics after preds

This commit is contained in:
Bryan Marcus McCann 2018-09-19 17:07:54 +00:00 committed by Bryan McCann
parent 4a962fa2bc
commit a46578457a
1 changed files with 11 additions and 2 deletions

View File

@ -155,10 +155,10 @@ def run(args, field, val_sets, model):
with open(results_file_name) as results_file:
metrics = json.loads(results_file.readlines()[0])
print(metrics)
if not args.silent:
for i, (p, a) in enumerate(zip(predictions, answers)):
print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n')
print(metrics)
def get_args():
@ -188,6 +188,8 @@ def get_args():
for r in retrieve:
if r in config:
setattr(args, r, config[r])
elif 'cove' in r:
setattr(args, r, False)
else:
setattr(args, r, None)
args.dropout_ratio = 0.0
@ -259,7 +261,14 @@ if __name__ == '__main__':
print(f'Initializing Model')
Model = getattr(models, args.model)
model = Model(field, args)
model.load_state_dict(save_dict['model_state_dict'])
model_dict = save_dict['model_state_dict']
backwards_compatible_cove_dict = {}
for k, v in model_dict.items():
if 'cove.rnn.' in k:
k = k.replace('cove.rnn.', 'cove.rnn1.')
backwards_compatible_cove_dict[k] = v
model_dict = backwards_compatible_cove_dict
model.load_state_dict(model_dict)
field, splits = prepare_data(args, field)
model.set_embeddings(field.vocab.vectors)