ref: remove obscure forward call in eval + CPU backend ___step (#3123)

* remove obscure forward call in eval

* remove obscure forward call in eval

* remove obscure forward call in eval

* remove obscure forward call in eval

* remove obscure forward call in eval

* remove obscure forward call in eval
This commit is contained in:
William Falcon 2020-08-24 12:31:40 -04:00 committed by GitHub
parent 18160b81b5
commit 6068b29d29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 65 deletions

View File

@ -38,3 +38,12 @@ class CPUBackend(object):
def train(self, model):
results = self.trainer.run_pretrain_routine(model)
return results
def training_step(self, args):
return self.trainer.model.training_step(*args)
def validation_step(self, args):
return self.trainer.model.validation_step(*args)
def test_step(self, args):
return self.trainer.model.test_step(*args)

View File

@ -115,22 +115,28 @@ class HorovodBackend(Accelerator):
pass
def training_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def validation_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.test_step(*args)
return output

View File

@ -323,11 +323,20 @@ class TrainerEvaluationLoopMixin(ABC):
# -----------------
# RUN EVALUATION STEP
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
# TODO: collapse if statement into backends (next)
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
else:
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
is_result_obj = isinstance(output, Result)
@ -646,50 +655,11 @@ class TrainerEvaluationLoopMixin(ABC):
return eval_loop_results
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
args.append(dataloader_idx)
# handle DP, DDP forward
if self.use_ddp or self.use_dp or self.use_ddp2:
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# Horovod
if self.use_horovod and self.on_gpu:
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# single GPU data transfer
if self.use_single_gpu:
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# TPU data transfer
if self.use_tpu:
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# CPU, TPU or gpu step
# TODO: remove during refactors
if test_mode:
output = model.test_step(*args)
else:
output = model.validation_step(*args)
return output
return args

View File

@ -1193,8 +1193,8 @@ class TrainerTrainLoopMixin(ABC):
if self.use_ddp or self.use_ddp2 or self.use_dp:
output = self.accelerator_backend.training_step(args)
# Horovod
elif self.use_horovod and self.on_gpu:
# horovod
elif self.use_horovod:
output = self.accelerator_backend.training_step(args)
# single GPU forward
@ -1207,7 +1207,7 @@ class TrainerTrainLoopMixin(ABC):
# CPU forward
else:
output = self.model.training_step(*args)
output = self.accelerator_backend.training_step(args)
is_result_obj = isinstance(output, Result)

View File

@ -926,12 +926,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)
@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
@ -956,10 +950,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)
@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(