enumerating predictions; exp with intermediate cove

This commit is contained in:
Bryan Marcus McCann 2018-09-04 15:43:12 +00:00
parent cb96b024a6
commit 09cebc62a3
4 changed files with 9 additions and 6 deletions

View File

@ -63,6 +63,7 @@ def parse():
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')

View File

@ -61,9 +61,10 @@ RUN apt-get install --yes \
python-lxml
# WikISQL evaluation
RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove
RUN pip install records
RUN pip install babel
RUN pip install tabulate
RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove
CMD bash

View File

@ -25,9 +25,10 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
self.decoder_embeddings = Embedding(field, args.dimension,
dropout=args.dropout_ratio, project=True)
if self.args.cove:
self.cove = MTLSTM(model_cache=args.embeddings)
self.project_cove = Feedforward(1000, args.dimension)
if self.args.cove or self.args.intermediate_cove:
self.cove = MTLSTM(model_cache=args.embeddings, layer0=args.intermediate_cove, layer1=args.cove)
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings
self.project_cove = Feedforward(cove_dim, args.dimension)
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=1)

View File

@ -157,8 +157,8 @@ def run(args, field, val_sets, model):
print(metrics)
if not args.silent:
for p, a in zip(predictions, answers):
print(f'Prediction: {p}\nAnswer: {a}\n')
for i, (p, a) in enumerate(zip(predictions, answers)):
print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n')
def get_args():