satisfy travis
This commit is contained in:
parent
cf534b3523
commit
35fe30272b
|
@ -169,7 +169,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
batch_size = targets.size(0)
|
||||
reference_lengths = [l-1 for l in answer_lengths]
|
||||
translation_len = max(reference_lengths)
|
||||
translation_lengths = torch.Tensor([translation_len] * batch_size, device=self.device)
|
||||
translation_lengths = torch.tensor([translation_len] * batch_size, device=self.device)
|
||||
|
||||
bleu_loss_smoothed = expectedMultiBleu.bleu(probs, targets, translation_lengths, reference_lengths, max_order=max_order, smooth=True)
|
||||
loss = -1 * bleu_loss_smoothed[0]
|
||||
|
|
|
@ -613,6 +613,7 @@ class WikiSQL(CQA, data.Dataset):
|
|||
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
|
||||
fields.append(('wikisql_id', FIELD))
|
||||
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
if os.path.exists(cache_name) and not skip_cache_bool:
|
||||
|
|
Loading…
Reference in New Issue