simplify accelerator steps (#5015)
* simplify accelerator steps * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
820d5c7348
commit
d5fa02e798
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -61,29 +61,22 @@ class CPUAccelerator(Accelerator):
|
|||
results = self.train_or_test()
|
||||
return results
|
||||
|
||||
def training_step(self, args):
|
||||
def _step(self, model_step: Callable, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.training_step(*args)
|
||||
output = model_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.training_step(*args)
|
||||
output = model_step(*args)
|
||||
return output
|
||||
|
||||
def training_step(self, args):
|
||||
return self._step(self.trainer.model.training_step, args)
|
||||
|
||||
def validation_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.validation_step, args)
|
||||
|
||||
def test_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.test_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.test_step, args)
|
||||
|
||||
def sync_tensor(self,
|
||||
tensor: Union[torch.Tensor],
|
||||
|
|
|
@ -116,7 +116,7 @@ class DataParallelAccelerator(Accelerator):
|
|||
self.trainer.model.forward = self.model_autocast_original_forward
|
||||
self.barrier()
|
||||
|
||||
def training_step(self, args):
|
||||
def _step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model(*args)
|
||||
|
@ -124,13 +124,14 @@ class DataParallelAccelerator(Accelerator):
|
|||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def training_step(self, args):
|
||||
return self._step(args)
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
return self._step(args)
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
return self._step(args)
|
||||
|
||||
def training_step_end(self, output):
|
||||
if isinstance(output, Result):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, Callable
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
@ -114,46 +114,26 @@ class HorovodAccelerator(Accelerator):
|
|||
hvd.join()
|
||||
return results
|
||||
|
||||
def training_step(self, args):
|
||||
def _step(self, model_step: Callable, args):
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
args[0] = self.batch_to_device(args[0], hvd.local_rank())
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.training_step(*args)
|
||||
output = model_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.training_step(*args)
|
||||
output = model_step(*args)
|
||||
|
||||
return output
|
||||
|
||||
def training_step(self, args):
|
||||
return self._step(self.trainer.model.training_step, args)
|
||||
|
||||
def validation_step(self, args):
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
|
||||
return output
|
||||
return self._step(self.trainer.model.validation_step, args)
|
||||
|
||||
def test_step(self, args):
|
||||
if self.trainer.on_gpu:
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model.test_step(*args)
|
||||
else:
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.test_step, args)
|
||||
|
||||
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
|
||||
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
|
||||
|
|
|
@ -54,6 +54,7 @@ class SLURMConnector:
|
|||
if self.trainer.is_slurm_managing_tasks:
|
||||
rank_zero_info('Multi-processing is handled by Slurm.')
|
||||
|
||||
# todo: the same function as slurm_environment.py `_resolve_root_node_address`
|
||||
def resolve_root_node_address(self, root_node):
|
||||
if '[' in root_node:
|
||||
name, numbers = root_node.split('[', maxsplit=1)
|
||||
|
@ -108,8 +109,8 @@ class SLURMConnector:
|
|||
# save
|
||||
log.info("bypassing sigterm")
|
||||
|
||||
# todo: this is the same func as slurm_environment.py `master_port`
|
||||
def connect_ddp(self, global_rank: int, world_size: int) -> None:
|
||||
""""""
|
||||
"""
|
||||
Sets up environment variables necessary for pytorch distributed communications
|
||||
based on slurm environment.
|
||||
|
|
Loading…
Reference in New Issue