simplify accelerator steps (#5015)

* simplify accelerator steps

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Jirka Borovec 2020-12-10 14:06:13 +01:00 committed by GitHub
parent 820d5c7348
commit d5fa02e798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 52 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable
import torch
@ -61,29 +61,22 @@ class CPUAccelerator(Accelerator):
results = self.train_or_test()
return results
def training_step(self, args):
def _step(self, model_step: Callable, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.training_step(*args)
output = model_step(*args)
else:
output = self.trainer.model.training_step(*args)
output = model_step(*args)
return output
def training_step(self, args):
return self._step(self.trainer.model.training_step, args)
def validation_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)
def test_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)
def sync_tensor(self,
tensor: Union[torch.Tensor],

View File

@ -116,7 +116,7 @@ class DataParallelAccelerator(Accelerator):
self.trainer.model.forward = self.model_autocast_original_forward
self.barrier()
def training_step(self, args):
def _step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
@ -124,13 +124,14 @@ class DataParallelAccelerator(Accelerator):
output = self.trainer.model(*args)
return output
def training_step(self, args):
return self._step(args)
def validation_step(self, args):
output = self.training_step(args)
return output
return self._step(args)
def test_step(self, args):
output = self.training_step(args)
return output
return self._step(args)
def training_step_end(self, output):
if isinstance(output, Result):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable
import torch
from torch.optim.lr_scheduler import _LRScheduler
@ -114,46 +114,26 @@ class HorovodAccelerator(Accelerator):
hvd.join()
return results
def training_step(self, args):
def _step(self, model_step: Callable, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
args[0] = self.batch_to_device(args[0], hvd.local_rank())
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.training_step(*args)
output = model_step(*args)
else:
output = self.trainer.model.training_step(*args)
output = model_step(*args)
return output
def training_step(self, args):
return self._step(self.trainer.model.training_step, args)
def validation_step(self, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)
def test_step(self, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)

View File

@ -54,6 +54,7 @@ class SLURMConnector:
if self.trainer.is_slurm_managing_tasks:
rank_zero_info('Multi-processing is handled by Slurm.')
# todo: the same function as slurm_environment.py `_resolve_root_node_address`
def resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
@ -108,8 +109,8 @@ class SLURMConnector:
# save
log.info("bypassing sigterm")
# todo: this is the same func as slurm_environment.py `master_port`
def connect_ddp(self, global_rank: int, world_size: int) -> None:
""""""
"""
Sets up environment variables necessary for pytorch distributed communications
based on slurm environment.