eval step scaling factor (#3136)
This commit is contained in:
parent
6c3cec3a3c
commit
82d1128966
|
@ -50,7 +50,17 @@ class CPUBackend(Accelerator):
|
|||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
return self.trainer.model.validation_step(*args)
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
return self.trainer.model.test_step(*args)
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.test_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
|
|
@ -69,6 +69,15 @@ class GPUBackend(Accelerator):
|
|||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.__validation_step(args)
|
||||
else:
|
||||
output = self.__validation_step(args)
|
||||
|
||||
return output
|
||||
|
||||
def __validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
|
@ -76,6 +85,15 @@ class GPUBackend(Accelerator):
|
|||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.__test_step(args)
|
||||
else:
|
||||
output = self.__test_step(args)
|
||||
|
||||
return output
|
||||
|
||||
def __test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
|
|
|
@ -134,7 +134,12 @@ class HorovodBackend(Accelerator):
|
|||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
|
@ -143,5 +148,9 @@ class HorovodBackend(Accelerator):
|
|||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
output = self.trainer.model.test_step(*args)
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.test_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
|
|
@ -303,18 +303,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# -----------------
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
||||
# TODO: collapse if statement into backends (next)
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||
with torch.cuda.amp.autocast():
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
||||
|
|
Loading…
Reference in New Issue