clean ups
This commit is contained in:
parent
2eac11805b
commit
d69fc8a45c
|
@ -188,52 +188,6 @@ def step(model, batch, opt, iteration, field, task, lr=None, grad_clip=None, wri
|
||||||
return loss.item(), {}, grad_norm
|
return loss.item(), {}, grad_norm
|
||||||
|
|
||||||
|
|
||||||
def create_mixed_set(args, train_sets, aux_sets, epoch):
|
|
||||||
|
|
||||||
assert len(aux_sets) == len(train_sets)
|
|
||||||
num_tasks = len(train_sets)
|
|
||||||
mixed_set = train_sets
|
|
||||||
|
|
||||||
for i in range(num_tasks):
|
|
||||||
|
|
||||||
train_set = train_sets[i]
|
|
||||||
aux_set = aux_sets[i]
|
|
||||||
|
|
||||||
assert aux_set.fields == train_set.fields
|
|
||||||
|
|
||||||
train_examples = train_set.examples
|
|
||||||
aux_examples = aux_set.examples
|
|
||||||
|
|
||||||
train_size = len(train_examples)
|
|
||||||
aux_size = len(aux_examples)
|
|
||||||
|
|
||||||
total_size = train_size + aux_size
|
|
||||||
|
|
||||||
if args.curriculum_strategy == 'linear':
|
|
||||||
next_fraction = args.curriculum_rate * epoch
|
|
||||||
elif args. curriculum_strategy == 'exp':
|
|
||||||
next_fraction = args.curriculum_rate * np.exp(epoch)
|
|
||||||
|
|
||||||
fraction = min(args.curriculum_max_frac, next_fraction)
|
|
||||||
|
|
||||||
train_size_target = int((1 - fraction) * total_size)
|
|
||||||
aux_size_target = int(fraction * total_size)
|
|
||||||
if aux_size_target > aux_size:
|
|
||||||
aux_size_target = aux_size
|
|
||||||
train_size_target = int(aux_size * (1 - fraction) / fraction)
|
|
||||||
elif train_size_target > train_size:
|
|
||||||
train_size_target = train_size
|
|
||||||
aux_size_target = int(train_size * fraction / (1 - fraction))
|
|
||||||
|
|
||||||
logging.info(f'at epoch {epoch} we have {train_size_target} examples from training set and {aux_size_target} examples from auxiliary training set')
|
|
||||||
|
|
||||||
train_set_indices = np.random.choice(range(train_size), size=train_size_target, replace=False)
|
|
||||||
aux_set_indices = np.random.choice(range(aux_size), size=aux_size_target, replace=False)
|
|
||||||
|
|
||||||
setattr(mixed_set[i], 'examples', [train_examples[i] for i in train_set_indices] + [aux_examples[i] for i in aux_set_indices])
|
|
||||||
|
|
||||||
return mixed_set
|
|
||||||
|
|
||||||
def update_fraction(args, task_iteration):
|
def update_fraction(args, task_iteration):
|
||||||
|
|
||||||
if args.curriculum_strategy == 'linear':
|
if args.curriculum_strategy == 'linear':
|
||||||
|
|
Loading…
Reference in New Issue