fix result obj dp auto reduce (#3013)
* fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * fix result for dp * added warning when changing monitor and using results obj
This commit is contained in:
parent
51de6802ed
commit
8315a65d0a
|
@ -361,6 +361,14 @@ class Result(Dict):
|
|||
result['meta'] = meta
|
||||
return result
|
||||
|
||||
def dp_reduce(self):
|
||||
for k, value in self.items():
|
||||
if k == 'meta':
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
value = torch.tensor(value)
|
||||
self[k] = value.mean(dim=-1)
|
||||
|
||||
@property
|
||||
def should_reduce_on_epoch_end(self) -> bool:
|
||||
return self['meta']['_internal']['_reduce_on_epoch']
|
||||
|
|
|
@ -343,17 +343,20 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
m = 'only EvalResults or dicts are allowed from validation_step'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# ------------------
|
||||
# EVAL STEP END
|
||||
# ------------------
|
||||
# on dp / ddp2 might still want to do something with the batch parts
|
||||
if test_mode:
|
||||
if self.is_overridden('test_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('test_step_end'):
|
||||
output = model_ref.test_step_end(output)
|
||||
else:
|
||||
if self.is_overridden('validation_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('validation_step_end'):
|
||||
output = model_ref.validation_step_end(output)
|
||||
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
|
||||
if self.is_overridden(eval_step_end_hook_name):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile(eval_step_end_hook_name):
|
||||
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
|
||||
output = eval_step_end(output)
|
||||
|
||||
elif is_result_obj and (self.use_dp or self.use_ddp2):
|
||||
# result auto reduce
|
||||
output.dp_reduce()
|
||||
|
||||
# callbacks (on __batch_end)
|
||||
if test_mode:
|
||||
|
|
|
@ -1221,6 +1221,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
else:
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
||||
# allow any mode to define training_step_end
|
||||
# do something will all the dp outputs (like softmax)
|
||||
if self.is_overridden('training_step_end'):
|
||||
|
@ -1229,6 +1231,9 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# TODO: modify when using result obj
|
||||
output = model_ref.training_step_end(output)
|
||||
|
||||
elif is_result_obj and (self.use_dp or self.use_ddp2):
|
||||
output.dp_reduce()
|
||||
|
||||
# allow any mode to define training_end
|
||||
# TODO: remove in 1.0.0
|
||||
if self.is_overridden('training_end'):
|
||||
|
|
|
@ -79,6 +79,28 @@ class TrainingStepVariations(ABC):
|
|||
self.training_step_called = True
|
||||
return result
|
||||
|
||||
def training_step_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
|
||||
# forward pass
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x.to(self.device))
|
||||
|
||||
# calculate loss
|
||||
loss_val = self.loss(y.to(y_hat.device), y_hat)
|
||||
log_val = loss_val
|
||||
|
||||
# alternate between tensors and scalars for "log" and "progress_bar"
|
||||
if batch_idx % 2 == 0:
|
||||
log_val = log_val.item()
|
||||
|
||||
result = TrainResult(loss_val)
|
||||
result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
|
||||
result.log('train_some_val', log_val * log_val)
|
||||
|
||||
self.training_step_called = True
|
||||
|
||||
return result
|
||||
|
||||
def training_step_end_full_loop_result_obj_dp(self, result):
|
||||
"""
|
||||
Full loop flow train step (result obj + dp)
|
||||
|
|
|
@ -52,6 +52,28 @@ class ValidationStepVariations(ABC):
|
|||
})
|
||||
return result
|
||||
|
||||
def validation_step_result_obj_dp(self, batch, batch_idx, *args, **kwargs):
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self(x.to(self.device))
|
||||
|
||||
y = y.to(y_hat.device)
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc).type_as(x)
|
||||
|
||||
result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
|
||||
result.log_dict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
|
||||
self.validation_step_called = True
|
||||
return result
|
||||
|
||||
def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
|
|
|
@ -535,6 +535,41 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
|
|||
assert 'epoch_train_epoch_end_metric' in seen_keys
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_loop_steps_only_dp(tmpdir):
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
batches = 10
|
||||
epochs = 3
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = None
|
||||
model.test_step = None
|
||||
model.training_step = model.training_step_result_obj_dp
|
||||
model.training_step_end = None
|
||||
model.training_epoch_end = None
|
||||
model.validation_step = model.validation_step_result_obj_dp
|
||||
model.validation_step_end = None
|
||||
model.validation_epoch_end = None
|
||||
model.test_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
distributed_backend='dp',
|
||||
gpus=[0, 1],
|
||||
max_epochs=epochs,
|
||||
early_stop_callback=True,
|
||||
row_log_interval=2,
|
||||
limit_train_batches=batches,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert model.training_step_called
|
||||
assert model.validation_step_called
|
||||
|
||||
|
||||
def test_result_map(tmpdir):
|
||||
result = TrainResult()
|
||||
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
|
||||
|
|
Loading…
Reference in New Issue