bwds compat for old CoVe params; print metrics after preds
This commit is contained in:
parent
4a962fa2bc
commit
a46578457a
13
predict.py
13
predict.py
|
@ -155,10 +155,10 @@ def run(args, field, val_sets, model):
|
|||
with open(results_file_name) as results_file:
|
||||
metrics = json.loads(results_file.readlines()[0])
|
||||
|
||||
print(metrics)
|
||||
if not args.silent:
|
||||
for i, (p, a) in enumerate(zip(predictions, answers)):
|
||||
print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n')
|
||||
print(metrics)
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -188,6 +188,8 @@ def get_args():
|
|||
for r in retrieve:
|
||||
if r in config:
|
||||
setattr(args, r, config[r])
|
||||
elif 'cove' in r:
|
||||
setattr(args, r, False)
|
||||
else:
|
||||
setattr(args, r, None)
|
||||
args.dropout_ratio = 0.0
|
||||
|
@ -259,7 +261,14 @@ if __name__ == '__main__':
|
|||
print(f'Initializing Model')
|
||||
Model = getattr(models, args.model)
|
||||
model = Model(field, args)
|
||||
model.load_state_dict(save_dict['model_state_dict'])
|
||||
model_dict = save_dict['model_state_dict']
|
||||
backwards_compatible_cove_dict = {}
|
||||
for k, v in model_dict.items():
|
||||
if 'cove.rnn.' in k:
|
||||
k = k.replace('cove.rnn.', 'cove.rnn1.')
|
||||
backwards_compatible_cove_dict[k] = v
|
||||
model_dict = backwards_compatible_cove_dict
|
||||
model.load_state_dict(model_dict)
|
||||
field, splits = prepare_data(args, field)
|
||||
model.set_embeddings(field.vocab.vectors)
|
||||
|
||||
|
|
Loading…
Reference in New Issue