ref: group connectors (#3472)

* ref: accelerator connector methods 3/n

* ref: accelerator connector methods 3/n
This commit is contained in:
William Falcon 2020-09-11 23:33:09 -04:00 committed by GitHub
parent dd324e4086
commit ff0064f956
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 39 additions and 16 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import re
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -22,6 +23,7 @@ from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.cloud_io import atomic_save
try: try:
import torch_xla import torch_xla
@ -124,7 +126,7 @@ class TPUBackend(Accelerator):
self.__save_end_of_training_weights(model, trainer) self.__save_end_of_training_weights(model, trainer)
# persist info in spawn # persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def training_step(self, args): def training_step(self, args):
batch = args[0] batch = args[0]
@ -294,3 +296,25 @@ class TPUBackend(Accelerator):
os.remove(path) os.remove(path)
return loaded_model return loaded_model
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return
# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path
if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)
# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

View File

@ -1,5 +1,5 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.states import TrainerState
from typing import List, Optional, Union from typing import List, Optional, Union
from pytorch_lightning.utilities import argparse_utils from pytorch_lightning.utilities import argparse_utils
@ -10,7 +10,7 @@ import os
from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ProgressBarBase from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.trainer.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
class TrainerProperties(ABC): class TrainerProperties(ABC):

View File

@ -14,10 +14,9 @@
import os import os
import warnings import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Union
import torch import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
@ -26,7 +25,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
@ -37,23 +36,23 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning import _logger as log from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.precision_connector import PrecisionConnector from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.properties import TrainerProperties