ref: remove obscure forward call in eval + CPU backend ___step (#3123)
* remove obscure forward call in eval * remove obscure forward call in eval * remove obscure forward call in eval * remove obscure forward call in eval * remove obscure forward call in eval * remove obscure forward call in eval
This commit is contained in:
parent
18160b81b5
commit
6068b29d29
|
@ -38,3 +38,12 @@ class CPUBackend(object):
|
|||
def train(self, model):
|
||||
results = self.trainer.run_pretrain_routine(model)
|
||||
return results
|
||||
|
||||
def training_step(self, args):
|
||||
return self.trainer.model.training_step(*args)
|
||||
|
||||
def validation_step(self, args):
|
||||
return self.trainer.model.validation_step(*args)
|
||||
|
||||
def test_step(self, args):
|
||||
return self.trainer.model.test_step(*args)
|
||||
|
|
|
@ -115,22 +115,28 @@ class HorovodBackend(Accelerator):
|
|||
pass
|
||||
|
||||
def training_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
|
|
@ -323,11 +323,20 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
||||
# TODO: collapse if statement into backends (next)
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
else:
|
||||
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
||||
|
@ -646,50 +655,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
return eval_loop_results
|
||||
|
||||
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
|
||||
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
|
||||
args.append(dataloader_idx)
|
||||
|
||||
# handle DP, DDP forward
|
||||
if self.use_ddp or self.use_dp or self.use_ddp2:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# Horovod
|
||||
if self.use_horovod and self.on_gpu:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# single GPU data transfer
|
||||
if self.use_single_gpu:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# TPU data transfer
|
||||
if self.use_tpu:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# CPU, TPU or gpu step
|
||||
# TODO: remove during refactors
|
||||
if test_mode:
|
||||
output = model.test_step(*args)
|
||||
else:
|
||||
output = model.validation_step(*args)
|
||||
|
||||
return output
|
||||
return args
|
||||
|
|
|
@ -1193,8 +1193,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
if self.use_ddp or self.use_ddp2 or self.use_dp:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# Horovod
|
||||
elif self.use_horovod and self.on_gpu:
|
||||
# horovod
|
||||
elif self.use_horovod:
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# single GPU forward
|
||||
|
@ -1207,7 +1207,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# CPU forward
|
||||
else:
|
||||
output = self.model.training_step(*args)
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
||||
|
|
|
@ -926,12 +926,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
|||
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
||||
val_dataloaders = model.val_dataloader__multiple_mixed_length()
|
||||
|
||||
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
assert mocked.call_count == sum(
|
||||
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['limit_val_batches'], [
|
||||
pytest.param(0.0), # this should run no sanity checks
|
||||
|
@ -956,10 +950,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
|||
assert trainer.num_sanity_val_steps == float('inf')
|
||||
val_dataloaders = model.val_dataloader__multiple()
|
||||
|
||||
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
assert mocked.call_count == sum(trainer.num_val_batches)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
||||
pytest.param(
|
||||
|
|
Loading…
Reference in New Issue