2019-07-09 00:12:27 +00:00
|
|
|
"""
|
2019-08-07 13:01:19 +00:00
|
|
|
The trainer handles all the logic for running a val loop, training loop, distributing, etc.. .
|
2019-07-09 00:12:27 +00:00
|
|
|
"""
|
2019-08-05 21:57:39 +00:00
|
|
|
|
2019-07-09 00:11:20 +00:00
|
|
|
import os
|
2019-08-05 21:57:39 +00:00
|
|
|
import warnings
|
2019-11-05 13:43:21 +00:00
|
|
|
import logging
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2019-07-09 00:11:20 +00:00
|
|
|
import torch.distributed as dist
|
2019-10-22 08:32:40 +00:00
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import tqdm
|
2019-08-15 15:31:56 +00:00
|
|
|
from torch.optim.optimizer import Optimizer
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
from pytorch_lightning.trainer.amp_mixin import TrainerAMPMixin
|
2019-10-22 08:32:40 +00:00
|
|
|
from pytorch_lightning.trainer.callback_config_mixin import TrainerCallbackConfigMixin
|
2019-10-22 01:16:51 +00:00
|
|
|
from pytorch_lightning.trainer.data_loading_mixin import TrainerDataLoadingMixin
|
2019-10-22 08:32:40 +00:00
|
|
|
from pytorch_lightning.trainer.ddp_mixin import TrainerDDPMixin
|
|
|
|
from pytorch_lightning.trainer.dp_mixin import TrainerDPMixin
|
2019-10-23 09:05:09 +00:00
|
|
|
from pytorch_lightning.trainer.dp_mixin import (
|
|
|
|
parse_gpu_ids,
|
|
|
|
determine_root_gpu_device
|
|
|
|
)
|
2019-10-22 01:16:51 +00:00
|
|
|
from pytorch_lightning.trainer.evaluation_loop_mixin import TrainerEvaluationLoopMixin
|
|
|
|
from pytorch_lightning.trainer.logging_mixin import TrainerLoggingMixin
|
|
|
|
from pytorch_lightning.trainer.model_hooks_mixin import TrainerModelHooksMixin
|
2019-10-22 08:32:40 +00:00
|
|
|
from pytorch_lightning.trainer.train_loop_mixin import TrainerTrainLoopMixin
|
|
|
|
from pytorch_lightning.trainer.trainer_io import TrainerIOMixin
|
|
|
|
from pytorch_lightning.trainer.training_tricks_mixin import TrainerTrainingTricksMixin
|
2019-08-07 14:14:59 +00:00
|
|
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
2019-10-04 19:35:02 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
APEX_AVAILABLE = True
|
2019-08-05 21:28:04 +00:00
|
|
|
except ImportError:
|
2019-05-14 00:40:07 +00:00
|
|
|
APEX_AVAILABLE = False
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-09 00:12:27 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
class Trainer(TrainerIOMixin,
|
|
|
|
TrainerDDPMixin,
|
|
|
|
TrainerDPMixin,
|
|
|
|
TrainerDataLoadingMixin,
|
|
|
|
TrainerAMPMixin,
|
|
|
|
TrainerEvaluationLoopMixin,
|
|
|
|
TrainerTrainLoopMixin,
|
|
|
|
TrainerLoggingMixin,
|
|
|
|
TrainerTrainingTricksMixin,
|
|
|
|
TrainerCallbackConfigMixin,
|
|
|
|
TrainerModelHooksMixin):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
def __init__(self,
|
2019-10-12 10:00:24 +00:00
|
|
|
logger=True,
|
|
|
|
checkpoint_callback=True,
|
|
|
|
early_stop_callback=True,
|
2019-10-04 23:48:57 +00:00
|
|
|
default_save_path=None,
|
2019-09-25 23:05:06 +00:00
|
|
|
gradient_clip_val=0,
|
2019-10-21 06:16:55 +00:00
|
|
|
gradient_clip=None, # backward compatible
|
2019-03-31 01:45:16 +00:00
|
|
|
process_position=0,
|
2019-07-08 21:33:20 +00:00
|
|
|
nb_gpu_nodes=1,
|
2019-07-01 22:38:07 +00:00
|
|
|
gpus=None,
|
2019-10-05 15:29:34 +00:00
|
|
|
log_gpu_memory=None,
|
2019-08-24 01:23:27 +00:00
|
|
|
show_progress_bar=True,
|
2019-03-31 20:29:50 +00:00
|
|
|
overfit_pct=0.0,
|
2019-03-31 01:45:16 +00:00
|
|
|
track_grad_norm=-1,
|
|
|
|
check_val_every_n_epoch=1,
|
|
|
|
fast_dev_run=False,
|
2019-03-31 20:29:50 +00:00
|
|
|
accumulate_grad_batches=1,
|
2019-08-06 10:08:31 +00:00
|
|
|
max_nb_epochs=1000,
|
|
|
|
min_nb_epochs=1,
|
|
|
|
train_percent_check=1.0,
|
|
|
|
val_percent_check=1.0,
|
|
|
|
test_percent_check=1.0,
|
2019-08-19 14:42:08 +00:00
|
|
|
val_check_interval=1.0,
|
2019-08-06 10:08:31 +00:00
|
|
|
log_save_interval=100,
|
2019-09-25 23:05:06 +00:00
|
|
|
row_log_interval=10,
|
2019-10-21 06:16:55 +00:00
|
|
|
add_row_log_interval=None, # backward compatible
|
2019-09-08 19:36:58 +00:00
|
|
|
distributed_backend=None,
|
2019-05-14 02:02:53 +00:00
|
|
|
use_amp=False,
|
2019-07-01 22:38:07 +00:00
|
|
|
print_nan_grads=False,
|
2019-10-08 19:30:06 +00:00
|
|
|
weights_summary='full',
|
2019-09-06 21:01:03 +00:00
|
|
|
weights_save_path=None,
|
2019-10-08 13:09:57 +00:00
|
|
|
amp_level='O1',
|
2019-10-31 10:45:28 +00:00
|
|
|
nb_sanity_val_steps=5,
|
|
|
|
truncated_bptt_steps=None):
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
|
|
|
|
2019-09-27 16:05:29 +00:00
|
|
|
:param logger: Logger for experiment tracking
|
2019-09-06 21:01:03 +00:00
|
|
|
:param checkpoint_callback: Callback for checkpointing
|
2019-10-04 23:48:57 +00:00
|
|
|
:param early_stop_callback: Callback for early stopping
|
|
|
|
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
|
2019-09-25 23:05:06 +00:00
|
|
|
:param gradient_clip_val: int. 0 means don't clip.
|
2019-10-21 06:16:55 +00:00
|
|
|
:param gradient_clip: int. 0 means don't clip. Deprecated.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param process_position: shown in the tqdm bar
|
|
|
|
:param nb_gpu_nodes: number of GPU nodes
|
2019-10-23 09:05:09 +00:00
|
|
|
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] OR '0,1'
|
|
|
|
OR '-1' / -1 to use all available gpus
|
2019-10-05 15:29:34 +00:00
|
|
|
:param log_gpu_memory: str. None, 'min_max', 'all'
|
2019-09-06 21:01:03 +00:00
|
|
|
:param show_progress_bar: Bool. If true shows tqdm bar
|
|
|
|
:param overfit_pct: float. uses this much of all datasets
|
|
|
|
:param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm
|
|
|
|
:param check_val_every_n_epoch: int. check val every n train epochs
|
|
|
|
:param fast_dev_run: Bool. runs full iteration over everything to find bugs
|
|
|
|
:param accumulate_grad_batches: int. Accumulates grads every k batches
|
|
|
|
:param max_nb_epochs: int.
|
|
|
|
:param min_nb_epochs: int.
|
|
|
|
:param train_percent_check: int. How much of train set to check
|
|
|
|
:param val_percent_check: int. How much of val set to check
|
|
|
|
:param test_percent_check: int. How much of test set to check
|
2019-10-22 02:10:00 +00:00
|
|
|
:param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
|
2019-09-06 21:01:03 +00:00
|
|
|
:param log_save_interval: int. Writes logs to disk this often
|
2019-09-25 23:05:06 +00:00
|
|
|
:param row_log_interval: int. How often to add logging rows
|
2019-10-21 06:16:55 +00:00
|
|
|
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
2019-10-04 19:07:54 +00:00
|
|
|
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param use_amp: Bool. If true uses apex for 16bit precision
|
|
|
|
:param print_nan_grads: Bool. Prints nan gradients
|
2019-10-08 21:11:47 +00:00
|
|
|
:param weights_summary: str. Options: 'full', 'top', None to not print.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param weights_save_path: Bool. Where to save weights if on cluster
|
|
|
|
:param amp_level: str. Check nvidia docs for level
|
|
|
|
:param nb_sanity_val_steps: int. How many val steps before a full train loop.
|
2019-10-31 10:45:28 +00:00
|
|
|
:param truncated_bptt_steps: int. Enables multiple backward passes for each batch.
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
# Transfer params
|
2019-07-03 20:34:49 +00:00
|
|
|
self.nb_gpu_nodes = nb_gpu_nodes
|
2019-09-04 14:43:46 +00:00
|
|
|
self.log_gpu_memory = log_gpu_memory
|
2019-10-21 06:16:55 +00:00
|
|
|
if not (gradient_clip is None):
|
|
|
|
# Backward compatibility
|
|
|
|
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
|
|
|
DeprecationWarning)
|
|
|
|
gradient_clip_val = gradient_clip
|
2019-09-25 23:05:06 +00:00
|
|
|
self.gradient_clip_val = gradient_clip_val
|
2019-03-31 01:45:16 +00:00
|
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
self.track_grad_norm = track_grad_norm
|
2019-07-01 22:38:07 +00:00
|
|
|
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
2019-03-31 01:45:16 +00:00
|
|
|
self.process_position = process_position
|
2019-10-08 19:30:06 +00:00
|
|
|
self.weights_summary = weights_summary
|
2019-03-31 01:45:16 +00:00
|
|
|
self.max_nb_epochs = max_nb_epochs
|
|
|
|
self.min_nb_epochs = min_nb_epochs
|
|
|
|
self.nb_sanity_val_steps = nb_sanity_val_steps
|
2019-07-01 22:38:07 +00:00
|
|
|
self.print_nan_grads = print_nan_grads
|
2019-10-31 10:45:28 +00:00
|
|
|
self.truncated_bptt_steps = truncated_bptt_steps
|
2019-10-24 10:43:35 +00:00
|
|
|
self.shown_warnings = set()
|
2019-07-08 13:42:13 +00:00
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
self.fast_dev_run = fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
|
|
|
self.nb_sanity_val_steps = 1
|
|
|
|
self.max_nb_epochs = 1
|
|
|
|
m = '''
|
|
|
|
Running in fast_dev_run mode: will run a full train,
|
|
|
|
val loop using a single batch
|
|
|
|
'''
|
2019-11-05 13:43:21 +00:00
|
|
|
logging.info(m)
|
2019-10-09 14:23:08 +00:00
|
|
|
|
2019-10-04 23:48:57 +00:00
|
|
|
# set default save path if user didn't provide one
|
|
|
|
self.default_save_path = default_save_path
|
|
|
|
if self.default_save_path is None:
|
|
|
|
self.default_save_path = os.getcwd()
|
|
|
|
|
2019-07-24 14:42:01 +00:00
|
|
|
# training bookeeping
|
|
|
|
self.total_batch_nb = 0
|
|
|
|
self.running_loss = []
|
|
|
|
self.avg_loss = 0
|
|
|
|
self.batch_nb = 0
|
|
|
|
self.tqdm_metrics = {}
|
2019-10-08 20:21:00 +00:00
|
|
|
self.callback_metrics = {}
|
2019-08-23 11:42:17 +00:00
|
|
|
self.nb_val_batches = 0
|
2019-09-25 23:05:06 +00:00
|
|
|
self.nb_training_batches = 0
|
2019-08-23 11:42:17 +00:00
|
|
|
self.nb_test_batches = 0
|
2019-10-04 19:35:02 +00:00
|
|
|
self.get_train_dataloader = None
|
|
|
|
self.get_test_dataloaders = None
|
|
|
|
self.get_val_dataloaders = None
|
2019-10-22 02:10:00 +00:00
|
|
|
self.is_iterable_train_dataloader = False
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# training state
|
|
|
|
self.model = None
|
|
|
|
self.testing = False
|
|
|
|
self.lr_schedulers = []
|
|
|
|
self.optimizers = None
|
|
|
|
self.global_step = 0
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.total_batches = 0
|
|
|
|
|
2019-09-06 21:01:03 +00:00
|
|
|
# configure early stop callback
|
2019-10-04 23:48:57 +00:00
|
|
|
# creates a default one if none passed in
|
2019-10-15 16:44:20 +00:00
|
|
|
self.early_stop_callback = None
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_early_stopping(early_stop_callback, logger)
|
2019-10-04 23:48:57 +00:00
|
|
|
|
|
|
|
# configure checkpoint callback
|
|
|
|
self.checkpoint_callback = checkpoint_callback
|
2019-10-09 21:46:27 +00:00
|
|
|
self.weights_save_path = weights_save_path
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# accumulated grads
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
# allow int, string and gpu list
|
2019-10-23 09:05:09 +00:00
|
|
|
self.data_parallel_device_ids = parse_gpu_ids(gpus)
|
|
|
|
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# distributed backend choice
|
|
|
|
self.use_ddp = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.use_ddp2 = False
|
2019-09-06 04:29:38 +00:00
|
|
|
self.use_dp = False
|
|
|
|
self.single_gpu = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.distributed_backend = distributed_backend
|
2019-10-22 01:16:51 +00:00
|
|
|
self.set_distributed_mode(distributed_backend, nb_gpu_nodes)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# init flags for SLURM+ddp to work
|
|
|
|
self.proc_rank = 0
|
|
|
|
self.world_size = 1
|
|
|
|
self.node_rank = 0
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_slurm_ddp(nb_gpu_nodes)
|
2019-09-08 19:36:58 +00:00
|
|
|
|
|
|
|
# nvidia setup
|
2019-10-22 01:16:51 +00:00
|
|
|
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# can't init progress bar here because starting a new process
|
2019-09-25 23:05:06 +00:00
|
|
|
# means the progress_bar won't survive pickling
|
2019-09-06 04:29:38 +00:00
|
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
|
|
|
|
# logging
|
|
|
|
self.log_save_interval = log_save_interval
|
|
|
|
self.val_check_interval = val_check_interval
|
2019-10-21 06:16:55 +00:00
|
|
|
if not (add_row_log_interval is None):
|
|
|
|
# backward compatibility
|
|
|
|
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
|
|
|
DeprecationWarning)
|
|
|
|
row_log_interval = add_row_log_interval
|
2019-09-25 23:05:06 +00:00
|
|
|
self.row_log_interval = row_log_interval
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# how much of the data to use
|
2019-10-22 01:16:51 +00:00
|
|
|
self.determine_data_use_amount(train_percent_check, val_percent_check,
|
|
|
|
test_percent_check, overfit_pct)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# 16 bit mixed precision training using apex
|
|
|
|
self.amp_level = amp_level
|
2019-10-22 01:16:51 +00:00
|
|
|
self.init_amp(use_amp)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2019-11-05 13:43:21 +00:00
|
|
|
# set logging options
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
2019-10-05 18:45:37 +00:00
|
|
|
@property
|
|
|
|
def slurm_job_id(self):
|
|
|
|
try:
|
|
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
|
|
job_id = int(job_id)
|
|
|
|
except Exception as e:
|
|
|
|
job_id = None
|
|
|
|
return job_id
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
def __parse_gpu_ids(self, gpus):
|
2019-09-08 19:36:58 +00:00
|
|
|
"""
|
|
|
|
:param gpus: Int, string or list of ids
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 13:42:13 +00:00
|
|
|
# if gpus = -1 then use all available devices
|
|
|
|
# otherwise, split the string using commas
|
|
|
|
if gpus is not None:
|
2019-07-21 12:08:21 +00:00
|
|
|
if type(gpus) is list:
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = gpus
|
2019-07-21 12:08:21 +00:00
|
|
|
elif type(gpus) is str:
|
|
|
|
if gpus == '-1':
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = list(range(0, torch.cuda.device_count()))
|
2019-07-21 12:08:21 +00:00
|
|
|
else:
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = [int(x.strip()) for x in gpus.split(',')]
|
2019-09-08 19:36:58 +00:00
|
|
|
elif type(gpus) is int:
|
|
|
|
gpus = gpus
|
2019-07-08 13:42:13 +00:00
|
|
|
else:
|
2019-09-08 19:36:58 +00:00
|
|
|
raise Exception('gpus has to be a string, int or list of ints')
|
2019-07-08 14:00:04 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
return gpus
|
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
def __set_root_gpu(self, gpus):
|
|
|
|
if gpus is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# set root gpu
|
|
|
|
root_gpu = 0
|
|
|
|
if type(gpus) is list:
|
|
|
|
root_gpu = gpus[0]
|
|
|
|
|
|
|
|
return root_gpu
|
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
@property
|
|
|
|
def num_gpus(self):
|
|
|
|
gpus = self.data_parallel_device_ids
|
|
|
|
if gpus is None:
|
|
|
|
return 0
|
2019-10-23 09:05:09 +00:00
|
|
|
else:
|
2019-09-08 19:36:58 +00:00
|
|
|
return len(gpus)
|
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
@property
|
|
|
|
def data_parallel(self):
|
2019-10-05 20:39:05 +00:00
|
|
|
return self.use_dp or self.use_ddp or self.use_ddp2
|
2019-07-18 15:08:48 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
@property
|
|
|
|
def training_tqdm_dict(self):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-10-22 01:16:51 +00:00
|
|
|
Read-only for tqdm metrics
|
|
|
|
:return:
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_dict = {
|
2019-08-16 15:58:44 +00:00
|
|
|
'loss': '{0:.3f}'.format(self.avg_loss),
|
2019-08-05 21:57:39 +00:00
|
|
|
'batch_nb': '{}'.format(self.batch_nb),
|
2019-03-31 01:45:16 +00:00
|
|
|
}
|
2019-08-08 14:59:16 +00:00
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
if self.truncated_bptt_steps is not None:
|
|
|
|
tqdm_dict['split_nb'] = self.split_nb
|
|
|
|
|
2019-09-27 16:05:29 +00:00
|
|
|
if self.logger is not None and self.logger.version is not None:
|
|
|
|
tqdm_dict['v_nb'] = self.logger.version
|
2019-08-08 14:59:16 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_dict.update(self.tqdm_metrics)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
if self.on_gpu:
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device())
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
return tqdm_dict
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-21 06:16:55 +00:00
|
|
|
@property
|
|
|
|
def tng_tqdm_dic(self):
|
|
|
|
"""
|
|
|
|
* Deprecated in v0.5.0. use training_tqdm_dict instead. *
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
warnings.warn("tng_tqdm_dict has renamed to training_tqdm_dict since v0.5.0",
|
|
|
|
DeprecationWarning)
|
|
|
|
return self.training_tqdm_dict
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# -----------------------------
|
|
|
|
# MODEL TRAINING
|
|
|
|
# -----------------------------
|
|
|
|
def fit(self, model):
|
2019-07-18 15:08:48 +00:00
|
|
|
# when using multi-node or DDP within a node start each module in a separate process
|
2019-10-05 20:39:05 +00:00
|
|
|
if self.use_ddp2:
|
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
2019-07-18 20:47:46 +00:00
|
|
|
|
2019-10-05 20:39:05 +00:00
|
|
|
elif self.use_ddp:
|
|
|
|
if self.is_slurm_managing_tasks:
|
2019-07-18 20:47:46 +00:00
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
|
|
|
else:
|
2019-10-22 08:32:40 +00:00
|
|
|
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
# 1 gpu or dp option triggers training using DP module
|
|
|
|
# easier to avoid NCCL issues
|
|
|
|
elif self.use_dp:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.dp_train(model)
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-08-07 17:39:40 +00:00
|
|
|
elif self.single_gpu:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.single_gpu_train(model)
|
2019-08-07 17:39:40 +00:00
|
|
|
|
2019-07-18 15:09:00 +00:00
|
|
|
# ON CPU
|
2019-07-03 19:09:49 +00:00
|
|
|
else:
|
2019-07-11 18:17:43 +00:00
|
|
|
# run through amp wrapper
|
|
|
|
if self.use_amp:
|
2019-08-06 10:08:31 +00:00
|
|
|
raise MisconfigurationException('amp + cpu is not supported.'
|
|
|
|
' Please use a GPU option')
|
2019-07-11 18:17:43 +00:00
|
|
|
|
2019-07-25 15:08:31 +00:00
|
|
|
# CHOOSE OPTIMIZER
|
2019-07-28 13:33:58 +00:00
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-07-25 15:08:31 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_pretrain_routine(model)
|
2019-07-03 19:09:49 +00:00
|
|
|
|
2019-07-24 11:26:18 +00:00
|
|
|
# return 1 when finished
|
|
|
|
# used for testing or when we need to know that training succeeded
|
|
|
|
return 1
|
|
|
|
|
2019-08-15 15:31:56 +00:00
|
|
|
def init_optimizers(self, optimizers):
|
|
|
|
|
|
|
|
# single optimizer
|
|
|
|
if isinstance(optimizers, Optimizer):
|
|
|
|
return [optimizers], []
|
|
|
|
|
|
|
|
# two lists
|
|
|
|
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
|
|
|
|
optimizers, lr_schedulers = optimizers
|
|
|
|
return optimizers, lr_schedulers
|
|
|
|
|
|
|
|
# single list or tuple
|
|
|
|
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
|
|
|
|
return optimizers, []
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
def run_pretrain_routine(self, model):
|
2019-07-03 19:09:49 +00:00
|
|
|
"""
|
|
|
|
Sanity check a few things before starting actual training
|
|
|
|
:param model:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model
|
2019-07-14 02:21:17 +00:00
|
|
|
if self.data_parallel:
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model.module
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# give model convenience properties
|
2019-07-08 22:55:05 +00:00
|
|
|
ref_model.trainer = self
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# set local properties on the model
|
2019-10-22 01:16:51 +00:00
|
|
|
self.copy_trainer_model_properties(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-10-10 19:16:19 +00:00
|
|
|
# link up experiment object
|
|
|
|
if self.logger is not None:
|
|
|
|
ref_model.logger = self.logger
|
|
|
|
|
|
|
|
# save exp to get started
|
|
|
|
if hasattr(ref_model, "hparams"):
|
|
|
|
self.logger.log_hyperparams(ref_model.hparams)
|
|
|
|
|
|
|
|
self.logger.save()
|
|
|
|
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
|
|
|
dist.barrier()
|
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
# set up checkpoint callback
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_checkpoint_callback()
|
2019-10-09 21:46:27 +00:00
|
|
|
|
2019-09-06 15:54:51 +00:00
|
|
|
# register auto-resubmit when on SLURM
|
|
|
|
self.register_slurm_signal_handlers()
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# transfer data loaders from model
|
2019-07-24 21:09:14 +00:00
|
|
|
self.get_dataloaders(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# print model summary
|
2019-10-08 21:11:47 +00:00
|
|
|
if self.proc_rank == 0 and self.weights_summary is not None:
|
|
|
|
if self.weights_summary in ['full', 'top']:
|
|
|
|
ref_model.summarize(mode=self.weights_summary)
|
|
|
|
else:
|
|
|
|
m = "weights_summary can be None, 'full' or 'top'"
|
|
|
|
raise MisconfigurationException(m)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-07-27 02:57:49 +00:00
|
|
|
# track model now.
|
|
|
|
# if cluster resets state, the model will update with the saved weights
|
|
|
|
self.model = model
|
|
|
|
|
2019-08-07 11:42:14 +00:00
|
|
|
# restore training and model before hpc call
|
2019-09-06 15:54:51 +00:00
|
|
|
self.restore_weights(model)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# when testing requested only run test and return
|
|
|
|
if self.testing:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_evaluation(test=True)
|
2019-08-30 22:56:09 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
# run tiny validation (if validation defined)
|
|
|
|
# to make sure program won't crash during val
|
2019-08-07 12:14:52 +00:00
|
|
|
ref_model.on_sanity_check_start()
|
2019-10-04 19:35:02 +00:00
|
|
|
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
|
2019-11-03 10:42:53 +00:00
|
|
|
# init progress bars for validation sanity check
|
|
|
|
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
|
|
|
|
leave=False, position=2 * self.process_position,
|
|
|
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
|
|
|
|
self.main_progress_bar = pbar
|
|
|
|
# dummy validation progress bar
|
|
|
|
self.val_progress_bar = tqdm.tqdm(disable=True)
|
2019-08-24 01:23:27 +00:00
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)
|
2019-08-07 11:51:55 +00:00
|
|
|
|
2019-11-03 10:42:53 +00:00
|
|
|
# close progress bars
|
|
|
|
self.main_progress_bar.close()
|
|
|
|
self.val_progress_bar.close()
|
|
|
|
|
|
|
|
# init progress bar
|
|
|
|
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
|
|
|
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
|
|
|
|
self.main_progress_bar = pbar
|
|
|
|
|
2019-10-23 15:41:00 +00:00
|
|
|
# clear cache before training
|
|
|
|
if self.on_gpu:
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# CORE TRAINING LOOP
|
2019-10-22 01:16:51 +00:00
|
|
|
self.train()
|
2019-10-05 17:35:20 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def test(self, model=None):
|
2019-10-18 22:39:30 +00:00
|
|
|
self.testing = True
|
2019-08-30 22:56:09 +00:00
|
|
|
if model is not None:
|
|
|
|
self.fit(model)
|
|
|
|
else:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_evaluation(test=True)
|