lightning/pytorch_lightning/models/trainer.py

888 lines
31 KiB
Python

"""
The trainer handles all the logic for running a val loop, training loop, distributing, etc...
"""
import subprocess
import traceback
import warnings
import os
import pdb
import re
import torch
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import tqdm
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utils.debugging import MisconfigurationException
try:
from apex import amp
APEX_AVAILABLE = True
except Exception:
APEX_AVAILABLE = False
def reduce_distributed_output(output, nb_gpus):
if nb_gpus <= 1:
return output
# when using DP, we get one output per gpu
# average outputs and return
if type(output) is torch.Tensor:
return output.mean()
for k, v in output.items():
# recurse on nested dics
if isinstance(output[k], dict):
output[k] = reduce_distributed_output(output[k], nb_gpus)
# reduce only metrics that have the same nb of gpus
elif output[k].size(0) == nb_gpus:
reduced = torch.mean(output[k])
output[k] = reduced
return output
class Trainer(TrainerIO):
def __init__(self,
experiment,
early_stop_callback=None,
checkpoint_callback=None,
gradient_clip=0,
cluster=None,
process_position=0,
current_gpu_name=0,
nb_gpu_nodes=1,
gpus=None,
progress_bar=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
max_nb_epochs=1000, min_nb_epochs=1,
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0,
val_check_interval=0.95,
log_save_interval=100, add_log_row_interval=10,
distributed_backend='dp',
use_amp=False,
print_nan_grads=False,
print_weights_summary=True,
amp_level='O2',
nb_sanity_val_steps=5):
"""
:param experiment: Test-tube experiment
:param early_stop_callback: from pytorch_lightning import EarlyStopping
:param checkpoint_callback: from pytorch_lightning import Checkpoint
:param gradient_clip:
:param cluster:
:param process_position:
:param current_gpu_name:
:param nb_gpu_nodes:
:param gpus:
:param progress_bar:
:param overfit_pct:
:param track_grad_norm:
:param check_val_every_n_epoch:
:param fast_dev_run:
:param accumulate_grad_batches:
:param max_nb_epochs:
:param min_nb_epochs:
:param train_percent_check:
:param val_percent_check:
:param test_percent_check:
:param val_check_interval:
:param log_save_interval:
:param add_log_row_interval:
:param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel
:param use_amp:
:param print_nan_grads:
:param print_weights_summary:
:param amp_level:
:param nb_sanity_val_steps:
"""
# Transfer params
self.nb_gpu_nodes = nb_gpu_nodes
self.gradient_clip = gradient_clip
self.check_val_every_n_epoch = check_val_every_n_epoch
self.enable_early_stop = early_stop_callback is not None
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
self.process_position = process_position
self.current_gpu_name = current_gpu_name
self.print_weights_summary = print_weights_summary
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint
self.early_stop = early_stop_callback
self.model = None
self.max_nb_epochs = max_nb_epochs
self.accumulate_grad_batches = accumulate_grad_batches
self.early_stop_callback = early_stop_callback
self.min_nb_epochs = min_nb_epochs
self.nb_sanity_val_steps = nb_sanity_val_steps
self.lr_schedulers = []
self.amp_level = amp_level
self.print_nan_grads = print_nan_grads
self.data_parallel_device_ids = None
self.world_size = 1
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
# training bookeeping
self.total_batch_nb = 0
self.running_loss = []
self.avg_loss = 0
self.batch_nb = 0
self.tqdm_metrics = {}
self.nb_val_batches = None
self.nb_tng_batches = None
self.nb_test_batches = None
# gpus come in as a string.
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
if type(gpus) is list:
self.data_parallel_device_ids = gpus
elif type(gpus) is str:
if gpus == '-1':
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
else:
self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')]
else:
raise Exception('gpus has to be a string or list of ids')
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in self.data_parallel_device_ids])
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
# make DP and DDP mutually exclusive
# single GPU will also use DP with devices=[0]
have_gpus = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0
if have_gpus:
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
# use ddp automatically if nb_gpu_nodes > 1
if nb_gpu_nodes > 1 and self.use_dp: # pragma: no cover
self.use_ddp = True
self.use_dp = False
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
'Switching to DistributedDataParallel for you. ' \
'To silence this warning set distributed_backend=ddp'
warnings.warn(w)
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp:
self.nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes
self.nb_slurm_tasks = 0
try:
self.nb_slurm_tasks = int(os.environ['SLURM_NTASKS'])
self.is_slurm_managing_tasks = self.nb_slurm_tasks == self.nb_requested_gpus
except Exception as e:
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
# process info
self.proc_rank = 0
# training state
self.optimizers = None
self.prog_bar = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.add_log_row_interval = add_log_row_interval
# dataloaders
self.tng_dataloader = None
self.test_dataloader = None
self.val_dataloader = None
# how much of the data to use
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
# 16 bit mixed precision training using apex
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
print('using 16bit precision')
if use_amp and not APEX_AVAILABLE: # pragma: no cover
msg = '''
You set use_amp=True but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux
this run will NOT use 16 bit precision
'''
raise ModuleNotFoundError(msg)
@property
def data_parallel(self):
return self.use_dp or self.use_ddp
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct
def __get_model(self):
return self.model.module if self.data_parallel else self.model
def __is_function_implemented(self, f_name):
model = self.__get_model()
f_op = getattr(model, f_name, None)
return callable(f_op)
@property
def __tng_tqdm_dic(self):
# ForkedPdb().set_trace()
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb':'{}'.format(self.batch_nb),
}
tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
return tqdm_dic
@property
def tng_tqdm_dic(self):
"""
Read-only for tqdm metrics
:return:
"""
return self.__tng_tqdm_dic
def __layout_bookeeping(self):
# determine number of training batches
self.nb_tng_batches = len(self.tng_dataloader)
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
# determine number of validation batches
self.nb_val_batches = len(self.val_dataloader)
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches
# determine number of test batches
self.nb_test_batches = len(self.test_dataloader)
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
# determine when to check validation
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
def __add_tqdm_metrics(self, metrics):
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
self.tqdm_metrics[k] = v
def validate(self, model, dataloader, max_batches):
"""
Run validation code
:param model: PT model
:param dataloader: PT dataloader
:param max_batches: Scalar
:return:
"""
# enable eval mode
model.zero_grad()
model.eval()
# disable gradients to save memory
torch.set_grad_enabled(False)
# bookkeeping
outputs = []
# run training
for batch_i, data_batch in enumerate(dataloader):
if data_batch is None: # pragma: no cover
continue
# stop short when on fast dev run
if max_batches is not None and batch_i >= max_batches:
break
# -----------------
# RUN VALIDATION STEP
# -----------------
if self.use_ddp:
output = model(data_batch, batch_i)
elif self.use_dp:
output = model(data_batch, batch_i)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = model.validation_step(data_batch, batch_i)
outputs.append(output)
# batch done
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)
# give model a chance to do something with the outputs
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
# enable train mode again
model.train()
# enable gradients to save memory
torch.set_grad_enabled(True)
return val_results
def get_dataloaders(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.tng_dataloader = model.tng_dataloader
self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
msg = '''
when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler).
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
'''
raise MisconfigurationException(msg)
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp:
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
self.experiment = self.experiment.get_meta_copy()
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
else:
msg = f"""
You requested {self.nb_requested_gpus} GPUs but launched {self.nb_slurm_tasks} slurm tasks.
We will launch {self.nb_requested_gpus} processes for you.
We recommend you let slurm manage the processes by setting: --ntasks-per-node={self.nb_requested_gpus}
If you're not using SLURM, ignore this message!
"""
warnings.warn(msg)
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
elif self.use_dp:
self.__dp_train(model)
# ON CPU
else:
# run through amp wrapper
if self.use_amp:
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
self.__run_pretrain_routine(model)
# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
def __dp_train(self, model):
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
model.cuda(self.data_parallel_device_ids[0])
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp:
m = f'amp level {self.amp_level} with DataParallel is not supported. ' \
f'See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227. ' \
f'We recommend you switch to ddp if you want to use amp'
raise MisconfigurationException(m)
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
self.__run_pretrain_routine(model)
def ddp_train(self, gpu_nb, model):
"""
Entry point into a DP thread
:param gpu_nb:
:param model:
:param cluster_obj:
:return:
"""
# node rank using relative slurm id
# otherwise default to node rank 0
try:
node_id = os.environ['SLURM_NODEID']
self.node_rank = int(node_id)
except Exception as e:
self.node_rank = 0
# recover original exp before went into process
# init in write mode only on proc 0
self.experiment.debug = self.proc_rank > 0
self.experiment = self.experiment.get_non_ddp_exp()
# show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
# determine which process we are and world size
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
# let the exp know the rank to avoid overwriting logs
self.experiment.rank = self.proc_rank
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.__init_tcp_connection()
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers = model.configure_optimizers()
if len(self.optimizers) == 2:
self.optimizers, self.lr_schedulers = self.optimizers
# MODEL
# copy model to each gpu
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)
# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,
)
self.optimizers = optimizers
model = LightningDistributedDataParallel(model, device_ids=[gpu_nb], find_unused_parameters=True)
# continue training routine
self.__run_pretrain_routine(model)
def __init_tcp_connection(self):
"""
Connect all procs in the world using the env:// init
Use the first node as the root address
:param port:
:param tries:
:return:
"""
# sets the appropriate port
try:
port = os.environ['MASTER_PORT']
except Exception as e:
port = 12910
os.environ['MASTER_PORT'] = f'{port}'
# figure out the root node addr
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
except Exception as e:
root_node = '127.0.0.2'
root_node = self.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
def resolve_root_node_address(self, root_node):
if '[' in root_node:
name = root_node.split('[')[0]
number = root_node.split(',')[0]
if '-' in number:
number = number.split('-')[0]
number = re.sub('[^0-9]', '', number)
root_node = name + number
return root_node
def __run_pretrain_routine(self, model):
"""
Sanity check a few things before starting actual training
:param model:
:return:
"""
ref_model = model
if self.data_parallel:
ref_model = model.module
ref_model.trainer = self
# set local properties on the model
ref_model.on_gpu = self.on_gpu
# transfer data loaders from model
self.get_dataloaders(ref_model)
# init training constants
self.__layout_bookeeping()
# print model summary
if self.proc_rank == 0 and self.print_weights_summary:
ref_model.summarize()
# give model convenience properties
ref_model.trainer = self
ref_model.experiment = self.experiment
# run tiny validation to make sure program won't crash during val
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
# save exp to get started
if self.proc_rank == 0:
self.experiment.save()
# track model now.
# if cluster resets state, the model will update with the saved weights
self.model = model
# enable cluster checkpointing
# also restores training state
if self.cluster is not None: # pragma: no cover
self.enable_auto_hpc_walltime_manager()
# ---------------------------
# CORE TRAINING LOOP
# ---------------------------
self.__train()
def __train(self):
# run all epochs
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
# update the lr scheduler
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
model = self.__get_model()
model.current_epoch = epoch_nb
# hook
if self.__is_function_implemented('on_epoch_start'):
model = self.__get_model()
model.on_epoch_start()
self.current_epoch = epoch_nb
self.total_batches = self.nb_tng_batches + self.nb_val_batches
self.batch_loss_value = 0 # accumulated grads
# init progbar when requested
if self.progress_bar:
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
for batch_nb, data_batch in enumerate(self.tng_dataloader):
self.batch_nb = batch_nb
self.global_step += 1
model = self.__get_model()
model.global_step = self.global_step
# stop when the flag is changed or we've gone past the amount requested in the batches
self.total_batch_nb += 1
met_batch_limit = batch_nb > self.nb_tng_batches
if met_batch_limit:
break
# ---------------
# RUN TRAIN STEP
# ---------------
batch_result = self.__run_tng_batch(data_batch, batch_nb)
early_stop_epoch = batch_result == -1
# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
self.__run_validation()
# when batch should be saved
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
if self.proc_rank == 0:
self.experiment.save()
# when metrics should be logged
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
# count items in memory
# nb_params, nb_tensors = count_mem_items()
model = self.__get_model()
metrics = self.__tng_tqdm_dic
# add gpu memory
if self.on_gpu:
mem_map = get_gpu_memory_map()
metrics.update(mem_map)
# add norms
if self.track_grad_norm > 0:
model = self.__get_model()
grad_norm_dic = model.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
if self.__is_function_implemented('on_tng_metrics'):
model.on_tng_metrics(metrics)
# log metrics
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
if self.proc_rank == 0:
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()
# hook
if self.__is_function_implemented('on_batch_end'):
model = self.__get_model()
model.on_batch_end()
# end epoch early
if early_stop_epoch:
break
# hook
if self.__is_function_implemented('on_epoch_end'):
model = self.__get_model()
model.on_epoch_end()
# early stopping
met_min_epochs = epoch_nb > self.min_nb_epochs
if self.enable_early_stop and met_min_epochs:
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, logs=self.__tng_tqdm_dic)
# stop training
stop = should_stop and met_min_epochs
if stop:
return
def __metrics_to_scalars(self, metrics, blacklist=[]):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
if type(v) is dict:
v = self.__metrics_to_scalars(v)
if k not in blacklist:
new_metrics[k] = float(v)
return new_metrics
def __log_vals_blacklist(self):
"""avoid logging some vals lightning uses to maintain state"""
blacklist = {'batch_nb', 'v_nb', 'gpu'}
return blacklist
def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
return 0
# hook
if self.__is_function_implemented('on_batch_start'):
model_ref = self.__get_model()
response = model_ref.on_batch_start(data_batch)
if response == -1:
return -1
if self.progress_bar:
self.prog_bar.update(1)
# forward pass
# return a scalar value and a dic with tqdm metrics
if self.use_ddp:
output = self.model(data_batch, batch_nb)
elif self.use_dp:
output = self.model(data_batch, batch_nb)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = self.model.training_step(data_batch, batch_nb)
try:
model_specific_tqdm_metrics_dic = output['prog']
except Exception as e:
model_specific_tqdm_metrics_dic = {}
# if output dict doesn't have the keyword loss
# then assume the output=loss if scalar
try:
loss = output['loss']
except Exception as e:
if type(output) is torch.Tensor:
loss = output
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# backward pass
if self.use_amp:
# scale loss when using amp
for optimizer in self.optimizers:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# insert after step hook
if self.__is_function_implemented('on_after_backward'):
model_ref = self.__get_model()
response = model_ref.on_after_backward()
if self.print_nan_grads:
model = self.__get_model()
for param in model.parameters():
print(param.grad.float().sum())
# avoid memory leaks
self.batch_loss_value += loss.item()
# gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
# clip gradients
if self.gradient_clip > 0:
model = self.__get_model()
torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip)
# update gradients across all optimizers
for optimizer in self.optimizers:
optimizer.step()
# insert after step hook
if self.__is_function_implemented('on_before_zero_grad'):
model_ref = self.__get_model()
response = model_ref.on_before_zero_grad(optimizer)
# clear gradients
optimizer.zero_grad()
# queuing loss across batches blows it up proportionally... divide out the number accumulated
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
# track loss
self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
# update progbar
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# activate batch end hook
if self.__is_function_implemented('on_batch_end'):
model = self.__get_model()
model.on_batch_end()
return 0
def __run_validation(self):
# decide if can check epochs
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
if self.fast_dev_run:
print('skipping to check performance bc of --fast_dev_run')
elif not can_check_epoch:
return
# hook
if self.__is_function_implemented('on_pre_performance_check'):
model = self.__get_model()
model.on_pre_performance_check()
# use full val set on end of epoch
# use a small portion otherwise
max_batches = None if not self.fast_dev_run else 1
model_specific_tqdm_metrics_dic = self.validate(
self.model,
self.val_dataloader,
max_batches
)
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# hook
if self.__is_function_implemented('on_post_performance_check'):
model = self.__get_model()
model.on_post_performance_check()
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None:
print('save callback...')
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)