clean ups

This commit is contained in:
mehrad 2019-03-15 13:16:27 -07:00
parent 2eac11805b
commit d69fc8a45c
1 changed files with 0 additions and 46 deletions

View File

@ -188,52 +188,6 @@ def step(model, batch, opt, iteration, field, task, lr=None, grad_clip=None, wri
return loss.item(), {}, grad_norm
def create_mixed_set(args, train_sets, aux_sets, epoch):
assert len(aux_sets) == len(train_sets)
num_tasks = len(train_sets)
mixed_set = train_sets
for i in range(num_tasks):
train_set = train_sets[i]
aux_set = aux_sets[i]
assert aux_set.fields == train_set.fields
train_examples = train_set.examples
aux_examples = aux_set.examples
train_size = len(train_examples)
aux_size = len(aux_examples)
total_size = train_size + aux_size
if args.curriculum_strategy == 'linear':
next_fraction = args.curriculum_rate * epoch
elif args. curriculum_strategy == 'exp':
next_fraction = args.curriculum_rate * np.exp(epoch)
fraction = min(args.curriculum_max_frac, next_fraction)
train_size_target = int((1 - fraction) * total_size)
aux_size_target = int(fraction * total_size)
if aux_size_target > aux_size:
aux_size_target = aux_size
train_size_target = int(aux_size * (1 - fraction) / fraction)
elif train_size_target > train_size:
train_size_target = train_size
aux_size_target = int(train_size * fraction / (1 - fraction))
logging.info(f'at epoch {epoch} we have {train_size_target} examples from training set and {aux_size_target} examples from auxiliary training set')
train_set_indices = np.random.choice(range(train_size), size=train_size_target, replace=False)
aux_set_indices = np.random.choice(range(aux_size), size=aux_size_target, replace=False)
setattr(mixed_set[i], 'examples', [train_examples[i] for i in train_set_indices] + [aux_examples[i] for i in aux_set_indices])
return mixed_set
def update_fraction(args, task_iteration):
if args.curriculum_strategy == 'linear':