2019-07-09 00:12:27 +00:00
|
|
|
"""
|
2019-08-07 13:01:19 +00:00
|
|
|
The trainer handles all the logic for running a val loop, training loop, distributing, etc.. .
|
2019-07-09 00:12:27 +00:00
|
|
|
"""
|
2019-08-05 21:57:39 +00:00
|
|
|
|
2019-07-09 00:11:20 +00:00
|
|
|
import os
|
2019-07-20 12:53:36 +00:00
|
|
|
import re
|
2019-08-05 21:57:39 +00:00
|
|
|
import warnings
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-08-05 08:52:09 +00:00
|
|
|
import numpy as np
|
|
|
|
import tqdm
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2019-07-09 00:11:20 +00:00
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.distributed as dist
|
2019-08-15 15:31:56 +00:00
|
|
|
from torch.optim.optimizer import Optimizer
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-09-02 19:46:16 +00:00
|
|
|
from pytorch_lightning.root_module.root_module import LightningModule
|
2019-10-05 15:29:34 +00:00
|
|
|
from pytorch_lightning.root_module import memory
|
2019-10-04 23:48:57 +00:00
|
|
|
from pytorch_lightning.logging import TestTubeLogger
|
2019-10-15 16:44:20 +00:00
|
|
|
from pytorch_lightning.trainer.trainer_io import TrainerIOMixin
|
2019-08-07 14:14:59 +00:00
|
|
|
from pytorch_lightning.pt_overrides.override_data_parallel import (
|
2019-08-06 10:08:31 +00:00
|
|
|
LightningDistributedDataParallel, LightningDataParallel)
|
2019-10-04 23:48:57 +00:00
|
|
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler, \
|
|
|
|
ModelCheckpoint, EarlyStopping
|
2019-08-07 14:14:59 +00:00
|
|
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
2019-08-30 22:56:09 +00:00
|
|
|
import pdb
|
2019-09-26 17:20:54 +00:00
|
|
|
from pytorch_lightning.trainer import ignored_warnings
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
|
|
|
APEX_AVAILABLE = True
|
2019-08-05 21:28:04 +00:00
|
|
|
except ImportError:
|
2019-05-14 00:40:07 +00:00
|
|
|
APEX_AVAILABLE = False
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-09 00:12:27 +00:00
|
|
|
|
2019-07-18 15:29:21 +00:00
|
|
|
def reduce_distributed_output(output, nb_gpus):
|
2019-07-18 16:08:17 +00:00
|
|
|
if nb_gpus <= 1:
|
2019-07-18 15:40:00 +00:00
|
|
|
return output
|
|
|
|
|
2019-07-18 16:08:17 +00:00
|
|
|
# when using DP, we get one output per gpu
|
|
|
|
# average outputs and return
|
|
|
|
if type(output) is torch.Tensor:
|
|
|
|
return output.mean()
|
|
|
|
|
2019-07-18 15:29:21 +00:00
|
|
|
for k, v in output.items():
|
|
|
|
# recurse on nested dics
|
|
|
|
if isinstance(output[k], dict):
|
|
|
|
output[k] = reduce_distributed_output(output[k], nb_gpus)
|
|
|
|
|
|
|
|
# reduce only metrics that have the same nb of gpus
|
|
|
|
elif output[k].size(0) == nb_gpus:
|
|
|
|
reduced = torch.mean(output[k])
|
|
|
|
output[k] = reduced
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2019-10-15 16:44:20 +00:00
|
|
|
class Trainer(TrainerIOMixin):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
def __init__(self,
|
2019-10-12 10:00:24 +00:00
|
|
|
logger=True,
|
|
|
|
checkpoint_callback=True,
|
|
|
|
early_stop_callback=True,
|
2019-10-04 23:48:57 +00:00
|
|
|
default_save_path=None,
|
2019-09-25 23:05:06 +00:00
|
|
|
gradient_clip_val=0,
|
2019-03-31 01:45:16 +00:00
|
|
|
process_position=0,
|
2019-07-08 21:33:20 +00:00
|
|
|
nb_gpu_nodes=1,
|
2019-07-01 22:38:07 +00:00
|
|
|
gpus=None,
|
2019-10-05 15:29:34 +00:00
|
|
|
log_gpu_memory=None,
|
2019-08-24 01:23:27 +00:00
|
|
|
show_progress_bar=True,
|
2019-03-31 20:29:50 +00:00
|
|
|
overfit_pct=0.0,
|
2019-03-31 01:45:16 +00:00
|
|
|
track_grad_norm=-1,
|
|
|
|
check_val_every_n_epoch=1,
|
|
|
|
fast_dev_run=False,
|
2019-03-31 20:29:50 +00:00
|
|
|
accumulate_grad_batches=1,
|
2019-08-06 10:08:31 +00:00
|
|
|
max_nb_epochs=1000,
|
|
|
|
min_nb_epochs=1,
|
|
|
|
train_percent_check=1.0,
|
|
|
|
val_percent_check=1.0,
|
|
|
|
test_percent_check=1.0,
|
2019-08-19 14:42:08 +00:00
|
|
|
val_check_interval=1.0,
|
2019-08-06 10:08:31 +00:00
|
|
|
log_save_interval=100,
|
2019-09-25 23:05:06 +00:00
|
|
|
row_log_interval=10,
|
2019-09-08 19:36:58 +00:00
|
|
|
distributed_backend=None,
|
2019-05-14 02:02:53 +00:00
|
|
|
use_amp=False,
|
2019-07-01 22:38:07 +00:00
|
|
|
print_nan_grads=False,
|
2019-10-08 19:30:06 +00:00
|
|
|
weights_summary='full',
|
2019-09-06 21:01:03 +00:00
|
|
|
weights_save_path=None,
|
2019-10-08 13:09:57 +00:00
|
|
|
amp_level='O1',
|
2019-03-31 01:45:16 +00:00
|
|
|
nb_sanity_val_steps=5):
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
|
|
|
|
2019-09-27 16:05:29 +00:00
|
|
|
:param logger: Logger for experiment tracking
|
2019-09-06 21:01:03 +00:00
|
|
|
:param checkpoint_callback: Callback for checkpointing
|
2019-10-04 23:48:57 +00:00
|
|
|
:param early_stop_callback: Callback for early stopping
|
|
|
|
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
|
2019-09-25 23:05:06 +00:00
|
|
|
:param gradient_clip_val: int. 0 means don't clip.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param process_position: shown in the tqdm bar
|
|
|
|
:param nb_gpu_nodes: number of GPU nodes
|
2019-09-08 19:36:58 +00:00
|
|
|
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
2019-10-05 15:29:34 +00:00
|
|
|
:param log_gpu_memory: str. None, 'min_max', 'all'
|
2019-09-06 21:01:03 +00:00
|
|
|
:param show_progress_bar: Bool. If true shows tqdm bar
|
|
|
|
:param overfit_pct: float. uses this much of all datasets
|
|
|
|
:param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm
|
|
|
|
:param check_val_every_n_epoch: int. check val every n train epochs
|
|
|
|
:param fast_dev_run: Bool. runs full iteration over everything to find bugs
|
|
|
|
:param accumulate_grad_batches: int. Accumulates grads every k batches
|
|
|
|
:param max_nb_epochs: int.
|
|
|
|
:param min_nb_epochs: int.
|
|
|
|
:param train_percent_check: int. How much of train set to check
|
|
|
|
:param val_percent_check: int. How much of val set to check
|
|
|
|
:param test_percent_check: int. How much of test set to check
|
|
|
|
:param val_check_interval: int. Check val this frequently within a train epoch
|
|
|
|
:param log_save_interval: int. Writes logs to disk this often
|
2019-09-25 23:05:06 +00:00
|
|
|
:param row_log_interval: int. How often to add logging rows
|
2019-10-04 19:07:54 +00:00
|
|
|
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param use_amp: Bool. If true uses apex for 16bit precision
|
|
|
|
:param print_nan_grads: Bool. Prints nan gradients
|
2019-10-08 21:11:47 +00:00
|
|
|
:param weights_summary: str. Options: 'full', 'top', None to not print.
|
2019-09-06 21:01:03 +00:00
|
|
|
:param weights_save_path: Bool. Where to save weights if on cluster
|
|
|
|
:param amp_level: str. Check nvidia docs for level
|
|
|
|
:param nb_sanity_val_steps: int. How many val steps before a full train loop.
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
# Transfer params
|
2019-07-03 20:34:49 +00:00
|
|
|
self.nb_gpu_nodes = nb_gpu_nodes
|
2019-09-04 14:43:46 +00:00
|
|
|
self.log_gpu_memory = log_gpu_memory
|
2019-09-25 23:05:06 +00:00
|
|
|
self.gradient_clip_val = gradient_clip_val
|
2019-03-31 01:45:16 +00:00
|
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
self.track_grad_norm = track_grad_norm
|
2019-07-01 22:38:07 +00:00
|
|
|
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
2019-03-31 01:45:16 +00:00
|
|
|
self.process_position = process_position
|
2019-10-08 19:30:06 +00:00
|
|
|
self.weights_summary = weights_summary
|
2019-03-31 01:45:16 +00:00
|
|
|
self.max_nb_epochs = max_nb_epochs
|
|
|
|
self.min_nb_epochs = min_nb_epochs
|
|
|
|
self.nb_sanity_val_steps = nb_sanity_val_steps
|
2019-07-01 22:38:07 +00:00
|
|
|
self.print_nan_grads = print_nan_grads
|
2019-07-08 13:42:13 +00:00
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
self.fast_dev_run = fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
|
|
|
self.nb_sanity_val_steps = 1
|
|
|
|
self.max_nb_epochs = 1
|
|
|
|
m = '''
|
|
|
|
Running in fast_dev_run mode: will run a full train,
|
|
|
|
val loop using a single batch
|
|
|
|
'''
|
|
|
|
print(m)
|
|
|
|
|
2019-10-04 23:48:57 +00:00
|
|
|
# set default save path if user didn't provide one
|
|
|
|
self.default_save_path = default_save_path
|
|
|
|
if self.default_save_path is None:
|
|
|
|
self.default_save_path = os.getcwd()
|
|
|
|
|
2019-07-24 14:42:01 +00:00
|
|
|
# training bookeeping
|
|
|
|
self.total_batch_nb = 0
|
|
|
|
self.running_loss = []
|
|
|
|
self.avg_loss = 0
|
|
|
|
self.batch_nb = 0
|
|
|
|
self.tqdm_metrics = {}
|
2019-10-08 20:21:00 +00:00
|
|
|
self.callback_metrics = {}
|
2019-08-23 11:42:17 +00:00
|
|
|
self.nb_val_batches = 0
|
2019-09-25 23:05:06 +00:00
|
|
|
self.nb_training_batches = 0
|
2019-08-23 11:42:17 +00:00
|
|
|
self.nb_test_batches = 0
|
2019-10-04 19:35:02 +00:00
|
|
|
self.get_train_dataloader = None
|
|
|
|
self.get_test_dataloaders = None
|
|
|
|
self.get_val_dataloaders = None
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# training state
|
|
|
|
self.model = None
|
|
|
|
self.testing = False
|
|
|
|
self.lr_schedulers = []
|
|
|
|
self.optimizers = None
|
|
|
|
self.global_step = 0
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.total_batches = 0
|
|
|
|
|
2019-09-06 21:01:03 +00:00
|
|
|
# configure early stop callback
|
2019-10-04 23:48:57 +00:00
|
|
|
# creates a default one if none passed in
|
2019-10-15 16:44:20 +00:00
|
|
|
self.early_stop_callback = None
|
2019-10-12 10:00:24 +00:00
|
|
|
if early_stop_callback is True:
|
|
|
|
self.early_stop_callback = EarlyStopping(
|
2019-10-04 23:48:57 +00:00
|
|
|
monitor='val_loss',
|
|
|
|
patience=3,
|
|
|
|
verbose=True,
|
|
|
|
mode='min'
|
|
|
|
)
|
2019-10-12 10:00:24 +00:00
|
|
|
self.enable_early_stop = True
|
|
|
|
elif not early_stop_callback:
|
|
|
|
self.early_stop_callback = None
|
|
|
|
self.enable_early_stop = False
|
|
|
|
else:
|
|
|
|
self.early_stop_callback = early_stop_callback
|
|
|
|
self.enable_early_stop = True
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2019-09-27 16:05:29 +00:00
|
|
|
# configure logger
|
2019-10-12 10:00:24 +00:00
|
|
|
if logger is True:
|
|
|
|
# default logger
|
2019-10-04 23:48:57 +00:00
|
|
|
self.logger = TestTubeLogger(
|
|
|
|
save_dir=self.default_save_path,
|
2019-10-05 18:45:37 +00:00
|
|
|
version=self.slurm_job_id,
|
2019-10-04 23:48:57 +00:00
|
|
|
name='lightning_logs'
|
|
|
|
)
|
2019-10-12 10:00:24 +00:00
|
|
|
self.logger.rank = 0
|
|
|
|
elif logger is False:
|
|
|
|
self.logger = None
|
|
|
|
else:
|
|
|
|
self.logger = logger
|
|
|
|
self.logger.rank = 0
|
2019-10-04 23:48:57 +00:00
|
|
|
|
|
|
|
# configure checkpoint callback
|
|
|
|
self.checkpoint_callback = checkpoint_callback
|
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
self.weights_save_path = weights_save_path
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# accumulated grads
|
|
|
|
self.__configure_accumulated_gradients(accumulate_grad_batches)
|
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
# allow int, string and gpu list
|
2019-09-06 04:29:38 +00:00
|
|
|
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
|
2019-09-11 11:52:36 +00:00
|
|
|
self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# distributed backend choice
|
|
|
|
self.use_ddp = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.use_ddp2 = False
|
2019-09-06 04:29:38 +00:00
|
|
|
self.use_dp = False
|
|
|
|
self.single_gpu = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.distributed_backend = distributed_backend
|
2019-09-06 04:29:38 +00:00
|
|
|
self.__set_distributed_mode(distributed_backend, nb_gpu_nodes)
|
|
|
|
|
|
|
|
# init flags for SLURM+ddp to work
|
|
|
|
self.proc_rank = 0
|
|
|
|
self.world_size = 1
|
|
|
|
self.node_rank = 0
|
2019-09-09 11:37:20 +00:00
|
|
|
self.__configure_slurm_ddp(nb_gpu_nodes)
|
2019-09-08 19:36:58 +00:00
|
|
|
|
|
|
|
# nvidia setup
|
|
|
|
self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# can't init progress bar here because starting a new process
|
2019-09-25 23:05:06 +00:00
|
|
|
# means the progress_bar won't survive pickling
|
2019-09-06 04:29:38 +00:00
|
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
|
|
|
|
# logging
|
|
|
|
self.log_save_interval = log_save_interval
|
|
|
|
self.val_check_interval = val_check_interval
|
2019-09-25 23:05:06 +00:00
|
|
|
self.row_log_interval = row_log_interval
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# how much of the data to use
|
|
|
|
self.__determine_data_use_amount(train_percent_check, val_percent_check,
|
|
|
|
test_percent_check, overfit_pct)
|
|
|
|
|
|
|
|
# 16 bit mixed precision training using apex
|
|
|
|
self.amp_level = amp_level
|
|
|
|
self.__init_amp(use_amp)
|
|
|
|
|
2019-10-05 18:45:37 +00:00
|
|
|
@property
|
|
|
|
def slurm_job_id(self):
|
|
|
|
try:
|
|
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
|
|
job_id = int(job_id)
|
|
|
|
except Exception as e:
|
|
|
|
job_id = None
|
|
|
|
return job_id
|
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
def __configure_checkpoint_callback(self):
|
2019-09-06 21:01:03 +00:00
|
|
|
"""
|
|
|
|
Weight path set in this priority:
|
|
|
|
Checkpoint_callback's path (if passed in).
|
|
|
|
User provided weights_saved_path
|
|
|
|
Otherwise use os.getcwd()
|
|
|
|
"""
|
2019-10-12 10:00:24 +00:00
|
|
|
if self.checkpoint_callback is True:
|
2019-10-09 21:46:27 +00:00
|
|
|
# init a default one
|
|
|
|
if isinstance(self.logger, TestTubeLogger):
|
2019-10-16 11:28:47 +00:00
|
|
|
ckpt_path = '{}/{}/version_{}/{}'.format(
|
2019-10-09 21:49:29 +00:00
|
|
|
self.default_save_path,
|
|
|
|
self.logger.experiment.name,
|
|
|
|
self.logger.experiment.version,
|
|
|
|
'checkpoints')
|
2019-10-09 21:46:27 +00:00
|
|
|
else:
|
|
|
|
ckpt_path = self.default_save_path
|
2019-09-06 21:01:03 +00:00
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
self.checkpoint_callback = ModelCheckpoint(
|
|
|
|
filepath=ckpt_path
|
|
|
|
)
|
2019-10-12 10:00:24 +00:00
|
|
|
elif self.checkpoint_callback is False:
|
|
|
|
self.checkpoint_callback = None
|
2019-09-06 21:01:03 +00:00
|
|
|
|
2019-10-12 10:00:24 +00:00
|
|
|
if self.checkpoint_callback:
|
|
|
|
# set the path for the callbacks
|
|
|
|
self.checkpoint_callback.save_function = self.save_checkpoint
|
2019-09-06 21:01:03 +00:00
|
|
|
|
2019-10-12 10:00:24 +00:00
|
|
|
# if checkpoint callback used, then override the weights path
|
|
|
|
self.weights_save_path = self.checkpoint_callback.filepath
|
2019-10-09 21:46:27 +00:00
|
|
|
|
|
|
|
# if weights_save_path is still none here, set to current working dir
|
2019-09-08 19:36:58 +00:00
|
|
|
if self.weights_save_path is None:
|
2019-10-04 23:48:57 +00:00
|
|
|
self.weights_save_path = self.default_save_path
|
2019-09-06 21:01:03 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
def __init_amp(self, use_amp):
|
|
|
|
self.use_amp = use_amp and APEX_AVAILABLE
|
|
|
|
if self.use_amp:
|
|
|
|
print('using 16bit precision')
|
|
|
|
|
|
|
|
if use_amp and not APEX_AVAILABLE: # pragma: no cover
|
|
|
|
msg = """
|
|
|
|
You set use_amp=True but do not have apex installed.
|
|
|
|
Install apex first using this guide and rerun with use_amp=True:
|
|
|
|
https://github.com/NVIDIA/apex#linux
|
|
|
|
|
|
|
|
this run will NOT use 16 bit precision
|
|
|
|
"""
|
|
|
|
raise ModuleNotFoundError(msg)
|
|
|
|
|
|
|
|
def __configure_accumulated_gradients(self, accumulate_grad_batches):
|
2019-10-05 14:47:18 +00:00
|
|
|
self.accumulate_grad_batches = None
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
if isinstance(accumulate_grad_batches, dict):
|
|
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
|
|
|
elif isinstance(accumulate_grad_batches, int):
|
|
|
|
schedule = {1: accumulate_grad_batches}
|
|
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
|
|
|
else:
|
|
|
|
raise TypeError("Gradient accumulation supports only int and dict types")
|
2019-07-24 14:42:01 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
def __parse_gpu_ids(self, gpus):
|
2019-09-08 19:36:58 +00:00
|
|
|
"""
|
|
|
|
:param gpus: Int, string or list of ids
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 13:42:13 +00:00
|
|
|
# if gpus = -1 then use all available devices
|
|
|
|
# otherwise, split the string using commas
|
|
|
|
if gpus is not None:
|
2019-07-21 12:08:21 +00:00
|
|
|
if type(gpus) is list:
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = gpus
|
2019-07-21 12:08:21 +00:00
|
|
|
elif type(gpus) is str:
|
|
|
|
if gpus == '-1':
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = list(range(0, torch.cuda.device_count()))
|
2019-07-21 12:08:21 +00:00
|
|
|
else:
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = [int(x.strip()) for x in gpus.split(',')]
|
2019-09-08 19:36:58 +00:00
|
|
|
elif type(gpus) is int:
|
|
|
|
gpus = gpus
|
2019-07-08 13:42:13 +00:00
|
|
|
else:
|
2019-09-08 19:36:58 +00:00
|
|
|
raise Exception('gpus has to be a string, int or list of ints')
|
2019-07-08 14:00:04 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
return gpus
|
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
def __set_root_gpu(self, gpus):
|
|
|
|
if gpus is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# set root gpu
|
|
|
|
root_gpu = 0
|
|
|
|
if type(gpus) is list:
|
|
|
|
root_gpu = gpus[0]
|
|
|
|
|
|
|
|
return root_gpu
|
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
@property
|
|
|
|
def num_gpus(self):
|
|
|
|
gpus = self.data_parallel_device_ids
|
|
|
|
if gpus is None:
|
|
|
|
return 0
|
2019-10-04 19:07:54 +00:00
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
if type(gpus) is list:
|
|
|
|
return len(gpus)
|
|
|
|
if type(gpus) is int:
|
|
|
|
return gpus
|
|
|
|
|
|
|
|
m = 'gpus must be int, none or list of ints'
|
|
|
|
raise MisconfigurationException(m)
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes):
|
2019-09-09 11:37:20 +00:00
|
|
|
# skip for CPU
|
|
|
|
if self.num_gpus == 0:
|
|
|
|
return
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2019-09-09 11:37:20 +00:00
|
|
|
# single GPU case
|
2019-09-26 18:39:04 +00:00
|
|
|
# in single gpu case we allow ddp so we can train on multiple
|
|
|
|
# nodes, 1 gpu per node
|
2019-09-09 11:37:20 +00:00
|
|
|
if self.num_gpus == 1:
|
|
|
|
self.single_gpu = True
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2019-09-09 11:37:20 +00:00
|
|
|
if distributed_backend is not None:
|
2019-09-08 19:36:58 +00:00
|
|
|
self.use_dp = distributed_backend == 'dp'
|
|
|
|
self.use_ddp = distributed_backend == 'ddp'
|
2019-10-04 19:07:54 +00:00
|
|
|
self.use_ddp2 = distributed_backend == 'ddp2'
|
|
|
|
|
|
|
|
# disable single gpu when using ddp2
|
|
|
|
if self.use_ddp2:
|
|
|
|
self.single_gpu = False
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2019-09-09 11:37:20 +00:00
|
|
|
# multiple GPU case
|
|
|
|
elif self.num_gpus > 1:
|
|
|
|
if distributed_backend is not None:
|
|
|
|
# DP, DDP case
|
|
|
|
self.use_dp = distributed_backend == 'dp'
|
|
|
|
self.use_ddp = distributed_backend == 'ddp'
|
2019-10-04 19:07:54 +00:00
|
|
|
self.use_ddp2 = distributed_backend == 'ddp2'
|
2019-09-08 19:36:58 +00:00
|
|
|
|
|
|
|
elif distributed_backend is None:
|
|
|
|
m = 'When using multiple GPUs set ' \
|
|
|
|
'Trainer(distributed_backend=dp) (or ddp)'
|
|
|
|
raise MisconfigurationException(m)
|
2019-08-07 17:39:40 +00:00
|
|
|
|
2019-09-09 11:37:20 +00:00
|
|
|
# use ddp automatically if nb_gpu_nodes > 1
|
|
|
|
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
|
|
|
|
self.use_ddp = True
|
|
|
|
self.use_dp = False
|
|
|
|
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
|
|
|
|
'Switching to DistributedDataParallel for you. ' \
|
|
|
|
'To silence this warning set distributed_backend=ddp'
|
|
|
|
warnings.warn(w)
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
|
|
|
|
2019-09-09 11:37:20 +00:00
|
|
|
def __configure_slurm_ddp(self, nb_gpu_nodes):
|
2019-09-08 19:36:58 +00:00
|
|
|
self.is_slurm_managing_tasks = False
|
|
|
|
|
2019-07-24 22:42:22 +00:00
|
|
|
# extract SLURM flag vars
|
|
|
|
# whenever we have the correct number of tasks, we let slurm manage processes
|
|
|
|
# otherwise we launch the required number of processes
|
|
|
|
if self.use_ddp:
|
2019-09-09 11:37:20 +00:00
|
|
|
self.nb_requested_gpus = self.num_gpus * nb_gpu_nodes
|
2019-07-24 22:42:22 +00:00
|
|
|
self.nb_slurm_tasks = 0
|
|
|
|
try:
|
|
|
|
self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS'])
|
|
|
|
self.is_slurm_managing_tasks = self.nb_slurm_tasks == self.nb_requested_gpus
|
2019-09-09 11:37:20 +00:00
|
|
|
|
|
|
|
# in interactive mode we don't manage tasks
|
|
|
|
job_name = os.environ['SLURM_JOB_NAME']
|
|
|
|
if job_name == 'bash':
|
|
|
|
self.is_slurm_managing_tasks = False
|
|
|
|
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-07-24 22:42:22 +00:00
|
|
|
# likely not on slurm, so set the slurm managed flag to false
|
|
|
|
self.is_slurm_managing_tasks = False
|
|
|
|
|
2019-09-26 18:39:04 +00:00
|
|
|
# used for tests only, set this flag to simulate slurm managing a task
|
|
|
|
try:
|
|
|
|
should_fake = int(os.environ['FAKE_SLURM_MANAGING_TASKS'])
|
|
|
|
if should_fake:
|
|
|
|
self.is_slurm_managing_tasks = True
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
def __set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
|
|
|
|
if data_parallel_device_ids is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
# set the correct cuda visible devices (using pci order)
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
|
|
|
|
|
|
# when slurm is managing the task it sets the visible devices
|
|
|
|
if not is_slurm_managing_tasks:
|
|
|
|
if type(data_parallel_device_ids) is int:
|
|
|
|
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
|
|
|
|
else:
|
|
|
|
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
|
|
|
|
|
|
|
|
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
|
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
@property
|
|
|
|
def data_parallel(self):
|
2019-10-05 20:39:05 +00:00
|
|
|
return self.use_dp or self.use_ddp or self.use_ddp2
|
2019-07-18 15:08:48 +00:00
|
|
|
|
2019-08-06 10:08:31 +00:00
|
|
|
def __determine_data_use_amount(self, train_percent_check, val_percent_check,
|
|
|
|
test_percent_check, overfit_pct):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
Use less data for debugging purposes
|
|
|
|
"""
|
|
|
|
self.train_percent_check = train_percent_check
|
|
|
|
self.val_percent_check = val_percent_check
|
|
|
|
self.test_percent_check = test_percent_check
|
|
|
|
if overfit_pct > 0:
|
|
|
|
self.train_percent_check = overfit_pct
|
|
|
|
self.val_percent_check = overfit_pct
|
|
|
|
self.test_percent_check = overfit_pct
|
|
|
|
|
2019-07-12 16:42:17 +00:00
|
|
|
def __get_model(self):
|
|
|
|
return self.model.module if self.data_parallel else self.model
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def __is_function_implemented(self, f_name):
|
2019-07-12 16:42:17 +00:00
|
|
|
model = self.__get_model()
|
|
|
|
f_op = getattr(model, f_name, None)
|
2019-03-31 01:45:16 +00:00
|
|
|
return callable(f_op)
|
|
|
|
|
2019-08-11 14:01:57 +00:00
|
|
|
def __is_overriden(self, f_name):
|
|
|
|
model = self.__get_model()
|
2019-09-02 19:46:16 +00:00
|
|
|
super_object = LightningModule
|
2019-08-11 14:01:57 +00:00
|
|
|
|
|
|
|
# when code pointers are different, it was overriden
|
|
|
|
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
|
|
|
|
return is_overriden
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
@property
|
2019-09-25 23:05:06 +00:00
|
|
|
def __training_tqdm_dict(self):
|
|
|
|
tqdm_dict = {
|
2019-08-16 15:58:44 +00:00
|
|
|
'loss': '{0:.3f}'.format(self.avg_loss),
|
2019-03-31 01:45:16 +00:00
|
|
|
'epoch': '{}'.format(self.current_epoch),
|
2019-08-05 21:57:39 +00:00
|
|
|
'batch_nb': '{}'.format(self.batch_nb),
|
2019-03-31 01:45:16 +00:00
|
|
|
}
|
2019-08-08 14:59:16 +00:00
|
|
|
|
2019-09-27 16:05:29 +00:00
|
|
|
if self.logger is not None and self.logger.version is not None:
|
|
|
|
tqdm_dict['v_nb'] = self.logger.version
|
2019-08-08 14:59:16 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_dict.update(self.tqdm_metrics)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
if self.on_gpu:
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device())
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
return tqdm_dict
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-24 12:53:00 +00:00
|
|
|
@property
|
2019-09-25 23:05:06 +00:00
|
|
|
def training_tqdm_dict(self):
|
2019-07-24 12:53:00 +00:00
|
|
|
"""
|
|
|
|
Read-only for tqdm metrics
|
|
|
|
:return:
|
|
|
|
"""
|
2019-09-25 23:05:06 +00:00
|
|
|
return self.__training_tqdm_dict
|
2019-07-24 12:53:00 +00:00
|
|
|
|
2019-07-09 00:13:40 +00:00
|
|
|
def __layout_bookeeping(self):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine number of training batches
|
2019-10-04 19:35:02 +00:00
|
|
|
self.nb_training_batches = len(self.get_train_dataloader())
|
2019-09-25 23:05:06 +00:00
|
|
|
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine number of validation batches
|
2019-08-12 19:23:11 +00:00
|
|
|
# val datasets could be none, 1 or 2+
|
2019-10-04 19:35:02 +00:00
|
|
|
if self.get_val_dataloaders() is not None:
|
|
|
|
self.nb_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
|
2019-08-23 11:42:17 +00:00
|
|
|
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
|
|
|
|
self.nb_val_batches = max(1, self.nb_val_batches)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine number of test batches
|
2019-10-04 19:35:02 +00:00
|
|
|
if self.get_test_dataloaders() is not None:
|
|
|
|
self.nb_test_batches = sum(
|
|
|
|
len(dataloader) for dataloader in self.get_test_dataloaders()
|
|
|
|
)
|
2019-08-30 22:56:09 +00:00
|
|
|
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
|
|
|
|
self.nb_test_batches = max(1, self.nb_test_batches)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine when to check validation
|
2019-09-25 23:05:06 +00:00
|
|
|
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
|
2019-08-18 22:15:58 +00:00
|
|
|
self.val_check_batch = max(1, self.val_check_batch)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
def __add_tqdm_metrics(self, metrics):
|
|
|
|
for k, v in metrics.items():
|
2019-07-01 22:38:07 +00:00
|
|
|
if type(v) is torch.Tensor:
|
|
|
|
v = v.item()
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.tqdm_metrics[k] = v
|
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
def __evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
|
|
|
|
# make dataloader_idx arg in validation_step optional
|
|
|
|
args = [batch, batch_idx]
|
2019-08-30 22:56:09 +00:00
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
if test and len(self.get_test_dataloaders()) > 1:
|
2019-09-25 23:05:06 +00:00
|
|
|
args.append(dataloader_idx)
|
2019-08-13 15:37:37 +00:00
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
elif not test and len(self.get_val_dataloaders()) > 1:
|
2019-09-25 23:05:06 +00:00
|
|
|
args.append(dataloader_idx)
|
2019-08-30 22:56:09 +00:00
|
|
|
|
|
|
|
# handle DP, DDP forward
|
2019-10-05 20:39:05 +00:00
|
|
|
if self.use_ddp or self.use_dp or self.use_ddp2:
|
2019-08-13 15:37:37 +00:00
|
|
|
output = model(*args)
|
2019-08-30 22:56:09 +00:00
|
|
|
return output
|
|
|
|
|
2019-10-04 19:07:54 +00:00
|
|
|
# single GPU
|
2019-08-30 22:56:09 +00:00
|
|
|
if self.single_gpu:
|
|
|
|
# for single GPU put inputs on gpu manually
|
2019-09-08 19:36:58 +00:00
|
|
|
root_gpu = 0
|
|
|
|
if type(self.data_parallel_device_ids) is list:
|
|
|
|
root_gpu = self.data_parallel_device_ids[0]
|
2019-09-25 23:05:06 +00:00
|
|
|
batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
|
|
|
args[0] = batch
|
2019-08-13 15:37:37 +00:00
|
|
|
|
2019-10-04 19:07:54 +00:00
|
|
|
# CPU
|
2019-08-30 22:56:09 +00:00
|
|
|
if test:
|
|
|
|
output = model.test_step(*args)
|
2019-08-13 15:37:37 +00:00
|
|
|
else:
|
|
|
|
output = model.validation_step(*args)
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
def evaluate(self, model, dataloaders, max_batches, test=False):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-08-30 22:56:09 +00:00
|
|
|
Run evaluation code
|
2019-03-31 01:45:16 +00:00
|
|
|
:param model: PT model
|
2019-09-06 11:37:25 +00:00
|
|
|
:param dataloaders: list of PT dataloaders
|
2019-03-31 01:45:16 +00:00
|
|
|
:param max_batches: Scalar
|
2019-08-30 22:56:09 +00:00
|
|
|
:param test: boolean
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
# enable eval mode
|
|
|
|
model.zero_grad()
|
|
|
|
model.eval()
|
|
|
|
|
2019-10-18 22:39:30 +00:00
|
|
|
# copy properties for forward overrides
|
|
|
|
self.__copy_trainer_model_properties(model)
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# disable gradients to save memory
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
# bookkeeping
|
|
|
|
outputs = []
|
|
|
|
|
|
|
|
# run training
|
2019-10-04 19:35:02 +00:00
|
|
|
for dataloader_idx, dataloader in enumerate(dataloaders):
|
2019-09-06 11:37:25 +00:00
|
|
|
dl_outputs = []
|
2019-10-04 19:35:02 +00:00
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
if batch is None: # pragma: no cover
|
2019-09-06 11:37:25 +00:00
|
|
|
continue
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
# stop short when on fast_dev_run (sets max_batch=1)
|
2019-09-25 23:05:06 +00:00
|
|
|
if batch_idx >= max_batches:
|
2019-09-06 11:37:25 +00:00
|
|
|
break
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
# -----------------
|
|
|
|
# RUN EVALUATION STEP
|
|
|
|
# -----------------
|
2019-10-05 17:35:20 +00:00
|
|
|
output = self.__evaluation_forward(model,
|
|
|
|
batch,
|
|
|
|
batch_idx,
|
|
|
|
dataloader_idx,
|
2019-09-06 11:37:25 +00:00
|
|
|
test)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
# track outputs for collation
|
|
|
|
dl_outputs.append(output)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
# batch done
|
|
|
|
if self.show_progress_bar:
|
|
|
|
self.progress_bar.update(1)
|
|
|
|
outputs.append(dl_outputs)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
eval_results = {}
|
|
|
|
|
2019-10-04 19:07:54 +00:00
|
|
|
# with a single dataloader don't pass an array
|
2019-09-06 11:37:25 +00:00
|
|
|
if len(dataloaders) == 1:
|
|
|
|
outputs = outputs[0]
|
2019-10-04 19:07:54 +00:00
|
|
|
|
|
|
|
# give model a chance to do something with the outputs (and method defined)
|
|
|
|
model = self.__get_model()
|
2019-08-30 22:56:09 +00:00
|
|
|
if test and self.__is_overriden('test_end'):
|
|
|
|
eval_results = model.test_end(outputs)
|
|
|
|
elif self.__is_overriden('validation_end'):
|
|
|
|
eval_results = model.validation_end(outputs)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# enable train mode again
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
# enable gradients to save memory
|
|
|
|
torch.set_grad_enabled(True)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
return eval_results
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-24 21:09:14 +00:00
|
|
|
def get_dataloaders(self, model):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
Dataloaders are provided by the model
|
|
|
|
:param model:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-10-04 19:35:02 +00:00
|
|
|
self.get_train_dataloader = model.train_dataloader
|
|
|
|
self.get_test_dataloaders = model.test_dataloader
|
|
|
|
self.get_val_dataloaders = model.val_dataloader
|
2019-08-11 14:01:57 +00:00
|
|
|
|
2019-10-05 20:39:05 +00:00
|
|
|
# call warnings from proc zero only which triggers dataloaders
|
|
|
|
# if those have to download data it will only happen on proc 0
|
|
|
|
if self.proc_rank == 0:
|
2019-10-05 20:56:24 +00:00
|
|
|
on_ddp = self.use_ddp or self.use_ddp2
|
|
|
|
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
|
2019-10-05 20:39:05 +00:00
|
|
|
msg = """
|
|
|
|
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
|
|
|
to assign a subset of your data to each process. To silence this warning, pass a
|
|
|
|
DistributedSampler to your DataLoader.
|
|
|
|
|
|
|
|
ie: this:
|
|
|
|
dataset = myDataset()
|
|
|
|
dataloader = Dataloader(dataset)
|
|
|
|
|
|
|
|
becomes:
|
|
|
|
dataset = myDataset()
|
|
|
|
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
|
|
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
|
|
|
|
|
|
|
If you want each process to load the full dataset, ignore this warning.
|
|
|
|
"""
|
|
|
|
warnings.warn(msg)
|
|
|
|
|
2019-10-06 03:52:32 +00:00
|
|
|
if on_ddp and self.get_val_dataloaders() is not None:
|
2019-10-05 20:39:05 +00:00
|
|
|
for dataloader in self.get_val_dataloaders():
|
|
|
|
if not isinstance(dataloader.sampler, DistributedSampler):
|
|
|
|
msg = """
|
|
|
|
Your val_dataloader(s) don't use DistributedSampler.
|
2019-10-06 03:52:32 +00:00
|
|
|
|
2019-10-05 20:56:24 +00:00
|
|
|
You're using multiple gpus and multiple nodes without using a
|
|
|
|
DistributedSampler to assign a subset of your data to each process.
|
|
|
|
To silence this warning, pass a DistributedSampler to your DataLoader.
|
2019-10-05 20:39:05 +00:00
|
|
|
|
|
|
|
ie: this:
|
|
|
|
dataset = myDataset()
|
|
|
|
dataloader = Dataloader(dataset)
|
|
|
|
|
|
|
|
becomes:
|
|
|
|
dataset = myDataset()
|
|
|
|
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
|
|
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
|
|
|
|
|
|
|
If you want each process to load the full dataset, ignore this warning.
|
|
|
|
"""
|
|
|
|
warnings.warn(msg)
|
|
|
|
break
|
|
|
|
|
2019-10-06 03:52:32 +00:00
|
|
|
if on_ddp and self.get_test_dataloaders() is not None:
|
2019-10-05 20:39:05 +00:00
|
|
|
for dataloader in self.get_test_dataloaders():
|
|
|
|
if not isinstance(dataloader.sampler, DistributedSampler):
|
|
|
|
msg = """
|
|
|
|
Your test_dataloader(s) don't use DistributedSampler.
|
2019-10-06 03:52:32 +00:00
|
|
|
|
2019-10-05 20:56:24 +00:00
|
|
|
You're using multiple gpus and multiple nodes without using a
|
|
|
|
DistributedSampler to assign a subset of your data to each process.
|
|
|
|
To silence this warning, pass a DistributedSampler to your DataLoader.
|
2019-10-05 20:39:05 +00:00
|
|
|
|
|
|
|
ie: this:
|
|
|
|
dataset = myDataset()
|
|
|
|
dataloader = Dataloader(dataset)
|
|
|
|
|
|
|
|
becomes:
|
|
|
|
dataset = myDataset()
|
|
|
|
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
|
|
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
|
|
|
|
|
|
|
If you want each process to load the full dataset, ignore this warning.
|
|
|
|
"""
|
|
|
|
warnings.warn(msg)
|
|
|
|
break
|
|
|
|
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
|
|
|
# wait for all processes to catch up
|
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
# load each dataloader
|
|
|
|
self.get_train_dataloader()
|
|
|
|
self.get_test_dataloaders()
|
|
|
|
self.get_val_dataloaders()
|
2019-08-30 22:56:09 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# -----------------------------
|
|
|
|
# MODEL TRAINING
|
|
|
|
# -----------------------------
|
|
|
|
def fit(self, model):
|
2019-07-18 15:08:48 +00:00
|
|
|
# when using multi-node or DDP within a node start each module in a separate process
|
2019-10-05 20:39:05 +00:00
|
|
|
if self.use_ddp2:
|
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
2019-07-18 20:47:46 +00:00
|
|
|
|
2019-10-05 20:39:05 +00:00
|
|
|
elif self.use_ddp:
|
|
|
|
if self.is_slurm_managing_tasks:
|
2019-07-18 20:47:46 +00:00
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
|
|
|
else:
|
2019-09-08 19:36:58 +00:00
|
|
|
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
# 1 gpu or dp option triggers training using DP module
|
|
|
|
# easier to avoid NCCL issues
|
|
|
|
elif self.use_dp:
|
2019-07-24 14:38:22 +00:00
|
|
|
self.__dp_train(model)
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-08-07 17:39:40 +00:00
|
|
|
elif self.single_gpu:
|
2019-08-07 17:49:01 +00:00
|
|
|
self.__single_gpu_train(model)
|
2019-08-07 17:39:40 +00:00
|
|
|
|
2019-07-18 15:09:00 +00:00
|
|
|
# ON CPU
|
2019-07-03 19:09:49 +00:00
|
|
|
else:
|
2019-07-11 18:17:43 +00:00
|
|
|
# run through amp wrapper
|
|
|
|
if self.use_amp:
|
2019-08-06 10:08:31 +00:00
|
|
|
raise MisconfigurationException('amp + cpu is not supported.'
|
|
|
|
' Please use a GPU option')
|
2019-07-11 18:17:43 +00:00
|
|
|
|
2019-07-25 15:08:31 +00:00
|
|
|
# CHOOSE OPTIMIZER
|
2019-07-28 13:33:58 +00:00
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-07-25 15:08:31 +00:00
|
|
|
|
2019-07-03 19:09:49 +00:00
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-07-24 11:26:18 +00:00
|
|
|
# return 1 when finished
|
|
|
|
# used for testing or when we need to know that training succeeded
|
|
|
|
return 1
|
|
|
|
|
2019-08-15 15:31:56 +00:00
|
|
|
def init_optimizers(self, optimizers):
|
|
|
|
|
|
|
|
# single optimizer
|
|
|
|
if isinstance(optimizers, Optimizer):
|
|
|
|
return [optimizers], []
|
|
|
|
|
|
|
|
# two lists
|
|
|
|
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
|
|
|
|
optimizers, lr_schedulers = optimizers
|
|
|
|
return optimizers, lr_schedulers
|
|
|
|
|
|
|
|
# single list or tuple
|
|
|
|
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
|
|
|
|
return optimizers, []
|
|
|
|
|
2019-08-07 17:39:40 +00:00
|
|
|
def __single_gpu_train(self, model):
|
|
|
|
# CHOOSE OPTIMIZER
|
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-08-07 17:39:40 +00:00
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
model.cuda(self.root_gpu)
|
2019-08-07 17:39:40 +00:00
|
|
|
|
|
|
|
if self.use_amp:
|
|
|
|
# An example
|
|
|
|
model, optimizers = amp.initialize(
|
|
|
|
model, self.optimizers, opt_level=self.amp_level,
|
|
|
|
)
|
|
|
|
self.optimizers = optimizers
|
|
|
|
|
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-07-24 14:38:22 +00:00
|
|
|
def __dp_train(self, model):
|
2019-07-14 20:57:15 +00:00
|
|
|
|
|
|
|
# CHOOSE OPTIMIZER
|
2019-07-28 13:33:58 +00:00
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
model.cuda(self.root_gpu)
|
2019-07-24 17:56:49 +00:00
|
|
|
|
2019-07-24 18:11:05 +00:00
|
|
|
# check for this bug (amp + dp + !01 doesn't work)
|
|
|
|
# https://github.com/NVIDIA/apex/issues/227
|
2019-07-24 23:43:38 +00:00
|
|
|
if self.use_dp and self.use_amp:
|
2019-08-15 15:31:56 +00:00
|
|
|
m = f"""
|
|
|
|
Amp level {self.amp_level} with DataParallel is not supported.
|
|
|
|
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
|
|
|
|
We recommend you switch to ddp if you want to use amp
|
|
|
|
"""
|
2019-07-24 20:57:21 +00:00
|
|
|
raise MisconfigurationException(m)
|
2019-07-24 18:11:05 +00:00
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
# create list of device ids
|
|
|
|
device_ids = self.data_parallel_device_ids
|
|
|
|
if type(device_ids) is int:
|
|
|
|
device_ids = list(range(device_ids))
|
|
|
|
|
|
|
|
model = LightningDataParallel(model, device_ids=device_ids)
|
2019-07-23 17:30:07 +00:00
|
|
|
|
2019-07-14 20:57:15 +00:00
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-10-18 22:39:30 +00:00
|
|
|
def __copy_trainer_model_properties(self, model):
|
|
|
|
if isinstance(model, LightningDataParallel):
|
|
|
|
ref_model = model.module
|
|
|
|
elif isinstance(model, LightningDistributedDataParallel):
|
|
|
|
ref_model = model.module
|
|
|
|
else:
|
|
|
|
ref_model = model
|
|
|
|
|
|
|
|
for m in [model, ref_model]:
|
|
|
|
m.trainer = self
|
|
|
|
m.on_gpu = self.on_gpu
|
|
|
|
m.use_dp = self.use_dp
|
|
|
|
m.use_ddp2 = self.use_ddp2
|
|
|
|
m.use_ddp = self.use_ddp
|
|
|
|
m.use_amp = self.use_amp
|
|
|
|
m.testing = self.testing
|
|
|
|
m.single_gpu = self.single_gpu
|
|
|
|
|
2019-07-24 14:51:35 +00:00
|
|
|
def ddp_train(self, gpu_nb, model):
|
2019-07-03 19:09:49 +00:00
|
|
|
"""
|
|
|
|
Entry point into a DP thread
|
|
|
|
:param gpu_nb:
|
|
|
|
:param model:
|
|
|
|
:param cluster_obj:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 17:48:59 +00:00
|
|
|
# node rank using relative slurm id
|
2019-07-08 21:31:47 +00:00
|
|
|
# otherwise default to node rank 0
|
|
|
|
try:
|
2019-07-20 13:08:24 +00:00
|
|
|
node_id = os.environ['SLURM_NODEID']
|
2019-07-20 13:15:09 +00:00
|
|
|
self.node_rank = int(node_id)
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-07-20 12:38:17 +00:00
|
|
|
self.node_rank = 0
|
2019-07-08 16:27:53 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
# show progressbar only on progress_rank 0
|
2019-08-24 01:23:27 +00:00
|
|
|
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
|
2019-07-08 16:27:53 +00:00
|
|
|
|
2019-07-08 13:36:09 +00:00
|
|
|
# determine which process we are and world size
|
2019-10-04 19:07:54 +00:00
|
|
|
if self.use_ddp:
|
|
|
|
self.proc_rank = self.node_rank * self.num_gpus + gpu_nb
|
|
|
|
self.world_size = self.nb_gpu_nodes * self.num_gpus
|
|
|
|
|
|
|
|
elif self.use_ddp2:
|
|
|
|
self.proc_rank = self.node_rank
|
|
|
|
self.world_size = self.nb_gpu_nodes
|
2019-07-08 13:36:09 +00:00
|
|
|
|
2019-07-26 22:52:02 +00:00
|
|
|
# let the exp know the rank to avoid overwriting logs
|
2019-09-27 16:05:29 +00:00
|
|
|
if self.logger is not None:
|
|
|
|
self.logger.rank = self.proc_rank
|
2019-07-26 22:52:02 +00:00
|
|
|
|
2019-07-08 13:36:09 +00:00
|
|
|
# set up server using proc 0's ip address
|
2019-07-11 18:35:41 +00:00
|
|
|
# try to init for 20 times at max in case ports are taken
|
2019-07-12 16:41:54 +00:00
|
|
|
# where to store ip_table
|
2019-07-12 17:19:10 +00:00
|
|
|
self.__init_tcp_connection()
|
2019-07-03 19:09:49 +00:00
|
|
|
|
2019-07-11 19:23:33 +00:00
|
|
|
# CHOOSE OPTIMIZER
|
2019-07-28 13:33:58 +00:00
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-07-11 19:23:33 +00:00
|
|
|
|
|
|
|
# MODEL
|
2019-07-03 19:09:49 +00:00
|
|
|
# copy model to each gpu
|
2019-10-04 19:07:54 +00:00
|
|
|
if self.distributed_backend == 'ddp':
|
|
|
|
torch.cuda.set_device(gpu_nb)
|
2019-07-03 19:09:49 +00:00
|
|
|
model.cuda(gpu_nb)
|
2019-07-11 18:17:43 +00:00
|
|
|
|
2019-10-04 19:07:54 +00:00
|
|
|
# set model properties before going into wrapper
|
2019-10-18 22:39:30 +00:00
|
|
|
self.__copy_trainer_model_properties(model)
|
2019-10-04 19:07:54 +00:00
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
# override root GPU
|
|
|
|
self.root_gpu = gpu_nb
|
|
|
|
|
2019-07-11 19:23:33 +00:00
|
|
|
# AMP
|
2019-07-11 18:17:43 +00:00
|
|
|
# run through amp wrapper before going to distributed DP
|
|
|
|
if self.use_amp:
|
|
|
|
# An example
|
|
|
|
model, optimizers = amp.initialize(
|
|
|
|
model, self.optimizers, opt_level=self.amp_level,
|
|
|
|
)
|
|
|
|
self.optimizers = optimizers
|
|
|
|
|
2019-10-04 19:07:54 +00:00
|
|
|
# DDP2 uses all GPUs on the machine
|
|
|
|
if self.distributed_backend == 'ddp':
|
|
|
|
device_ids = [gpu_nb]
|
|
|
|
elif self.use_ddp2:
|
|
|
|
device_ids = None
|
|
|
|
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model,
|
|
|
|
device_ids=device_ids,
|
|
|
|
find_unused_parameters=True
|
|
|
|
)
|
2019-07-03 19:09:49 +00:00
|
|
|
|
|
|
|
# continue training routine
|
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-07-12 20:07:57 +00:00
|
|
|
def __init_tcp_connection(self):
|
2019-07-12 17:39:58 +00:00
|
|
|
"""
|
|
|
|
Connect all procs in the world using the env:// init
|
|
|
|
Use the first node as the root address
|
|
|
|
:param port:
|
|
|
|
:param tries:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-10-05 18:45:37 +00:00
|
|
|
|
|
|
|
# use slurm job id for the port number
|
|
|
|
# guarantees unique ports across jobs from same grid search
|
|
|
|
try:
|
|
|
|
# use the last 4 numbers in the job id as the id
|
|
|
|
default_port = os.environ['SLURM_JOB_ID']
|
|
|
|
default_port = default_port[-4:]
|
|
|
|
|
|
|
|
# all ports should be in the 10k+ range
|
2019-10-05 20:39:05 +00:00
|
|
|
default_port = int(default_port) + 15000
|
2019-10-05 18:45:37 +00:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
default_port = 12910
|
|
|
|
|
|
|
|
# if user gave a port number, use that one instead
|
2019-07-24 19:11:29 +00:00
|
|
|
try:
|
2019-10-05 20:39:05 +00:00
|
|
|
default_port = os.environ['MASTER_PORT']
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-10-05 18:45:37 +00:00
|
|
|
os.environ['MASTER_PORT'] = str(default_port)
|
2019-07-12 19:55:28 +00:00
|
|
|
|
2019-07-24 22:53:12 +00:00
|
|
|
# figure out the root node addr
|
2019-07-25 15:05:15 +00:00
|
|
|
try:
|
|
|
|
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-07-25 15:05:15 +00:00
|
|
|
root_node = '127.0.0.2'
|
|
|
|
|
2019-07-24 22:53:12 +00:00
|
|
|
root_node = self.resolve_root_node_address(root_node)
|
2019-07-12 19:11:32 +00:00
|
|
|
os.environ['MASTER_ADDR'] = root_node
|
2019-07-12 20:05:46 +00:00
|
|
|
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
|
2019-07-11 18:35:41 +00:00
|
|
|
|
2019-07-24 22:53:12 +00:00
|
|
|
def resolve_root_node_address(self, root_node):
|
2019-07-25 15:05:15 +00:00
|
|
|
if '[' in root_node:
|
|
|
|
name = root_node.split('[')[0]
|
|
|
|
number = root_node.split(',')[0]
|
|
|
|
if '-' in number:
|
|
|
|
number = number.split('-')[0]
|
|
|
|
|
|
|
|
number = re.sub('[^0-9]', '', number)
|
|
|
|
root_node = name + number
|
2019-07-20 12:53:24 +00:00
|
|
|
|
|
|
|
return root_node
|
|
|
|
|
2019-07-03 19:09:49 +00:00
|
|
|
def __run_pretrain_routine(self, model):
|
|
|
|
"""
|
|
|
|
Sanity check a few things before starting actual training
|
|
|
|
:param model:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model
|
2019-07-14 02:21:17 +00:00
|
|
|
if self.data_parallel:
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model.module
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# give model convenience properties
|
2019-07-08 22:55:05 +00:00
|
|
|
ref_model.trainer = self
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# set local properties on the model
|
2019-10-18 22:39:30 +00:00
|
|
|
self.__copy_trainer_model_properties(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-10-10 19:16:19 +00:00
|
|
|
# link up experiment object
|
|
|
|
if self.logger is not None:
|
|
|
|
ref_model.logger = self.logger
|
|
|
|
|
|
|
|
# save exp to get started
|
|
|
|
if hasattr(ref_model, "hparams"):
|
|
|
|
self.logger.log_hyperparams(ref_model.hparams)
|
|
|
|
|
|
|
|
self.logger.save()
|
|
|
|
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
|
|
|
dist.barrier()
|
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
# set up checkpoint callback
|
|
|
|
self.__configure_checkpoint_callback()
|
|
|
|
|
2019-09-06 15:54:51 +00:00
|
|
|
# register auto-resubmit when on SLURM
|
|
|
|
self.register_slurm_signal_handlers()
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# transfer data loaders from model
|
2019-07-24 21:09:14 +00:00
|
|
|
self.get_dataloaders(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# init training constants
|
2019-07-09 00:13:40 +00:00
|
|
|
self.__layout_bookeeping()
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# print model summary
|
2019-10-08 21:11:47 +00:00
|
|
|
if self.proc_rank == 0 and self.weights_summary is not None:
|
|
|
|
if self.weights_summary in ['full', 'top']:
|
|
|
|
ref_model.summarize(mode=self.weights_summary)
|
|
|
|
else:
|
|
|
|
m = "weights_summary can be None, 'full' or 'top'"
|
|
|
|
raise MisconfigurationException(m)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-07-27 02:57:49 +00:00
|
|
|
# track model now.
|
|
|
|
# if cluster resets state, the model will update with the saved weights
|
|
|
|
self.model = model
|
|
|
|
|
2019-08-07 11:42:14 +00:00
|
|
|
# restore training and model before hpc call
|
2019-09-06 15:54:51 +00:00
|
|
|
self.restore_weights(model)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-24 01:23:27 +00:00
|
|
|
# progress bar init
|
|
|
|
if self.show_progress_bar:
|
|
|
|
self.progress_bar = tqdm.tqdm(0, position=self.process_position)
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# when testing requested only run test and return
|
|
|
|
if self.testing:
|
|
|
|
self.__run_evaluation(test=True)
|
|
|
|
return
|
|
|
|
|
|
|
|
# run tiny validation (if validation defined)
|
|
|
|
# to make sure program won't crash during val
|
2019-08-07 12:14:52 +00:00
|
|
|
ref_model.on_sanity_check_start()
|
2019-10-04 19:35:02 +00:00
|
|
|
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
|
2019-09-06 11:37:25 +00:00
|
|
|
# reset progress_bar limit for sanity check
|
|
|
|
if self.show_progress_bar:
|
|
|
|
self.progress_bar.reset(self.nb_sanity_val_steps)
|
2019-08-24 01:23:27 +00:00
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)
|
2019-08-07 11:51:55 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# ---------------------------
|
|
|
|
# CORE TRAINING LOOP
|
|
|
|
# ---------------------------
|
|
|
|
self.__train()
|
|
|
|
|
|
|
|
def __train(self):
|
|
|
|
# run all epochs
|
|
|
|
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
|
2019-09-16 14:21:00 +00:00
|
|
|
# set seed for distributed sampler (enables shuffling for each epoch)
|
2019-10-04 19:35:02 +00:00
|
|
|
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
|
|
|
|
self.get_train_dataloader().sampler.set_epoch(epoch_nb)
|
2019-09-16 14:21:00 +00:00
|
|
|
|
2019-08-12 20:07:42 +00:00
|
|
|
# get model
|
2019-07-12 16:42:17 +00:00
|
|
|
model = self.__get_model()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-12 20:07:42 +00:00
|
|
|
# update training progress in trainer and model
|
|
|
|
model.current_epoch = epoch_nb
|
2019-03-31 01:45:16 +00:00
|
|
|
self.current_epoch = epoch_nb
|
2019-09-25 23:05:06 +00:00
|
|
|
self.total_batches = self.nb_training_batches + self.nb_val_batches
|
2019-03-31 01:45:16 +00:00
|
|
|
self.batch_loss_value = 0 # accumulated grads
|
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
# limit the number of batches to 1 in fast_dev_run
|
2019-10-10 22:17:26 +00:00
|
|
|
if self.fast_dev_run:
|
|
|
|
self.total_batches = 1
|
2019-10-09 14:23:08 +00:00
|
|
|
|
2019-08-24 01:23:27 +00:00
|
|
|
# init progress_bar when requested
|
|
|
|
if self.show_progress_bar:
|
|
|
|
self.progress_bar.reset(self.total_batches)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 14:56:14 +00:00
|
|
|
# changing gradient according accumulation_scheduler
|
|
|
|
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
|
|
|
|
|
2019-08-12 20:07:42 +00:00
|
|
|
# -----------------
|
|
|
|
# RUN TNG EPOCH
|
|
|
|
# -----------------
|
2019-09-25 23:05:06 +00:00
|
|
|
self.run_training_epoch()
|
2019-04-23 12:46:20 +00:00
|
|
|
|
2019-08-12 20:07:42 +00:00
|
|
|
# update LR schedulers
|
|
|
|
if self.lr_schedulers is not None:
|
|
|
|
for lr_scheduler in self.lr_schedulers:
|
2019-09-26 20:36:41 +00:00
|
|
|
lr_scheduler.step(self.current_epoch)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# early stopping
|
2019-07-16 14:00:03 +00:00
|
|
|
met_min_epochs = epoch_nb > self.min_nb_epochs
|
2019-10-09 14:23:08 +00:00
|
|
|
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
|
2019-08-06 10:08:31 +00:00
|
|
|
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
|
2019-10-08 20:21:00 +00:00
|
|
|
logs=self.callback_metrics)
|
2019-03-31 01:45:16 +00:00
|
|
|
# stop training
|
|
|
|
stop = should_stop and met_min_epochs
|
|
|
|
if stop:
|
|
|
|
return
|
|
|
|
|
2019-10-08 21:33:33 +00:00
|
|
|
if self.logger is not None:
|
|
|
|
self.logger.finalize("success")
|
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
def run_training_epoch(self):
|
2019-08-12 20:07:42 +00:00
|
|
|
# before epoch hook
|
|
|
|
if self.__is_function_implemented('on_epoch_start'):
|
|
|
|
model = self.__get_model()
|
|
|
|
model.on_epoch_start()
|
|
|
|
|
|
|
|
# run epoch
|
2019-10-04 19:35:02 +00:00
|
|
|
for batch_nb, batch in enumerate(self.get_train_dataloader()):
|
2019-08-12 20:07:42 +00:00
|
|
|
self.batch_nb = batch_nb
|
|
|
|
self.global_step += 1
|
|
|
|
|
|
|
|
model = self.__get_model()
|
|
|
|
model.global_step = self.global_step
|
|
|
|
|
|
|
|
# stop when the flag is changed or we've gone past the amount
|
|
|
|
# requested in the batches
|
|
|
|
self.total_batch_nb += 1
|
2019-10-18 08:18:05 +00:00
|
|
|
met_batch_limit = batch_nb >= self.nb_training_batches
|
2019-08-12 20:07:42 +00:00
|
|
|
if met_batch_limit:
|
|
|
|
break
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# RUN TRAIN STEP
|
|
|
|
# ---------------
|
2019-10-05 17:35:20 +00:00
|
|
|
output = self.__run_training_batch(batch, batch_nb)
|
|
|
|
batch_result, grad_norm_dic, batch_step_metrics = output
|
2019-08-12 20:07:42 +00:00
|
|
|
early_stop_epoch = batch_result == -1
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# RUN VAL STEP
|
|
|
|
# ---------------
|
|
|
|
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
|
2019-08-30 22:56:09 +00:00
|
|
|
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
2019-10-09 14:23:08 +00:00
|
|
|
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch)
|
|
|
|
|
|
|
|
# fast_dev_run always forces val checking after train batch
|
|
|
|
if self.fast_dev_run or should_check_val:
|
|
|
|
self.__run_evaluation(test=self.testing)
|
2019-08-12 20:07:42 +00:00
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
# when logs should be saved
|
|
|
|
should_save_log = (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch
|
|
|
|
if should_save_log or self.fast_dev_run:
|
2019-09-27 16:05:29 +00:00
|
|
|
if self.proc_rank == 0 and self.logger is not None:
|
|
|
|
self.logger.save()
|
2019-08-12 20:07:42 +00:00
|
|
|
|
|
|
|
# when metrics should be logged
|
2019-10-09 14:23:08 +00:00
|
|
|
should_log_metrics = batch_nb % self.row_log_interval == 0 or early_stop_epoch
|
|
|
|
if should_log_metrics or self.fast_dev_run:
|
2019-08-12 20:07:42 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# logs user requested information to logger
|
|
|
|
self.__log_metrics(batch_step_metrics, grad_norm_dic)
|
2019-08-12 20:07:42 +00:00
|
|
|
|
|
|
|
# end epoch early
|
2019-10-09 14:23:08 +00:00
|
|
|
if early_stop_epoch or self.fast_dev_run:
|
2019-08-12 20:07:42 +00:00
|
|
|
break
|
|
|
|
|
|
|
|
# epoch end hook
|
|
|
|
if self.__is_function_implemented('on_epoch_end'):
|
|
|
|
model = self.__get_model()
|
|
|
|
model.on_epoch_end()
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
def __log_metrics(self, metrics, grad_norm_dic):
|
|
|
|
"""
|
|
|
|
Logs the metric dict passed in
|
|
|
|
:param metrics:
|
|
|
|
:param grad_norm_dic:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
# added metrics by Lightning for convenience
|
|
|
|
metrics['epoch'] = self.current_epoch
|
|
|
|
|
|
|
|
# add gpu memory
|
|
|
|
if self.on_gpu and self.log_gpu_memory:
|
2019-10-05 20:39:05 +00:00
|
|
|
mem_map = memory.get_memory_profile(self.log_gpu_memory)
|
2019-10-05 17:35:20 +00:00
|
|
|
metrics.update(mem_map)
|
|
|
|
|
|
|
|
# add norms
|
|
|
|
metrics.update(grad_norm_dic)
|
|
|
|
|
|
|
|
# turn all tensors to scalars
|
|
|
|
scalar_metrics = self.__metrics_to_scalars(metrics)
|
|
|
|
|
|
|
|
# log actual metrics
|
|
|
|
if self.proc_rank == 0 and self.logger is not None:
|
|
|
|
self.logger.log_metrics(scalar_metrics, step_num=self.global_step)
|
|
|
|
self.logger.save()
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def test(self, model=None):
|
2019-10-18 22:39:30 +00:00
|
|
|
self.testing = True
|
2019-08-30 22:56:09 +00:00
|
|
|
if model is not None:
|
|
|
|
self.fit(model)
|
|
|
|
else:
|
|
|
|
self.__run_evaluation(test=True)
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
def __metrics_to_scalars(self, metrics):
|
2019-07-01 22:38:07 +00:00
|
|
|
new_metrics = {}
|
|
|
|
for k, v in metrics.items():
|
2019-10-06 21:57:23 +00:00
|
|
|
if isinstance(v, torch.Tensor):
|
2019-07-01 22:38:07 +00:00
|
|
|
v = v.item()
|
|
|
|
|
|
|
|
if type(v) is dict:
|
|
|
|
v = self.__metrics_to_scalars(v)
|
|
|
|
|
2019-10-06 21:57:23 +00:00
|
|
|
new_metrics[k] = v
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
return new_metrics
|
|
|
|
|
|
|
|
def __log_vals_blacklist(self):
|
|
|
|
"""avoid logging some vals lightning uses to maintain state"""
|
2019-07-18 17:32:36 +00:00
|
|
|
blacklist = {'batch_nb', 'v_nb', 'gpu'}
|
2019-07-01 22:38:07 +00:00
|
|
|
return blacklist
|
2019-04-23 12:57:58 +00:00
|
|
|
|
2019-08-15 13:39:09 +00:00
|
|
|
def transfer_batch_to_gpu(self, batch, gpu_id):
|
2019-09-05 11:13:06 +00:00
|
|
|
# base case: object can be directly moved using `cuda` or `to`
|
|
|
|
if callable(getattr(batch, 'cuda', None)):
|
2019-08-15 13:39:09 +00:00
|
|
|
return batch.cuda(gpu_id)
|
|
|
|
|
2019-09-05 11:13:06 +00:00
|
|
|
elif callable(getattr(batch, 'to', None)):
|
|
|
|
return batch.to(torch.device('cuda', gpu_id))
|
|
|
|
|
2019-08-21 14:22:51 +00:00
|
|
|
# when list
|
|
|
|
elif isinstance(batch, list):
|
2019-08-15 13:39:09 +00:00
|
|
|
for i, x in enumerate(batch):
|
|
|
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
|
|
|
return batch
|
|
|
|
|
2019-08-21 14:22:51 +00:00
|
|
|
# when tuple
|
|
|
|
elif isinstance(batch, tuple):
|
|
|
|
batch = list(batch)
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
|
|
|
return tuple(batch)
|
|
|
|
|
2019-08-15 13:39:09 +00:00
|
|
|
# when dict
|
|
|
|
elif isinstance(batch, dict):
|
|
|
|
for k, v in batch.items():
|
|
|
|
batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
|
|
|
|
|
|
|
|
return batch
|
|
|
|
|
2019-08-20 20:59:26 +00:00
|
|
|
# nothing matches, return the value as is without transform
|
|
|
|
return batch
|
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
def __training_forward(self, batch, batch_nb, opt_idx):
|
2019-08-13 13:32:45 +00:00
|
|
|
"""
|
|
|
|
Handle forward for each training case (distributed, single gpu, etc...)
|
2019-09-25 23:05:06 +00:00
|
|
|
:param batch:
|
2019-08-13 13:32:45 +00:00
|
|
|
:param batch_nb:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
# ---------------
|
|
|
|
# FORWARD
|
|
|
|
# ---------------
|
|
|
|
# enable not needing to add opt_idx to training_step
|
2019-09-25 23:05:06 +00:00
|
|
|
args = [batch, batch_nb]
|
2019-08-13 13:32:45 +00:00
|
|
|
if len(self.optimizers) > 1:
|
|
|
|
args.append(opt_idx)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-05 20:39:05 +00:00
|
|
|
if self.use_ddp or self.use_ddp2:
|
2019-08-13 13:32:45 +00:00
|
|
|
output = self.model(*args)
|
2019-07-18 15:29:21 +00:00
|
|
|
elif self.use_dp:
|
2019-08-13 13:32:45 +00:00
|
|
|
output = self.model(*args)
|
2019-08-07 17:46:06 +00:00
|
|
|
elif self.single_gpu:
|
2019-09-08 19:36:58 +00:00
|
|
|
gpu_id = 0
|
|
|
|
if type(self.data_parallel_device_ids) is list:
|
|
|
|
gpu_id = self.data_parallel_device_ids[0]
|
2019-09-25 23:05:06 +00:00
|
|
|
batch = self.transfer_batch_to_gpu(batch, gpu_id)
|
|
|
|
args[0] = batch
|
2019-08-13 13:32:45 +00:00
|
|
|
output = self.model.training_step(*args)
|
2019-08-07 17:49:01 +00:00
|
|
|
|
2019-07-03 20:51:32 +00:00
|
|
|
else:
|
2019-08-13 13:32:45 +00:00
|
|
|
output = self.model.training_step(*args)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# format and reduce outputs accordingly
|
2019-10-08 20:21:00 +00:00
|
|
|
output = self.__process_output(output, train=True)
|
|
|
|
loss, progress_bar_metrics, log_metrics, callback_metrics = output
|
|
|
|
return loss, progress_bar_metrics, log_metrics, callback_metrics
|
2019-10-05 17:35:20 +00:00
|
|
|
|
|
|
|
def __process_output(self, output, train=False):
|
|
|
|
"""
|
|
|
|
Reduces output according to the training mode.
|
|
|
|
Separates loss from logging and tqdm metrics
|
|
|
|
:param output:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-10-09 14:23:08 +00:00
|
|
|
# ---------------
|
|
|
|
# EXTRACT CALLBACK KEYS
|
|
|
|
# ---------------
|
2019-10-08 20:21:00 +00:00
|
|
|
# all keys not progress_bar or log are candidates for callbacks
|
|
|
|
callback_metrics = {}
|
|
|
|
for k, v in output.items():
|
|
|
|
if k not in ['progress_bar', 'log']:
|
|
|
|
callback_metrics[k] = v
|
|
|
|
|
2019-10-18 22:39:30 +00:00
|
|
|
if train and (self.use_dp or self.use_ddp2):
|
2019-10-09 14:23:08 +00:00
|
|
|
nb_gpus = self.num_gpus
|
|
|
|
callback_metrics = reduce_distributed_output(callback_metrics, nb_gpus)
|
|
|
|
|
2019-10-09 16:53:33 +00:00
|
|
|
for k, v in callback_metrics.items():
|
|
|
|
callback_metrics[k] = v.item()
|
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
# ---------------
|
|
|
|
# EXTRACT PROGRESS BAR KEYS
|
|
|
|
# ---------------
|
2019-07-11 18:58:47 +00:00
|
|
|
try:
|
2019-10-05 17:35:20 +00:00
|
|
|
progress_output = output['progress_bar']
|
2019-08-08 16:06:29 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
# reduce progress metrics for tqdm when using dp
|
2019-10-18 22:39:30 +00:00
|
|
|
if train and (self.use_dp or self.use_ddp2):
|
2019-09-08 19:36:58 +00:00
|
|
|
nb_gpus = self.num_gpus
|
2019-09-25 23:05:06 +00:00
|
|
|
progress_output = reduce_distributed_output(progress_output, nb_gpus)
|
2019-08-08 16:06:29 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
progress_bar_metrics = progress_output
|
2019-08-05 21:57:39 +00:00
|
|
|
except Exception:
|
2019-10-05 17:35:20 +00:00
|
|
|
progress_bar_metrics = {}
|
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
# ---------------
|
|
|
|
# EXTRACT LOGGING KEYS
|
|
|
|
# ---------------
|
2019-10-05 17:35:20 +00:00
|
|
|
# extract metrics to log to experiment
|
|
|
|
try:
|
|
|
|
log_output = output['log']
|
|
|
|
|
|
|
|
# reduce progress metrics for tqdm when using dp
|
2019-10-17 22:17:27 +00:00
|
|
|
if train and (self.use_dp or self.use_ddp2):
|
2019-10-05 17:35:20 +00:00
|
|
|
nb_gpus = self.num_gpus
|
|
|
|
log_output = reduce_distributed_output(log_output, nb_gpus)
|
|
|
|
|
|
|
|
log_metrics = log_output
|
|
|
|
except Exception:
|
|
|
|
log_metrics = {}
|
2019-07-11 18:58:47 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
# ---------------
|
|
|
|
# EXTRACT LOSS
|
|
|
|
# ---------------
|
2019-07-11 19:08:45 +00:00
|
|
|
# if output dict doesn't have the keyword loss
|
|
|
|
# then assume the output=loss if scalar
|
2019-10-05 17:35:20 +00:00
|
|
|
loss = None
|
|
|
|
if train:
|
|
|
|
try:
|
|
|
|
loss = output['loss']
|
|
|
|
except Exception:
|
|
|
|
if type(output) is torch.Tensor:
|
|
|
|
loss = output
|
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
'No `loss` value in the dictionary returned from `model.training_step()`.'
|
|
|
|
)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# when using dp need to reduce the loss
|
|
|
|
if self.use_dp or self.use_ddp2:
|
|
|
|
loss = reduce_distributed_output(loss, self.num_gpus)
|
2019-08-08 16:06:29 +00:00
|
|
|
|
2019-10-18 13:28:13 +00:00
|
|
|
# use every metric passed in as a candidate for callback
|
|
|
|
callback_metrics.update(progress_bar_metrics)
|
|
|
|
callback_metrics.update(log_metrics)
|
|
|
|
|
2019-10-18 21:03:28 +00:00
|
|
|
# convert tensors to numpy
|
|
|
|
for k, v in callback_metrics.items():
|
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
|
callback_metrics[k] = v.item()
|
|
|
|
|
2019-10-08 20:21:00 +00:00
|
|
|
return loss, progress_bar_metrics, log_metrics, callback_metrics
|
2019-05-14 00:40:07 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
def __clip_gradients(self):
|
2019-09-25 23:05:06 +00:00
|
|
|
if self.gradient_clip_val > 0:
|
2019-08-13 13:32:45 +00:00
|
|
|
model = self.__get_model()
|
2019-09-25 23:05:06 +00:00
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
|
2019-07-21 22:23:48 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
def __print_nan_grads(self):
|
2019-09-07 05:08:09 +00:00
|
|
|
model = self.__get_model()
|
|
|
|
for param in model.parameters():
|
|
|
|
if torch.isnan(param.grad.float()).any():
|
|
|
|
print(param, param.grad)
|
2019-05-16 19:58:06 +00:00
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
def __run_training_batch(self, batch, batch_nb):
|
2019-10-02 15:11:08 +00:00
|
|
|
# track grad norms
|
|
|
|
grad_norm_dic = {}
|
|
|
|
|
2019-10-08 20:21:00 +00:00
|
|
|
# track all metrics for callbacks
|
|
|
|
all_callback_metrics = []
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# track metrics to log
|
|
|
|
all_log_metrics = []
|
|
|
|
|
2019-09-25 23:05:06 +00:00
|
|
|
if batch is None:
|
2019-10-02 15:11:08 +00:00
|
|
|
return 0, grad_norm_dic
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_batch_start'):
|
|
|
|
model_ref = self.__get_model()
|
2019-09-25 23:05:06 +00:00
|
|
|
response = model_ref.on_batch_start(batch)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
if response == -1:
|
2019-10-02 15:11:08 +00:00
|
|
|
return -1, grad_norm_dic
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-24 01:23:27 +00:00
|
|
|
if self.show_progress_bar:
|
|
|
|
self.progress_bar.update(1)
|
2019-07-21 22:15:58 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
# call training_step once per optimizer
|
|
|
|
for opt_idx, optimizer in enumerate(self.optimizers):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-05 14:47:18 +00:00
|
|
|
# wrap the forward step in a closure so second order methods work
|
|
|
|
def optimizer_closure():
|
|
|
|
# forward pass
|
2019-10-05 15:10:21 +00:00
|
|
|
output = self.__training_forward(batch, batch_nb, opt_idx)
|
2019-10-08 20:21:00 +00:00
|
|
|
closure_loss, progress_bar_metrics, log_metrics, callback_metrics = output
|
|
|
|
|
|
|
|
# track metrics for callbacks
|
|
|
|
all_callback_metrics.append(callback_metrics)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# track progress bar metrics
|
|
|
|
self.__add_tqdm_metrics(progress_bar_metrics)
|
|
|
|
all_log_metrics.append(log_metrics)
|
2019-08-13 13:32:45 +00:00
|
|
|
|
2019-10-05 14:47:18 +00:00
|
|
|
# accumulate loss
|
|
|
|
# (if accumulate_grad_batches = 1 no effect)
|
|
|
|
closure_loss = closure_loss / self.accumulate_grad_batches
|
2019-08-13 13:32:45 +00:00
|
|
|
|
2019-10-05 14:47:18 +00:00
|
|
|
# backward pass
|
|
|
|
if self.use_amp:
|
|
|
|
with amp.scale_loss(closure_loss, optimizer) as scaled_loss:
|
|
|
|
scaled_loss.backward()
|
|
|
|
else:
|
|
|
|
closure_loss.backward()
|
|
|
|
|
|
|
|
# insert after step hook
|
|
|
|
if self.__is_function_implemented('on_after_backward'):
|
|
|
|
model_ref = self.__get_model()
|
|
|
|
model_ref.on_after_backward()
|
|
|
|
|
|
|
|
return closure_loss
|
2019-08-13 13:32:45 +00:00
|
|
|
|
2019-10-05 14:47:18 +00:00
|
|
|
# calculate loss
|
|
|
|
loss = optimizer_closure()
|
2019-08-13 13:32:45 +00:00
|
|
|
|
|
|
|
# nan grads
|
2019-09-07 05:08:09 +00:00
|
|
|
if self.print_nan_grads:
|
|
|
|
self.__print_nan_grads()
|
2019-08-13 13:32:45 +00:00
|
|
|
|
|
|
|
# track total loss for logging (avoid mem leaks)
|
|
|
|
self.batch_loss_value += loss.item()
|
|
|
|
|
|
|
|
# gradient update with accumulated gradients
|
|
|
|
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
2019-10-02 15:11:08 +00:00
|
|
|
|
|
|
|
# track gradient norms when requested
|
|
|
|
if batch_nb % self.row_log_interval == 0:
|
|
|
|
if self.track_grad_norm > 0:
|
|
|
|
model = self.__get_model()
|
|
|
|
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
# clip gradients
|
|
|
|
self.__clip_gradients()
|
|
|
|
|
|
|
|
# calls .step(), .zero_grad()
|
|
|
|
# override function to modify this behavior
|
|
|
|
model = self.__get_model()
|
2019-10-05 15:10:21 +00:00
|
|
|
model.optimizer_step(self.current_epoch, batch_nb,
|
|
|
|
optimizer, opt_idx, optimizer_closure)
|
2019-08-13 13:32:45 +00:00
|
|
|
|
|
|
|
# calculate running loss for display
|
|
|
|
self.running_loss.append(self.batch_loss_value)
|
|
|
|
self.batch_loss_value = 0
|
|
|
|
self.avg_loss = np.mean(self.running_loss[-100:])
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# update progress bar
|
2019-08-24 01:23:27 +00:00
|
|
|
if self.show_progress_bar:
|
2019-08-13 13:32:45 +00:00
|
|
|
# add model specific metrics
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_metrics = self.__training_tqdm_dict
|
2019-08-24 01:23:27 +00:00
|
|
|
self.progress_bar.set_postfix(**tqdm_metrics)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# activate batch end hook
|
|
|
|
if self.__is_function_implemented('on_batch_end'):
|
2019-07-12 16:42:17 +00:00
|
|
|
model = self.__get_model()
|
|
|
|
model.on_batch_end()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# collapse all metrics into one dict
|
|
|
|
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
|
2019-10-08 20:21:00 +00:00
|
|
|
|
|
|
|
# track all metrics for callbacks
|
|
|
|
self.callback_metrics = {k: v for d in all_callback_metrics for k, v in d.items()}
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
return 0, grad_norm_dic, all_log_metrics
|
2019-04-23 12:26:48 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def __run_evaluation(self, test=False):
|
|
|
|
# when testing make sure user defined a test step
|
|
|
|
can_run_test_step = False
|
|
|
|
if test:
|
|
|
|
can_run_test_step = self.__is_overriden('test_step') and self.__is_overriden('test_end')
|
|
|
|
if not can_run_test_step:
|
|
|
|
m = '''You called .test() without defining a test step or test_end.
|
|
|
|
Please define and try again'''
|
|
|
|
raise MisconfigurationException(m)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-12 19:23:11 +00:00
|
|
|
# validate only if model has validation_step defined
|
2019-08-30 22:56:09 +00:00
|
|
|
# test only if test_step or validation_step are defined
|
|
|
|
run_val_step = self.__is_overriden('validation_step')
|
|
|
|
|
|
|
|
if run_val_step or can_run_test_step:
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-12 19:23:11 +00:00
|
|
|
# hook
|
2019-08-30 22:56:09 +00:00
|
|
|
model = self.__get_model()
|
|
|
|
model.on_pre_performance_check()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# select dataloaders
|
|
|
|
if test:
|
2019-10-04 19:35:02 +00:00
|
|
|
dataloaders = self.get_test_dataloaders()
|
2019-08-30 22:56:09 +00:00
|
|
|
max_batches = self.nb_test_batches
|
2019-10-18 22:39:30 +00:00
|
|
|
else:
|
|
|
|
# val
|
|
|
|
dataloaders = self.get_val_dataloaders()
|
|
|
|
max_batches = self.nb_val_batches
|
2019-08-30 22:56:09 +00:00
|
|
|
|
|
|
|
# cap max batches to 1 when using fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
|
|
|
max_batches = 1
|
|
|
|
|
2019-10-05 17:35:20 +00:00
|
|
|
# run evaluation
|
|
|
|
eval_results = self.evaluate(self.model,
|
|
|
|
dataloaders,
|
|
|
|
max_batches,
|
|
|
|
test)
|
2019-10-08 20:21:00 +00:00
|
|
|
_, prog_bar_metrics, log_metrics, callback_metrics = self.__process_output(eval_results)
|
2019-10-05 17:35:20 +00:00
|
|
|
|
|
|
|
# add metrics to prog bar
|
2019-10-08 20:21:00 +00:00
|
|
|
self.__add_tqdm_metrics(prog_bar_metrics)
|
2019-10-05 17:35:20 +00:00
|
|
|
|
|
|
|
# log metrics
|
|
|
|
self.__log_metrics(log_metrics, {})
|
2019-08-30 22:56:09 +00:00
|
|
|
|
2019-10-08 20:21:00 +00:00
|
|
|
# track metrics for callbacks
|
|
|
|
self.callback_metrics = callback_metrics
|
|
|
|
|
2019-09-06 11:37:25 +00:00
|
|
|
# hook
|
|
|
|
model.on_post_performance_check()
|
2019-08-12 19:23:11 +00:00
|
|
|
|
2019-08-24 01:23:27 +00:00
|
|
|
if self.show_progress_bar:
|
2019-08-12 19:23:11 +00:00
|
|
|
# add model specific metrics
|
2019-09-25 23:05:06 +00:00
|
|
|
tqdm_metrics = self.__training_tqdm_dict
|
2019-08-24 01:23:27 +00:00
|
|
|
self.progress_bar.set_postfix(**tqdm_metrics)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# model checkpointing
|
2019-08-30 22:56:09 +00:00
|
|
|
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
|
2019-08-06 10:08:31 +00:00
|
|
|
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
|
2019-10-08 20:21:00 +00:00
|
|
|
logs=self.callback_metrics)
|