eval step scaling factor (#3136)

This commit is contained in:
William Falcon 2020-08-24 20:26:39 -04:00 committed by GitHub
parent 6c3cec3a3c
commit 82d1128966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 15 deletions

View File

@ -50,7 +50,17 @@ class CPUBackend(Accelerator):
return output
def validation_step(self, args):
return self.trainer.model.validation_step(*args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, args):
return self.trainer.model.test_step(*args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output

View File

@ -69,6 +69,15 @@ class GPUBackend(Accelerator):
return output
def validation_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__validation_step(args)
else:
output = self.__validation_step(args)
return output
def __validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
@ -76,6 +85,15 @@ class GPUBackend(Accelerator):
return output
def test_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__test_step(args)
else:
output = self.__test_step(args)
return output
def __test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch

View File

@ -134,7 +134,12 @@ class HorovodBackend(Accelerator):
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.validation_step(*args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, args):
@ -143,5 +148,9 @@ class HorovodBackend(Accelerator):
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.test_step(*args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output

View File

@ -303,18 +303,10 @@ class TrainerEvaluationLoopMixin(ABC):
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
# TODO: collapse if statement into backends (next)
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
output = self.accelerator_backend.validation_step(args)
is_result_obj = isinstance(output, Result)