satisfy travis

This commit is contained in:
mehrad 2019-03-01 15:08:31 -08:00
parent cf534b3523
commit 35fe30272b
2 changed files with 2 additions and 1 deletions

View File

@ -169,7 +169,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
batch_size = targets.size(0)
reference_lengths = [l-1 for l in answer_lengths]
translation_len = max(reference_lengths)
translation_lengths = torch.Tensor([translation_len] * batch_size, device=self.device)
translation_lengths = torch.tensor([translation_len] * batch_size, device=self.device)
bleu_loss_smoothed = expectedMultiBleu.bleu(probs, targets, translation_lengths, reference_lengths, max_order=max_order, smooth=True)
loss = -1 * bleu_loss_smoothed[0]

View File

@ -613,6 +613,7 @@ class WikiSQL(CQA, data.Dataset):
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
fields.append(('wikisql_id', FIELD))
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool: