fix result obj dp auto reduce (#3013)

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* fix result for dp

* added warning when changing monitor and using results obj
This commit is contained in:
William Falcon 2020-08-17 10:29:39 -04:00 committed by GitHub
parent 51de6802ed
commit 8315a65d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 10 deletions

View File

@ -361,6 +361,14 @@ class Result(Dict):
result['meta'] = meta
return result
def dp_reduce(self):
for k, value in self.items():
if k == 'meta':
continue
if isinstance(value, list):
value = torch.tensor(value)
self[k] = value.mean(dim=-1)
@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']

View File

@ -343,17 +343,20 @@ class TrainerEvaluationLoopMixin(ABC):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)
# ------------------
# EVAL STEP END
# ------------------
# on dp / ddp2 might still want to do something with the batch parts
if test_mode:
if self.is_overridden('test_step_end'):
model_ref = self.get_model()
with self.profiler.profile('test_step_end'):
output = model_ref.test_step_end(output)
else:
if self.is_overridden('validation_step_end'):
model_ref = self.get_model()
with self.profiler.profile('validation_step_end'):
output = model_ref.validation_step_end(output)
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
if self.is_overridden(eval_step_end_hook_name):
model_ref = self.get_model()
with self.profiler.profile(eval_step_end_hook_name):
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
output = eval_step_end(output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
# result auto reduce
output.dp_reduce()
# callbacks (on __batch_end)
if test_mode:

View File

@ -1221,6 +1221,8 @@ class TrainerTrainLoopMixin(ABC):
else:
output = self.model.training_step(*args)
is_result_obj = isinstance(output, Result)
# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden('training_step_end'):
@ -1229,6 +1231,9 @@ class TrainerTrainLoopMixin(ABC):
# TODO: modify when using result obj
output = model_ref.training_step_end(output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
output.dp_reduce()
# allow any mode to define training_end
# TODO: remove in 1.0.0
if self.is_overridden('training_end'):

View File

@ -79,6 +79,28 @@ class TrainingStepVariations(ABC):
self.training_step_called = True
return result
def training_step_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))
# calculate loss
loss_val = self.loss(y.to(y_hat.device), y_hat)
log_val = loss_val
# alternate between tensors and scalars for "log" and "progress_bar"
if batch_idx % 2 == 0:
log_val = log_val.item()
result = TrainResult(loss_val)
result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
result.log('train_some_val', log_val * log_val)
self.training_step_called = True
return result
def training_step_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)

View File

@ -52,6 +52,28 @@ class ValidationStepVariations(ABC):
})
return result
def validation_step_result_obj_dp(self, batch, batch_idx, *args, **kwargs):
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))
y = y.to(y_hat.device)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)
result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
result.log_dict({
'val_loss': loss_val,
'val_acc': val_acc,
})
self.validation_step_called = True
return result
def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Lightning calls this inside the validation loop

View File

@ -535,6 +535,41 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
assert 'epoch_train_epoch_end_metric' in seen_keys
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_loop_steps_only_dp(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
batches = 10
epochs = 3
model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_result_obj_dp
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj_dp
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dp',
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
assert model.training_step_called
assert model.validation_step_called
def test_result_map(tmpdir):
result = TrainResult()
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})