fix dataparallel
This commit is contained in:
parent
8fde5e444e
commit
0f5a7c322e
|
@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
|||
import traceback
|
||||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||
from torch.nn import DataParallel
|
||||
import pdb
|
||||
|
||||
try:
|
||||
|
@ -14,53 +14,37 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
APEX_AVAILABLE = False
|
||||
|
||||
|
||||
def reduce_distributed_output(output, nb_gpus):
|
||||
for k, v in output.items():
|
||||
# recurse on nested dics
|
||||
if isinstance(output[k], dict):
|
||||
output[k] = reduce_distributed_output(output[k], nb_gpus)
|
||||
|
||||
# reduce only metrics that have the same nb of gpus
|
||||
elif output[k].size(0) == nb_gpus:
|
||||
reduced = torch.mean(output[k])
|
||||
output[k] = reduced
|
||||
return output
|
||||
|
||||
|
||||
class Trainer(TrainerIO):
|
||||
|
||||
def __init__(self,
|
||||
experiment,
|
||||
checkpoint_callback, early_stop_callback,
|
||||
gradient_clip=0,
|
||||
cluster=None,
|
||||
process_position=0,
|
||||
current_gpu_name=0,
|
||||
gpus=None,
|
||||
progress_bar=True,
|
||||
on_gpu=False,
|
||||
enable_tqdm=True,
|
||||
overfit_pct=0.0,
|
||||
track_grad_norm=-1,
|
||||
check_val_every_n_epoch=1,
|
||||
fast_dev_run=False,
|
||||
accumulate_grad_batches=1,
|
||||
enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1,
|
||||
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
|
||||
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95,
|
||||
log_save_interval=100, add_log_row_interval=10,
|
||||
log_save_interval=1, add_log_row_interval=1,
|
||||
lr_scheduler_milestones=None,
|
||||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
check_grad_nans=False,
|
||||
amp_level='O2',
|
||||
nb_sanity_val_steps=5):
|
||||
|
||||
# Transfer params
|
||||
self.gradient_clip = gradient_clip
|
||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.enable_early_stop = enable_early_stop
|
||||
self.track_grad_norm = track_grad_norm
|
||||
self.fast_dev_run = fast_dev_run
|
||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.progress_bar = progress_bar
|
||||
self.on_gpu = on_gpu
|
||||
self.enable_tqdm = enable_tqdm
|
||||
self.experiment = experiment
|
||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||
self.cluster = cluster
|
||||
|
@ -78,9 +62,9 @@ class Trainer(TrainerIO):
|
|||
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
|
||||
self.lr_schedulers = []
|
||||
self.amp_level = amp_level
|
||||
self.print_nan_grads = print_nan_grads
|
||||
self.data_parallel_device_ids = gpus
|
||||
self.data_parallel = gpus is not None and len(gpus) > 0
|
||||
self.check_grad_nans = check_grad_nans
|
||||
self.data_parallel_device_ids = [0]
|
||||
self.data_parallel = False
|
||||
|
||||
# training state
|
||||
self.optimizers = None
|
||||
|
@ -128,18 +112,15 @@ class Trainer(TrainerIO):
|
|||
def __tng_tqdm_dic(self):
|
||||
tqdm_dic = {
|
||||
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
||||
'gpu': '{}'.format(self.current_gpu_name),
|
||||
'v_nb': '{}'.format(self.experiment.version),
|
||||
'epoch': '{}'.format(self.current_epoch),
|
||||
'batch_nb':'{}'.format(self.batch_nb),
|
||||
}
|
||||
tqdm_dic.update(self.tqdm_metrics)
|
||||
|
||||
if self.on_gpu:
|
||||
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
|
||||
|
||||
return tqdm_dic
|
||||
|
||||
def __layout_bookeeping(self, model):
|
||||
def __layout_bookeeping(self):
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
self.running_loss = []
|
||||
|
@ -148,17 +129,17 @@ class Trainer(TrainerIO):
|
|||
self.tqdm_metrics = {}
|
||||
|
||||
# determine number of training batches
|
||||
self.nb_tng_batches = model.nb_batches(self.tng_dataloader)
|
||||
self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader)
|
||||
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
|
||||
|
||||
# determine number of validation batches
|
||||
self.nb_val_batches = model.nb_batches(self.val_dataloader)
|
||||
self.nb_val_batches = self.model.nb_batches(self.val_dataloader)
|
||||
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
|
||||
self.nb_val_batches = max(1, self.nb_val_batches)
|
||||
self.nb_val_batches = self.nb_val_batches
|
||||
|
||||
# determine number of test batches
|
||||
self.nb_test_batches = model.nb_batches(self.test_dataloader)
|
||||
self.nb_test_batches = self.model.nb_batches(self.test_dataloader)
|
||||
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
|
||||
|
||||
# determine when to check validation
|
||||
|
@ -166,9 +147,6 @@ class Trainer(TrainerIO):
|
|||
|
||||
def __add_tqdm_metrics(self, metrics):
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
v = v.item()
|
||||
|
||||
self.tqdm_metrics[k] = v
|
||||
|
||||
def validate(self, model, dataloader, max_batches):
|
||||
|
@ -184,7 +162,6 @@ class Trainer(TrainerIO):
|
|||
# enable eval mode
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
model.from_lightning = True
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
@ -205,30 +182,21 @@ class Trainer(TrainerIO):
|
|||
# -----------------
|
||||
# RUN VALIDATION STEP
|
||||
# -----------------
|
||||
if self.data_parallel:
|
||||
output = model(data_batch, batch_i)
|
||||
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
||||
else:
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
outputs.append(output)
|
||||
|
||||
# batch done
|
||||
if self.progress_bar and self.prog_bar is not None:
|
||||
if self.enable_tqdm and self.prog_bar is not None:
|
||||
self.prog_bar.update(1)
|
||||
|
||||
# give model a chance to do something with the outputs
|
||||
if self.data_parallel:
|
||||
val_results = model.module.validation_end(outputs)
|
||||
else:
|
||||
val_results = model.validation_end(outputs)
|
||||
val_results = model.validation_end(outputs)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
||||
# enable gradients to save memory
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
return val_results
|
||||
|
||||
def __get_dataloaders(self, model):
|
||||
|
@ -245,16 +213,14 @@ class Trainer(TrainerIO):
|
|||
# MODEL TRAINING
|
||||
# -----------------------------
|
||||
def fit(self, model):
|
||||
|
||||
# give model convenience properties
|
||||
self.model = model
|
||||
model.trainer = self
|
||||
model.experiment = self.experiment
|
||||
|
||||
# transfer data loaders from model
|
||||
self.__get_dataloaders(model)
|
||||
|
||||
# init training constants
|
||||
self.__layout_bookeeping(model)
|
||||
self.__layout_bookeeping()
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
|
@ -262,8 +228,8 @@ class Trainer(TrainerIO):
|
|||
|
||||
if self.use_amp:
|
||||
# An example
|
||||
model, optimizer = amp.initialize(
|
||||
model, self.optimizers[0], opt_level=self.amp_level,
|
||||
self.model, optimizer = amp.initialize(
|
||||
self.model, self.optimizers[0], opt_level=self.amp_level,
|
||||
)
|
||||
self.optimizers[0] = optimizer
|
||||
model.trainer = self
|
||||
|
@ -279,7 +245,10 @@ class Trainer(TrainerIO):
|
|||
|
||||
# put on gpu if needed
|
||||
if self.on_gpu:
|
||||
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
if self.data_parallel:
|
||||
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
@ -294,7 +263,6 @@ class Trainer(TrainerIO):
|
|||
# ---------------------------
|
||||
# CORE TRAINING LOOP
|
||||
# ---------------------------
|
||||
self.model = model
|
||||
self.__train()
|
||||
|
||||
def __train(self):
|
||||
|
@ -304,28 +272,24 @@ class Trainer(TrainerIO):
|
|||
for lr_scheduler in self.lr_schedulers:
|
||||
lr_scheduler.step()
|
||||
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.current_epoch = epoch_nb
|
||||
self.model.current_epoch = epoch_nb
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_start'):
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_epoch_start()
|
||||
self.model.on_epoch_start()
|
||||
|
||||
self.current_epoch = epoch_nb
|
||||
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
|
||||
# init progbar when requested
|
||||
if self.progress_bar:
|
||||
if self.enable_tqdm:
|
||||
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
|
||||
|
||||
for batch_nb, data_batch in enumerate(self.tng_dataloader):
|
||||
self.batch_nb = batch_nb
|
||||
self.global_step += 1
|
||||
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.global_step = self.global_step
|
||||
self.model.global_step = self.global_step
|
||||
|
||||
# stop when the flag is changed or we've gone past the amount requested in the batches
|
||||
self.total_batch_nb += 1
|
||||
|
@ -355,10 +319,7 @@ class Trainer(TrainerIO):
|
|||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
if self.data_parallel:
|
||||
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
else:
|
||||
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu:
|
||||
|
@ -367,20 +328,16 @@ class Trainer(TrainerIO):
|
|||
|
||||
# add norms
|
||||
if self.track_grad_norm > 0:
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
||||
|
||||
grad_norm_dic = self.model.grad_norm(self.track_grad_norm)
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
|
||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||
self.experiment.log(metrics)
|
||||
self.experiment.save()
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_end'):
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_batch_end()
|
||||
self.model.on_batch_end()
|
||||
|
||||
# end epoch early
|
||||
if early_stop_epoch:
|
||||
|
@ -388,8 +345,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_end'):
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_epoch_end()
|
||||
self.model.on_epoch_end()
|
||||
|
||||
# early stopping
|
||||
if self.enable_early_stop:
|
||||
|
@ -401,24 +357,6 @@ class Trainer(TrainerIO):
|
|||
if stop:
|
||||
return
|
||||
|
||||
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
v = v.item()
|
||||
|
||||
if type(v) is dict:
|
||||
v = self.__metrics_to_scalars(v)
|
||||
|
||||
if k not in blacklist:
|
||||
new_metrics[k] = float(v)
|
||||
|
||||
return new_metrics
|
||||
|
||||
def __log_vals_blacklist(self):
|
||||
"""avoid logging some vals lightning uses to maintain state"""
|
||||
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
|
||||
return blacklist
|
||||
|
||||
def __run_tng_batch(self, data_batch, batch_nb):
|
||||
if data_batch is None:
|
||||
|
@ -426,26 +364,16 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_start'):
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
response = model.on_batch_start(data_batch)
|
||||
|
||||
response = self.model.on_batch_start(data_batch)
|
||||
if response == -1:
|
||||
return -1
|
||||
|
||||
if self.progress_bar:
|
||||
if self.enable_tqdm:
|
||||
self.prog_bar.update(1)
|
||||
|
||||
# forward pass
|
||||
# return a scalar value and a dic with tqdm metrics
|
||||
if self.data_parallel:
|
||||
output = self.model(data_batch, batch_nb)
|
||||
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
||||
else:
|
||||
output = self.model.training_step(data_batch, batch_nb)
|
||||
|
||||
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
|
||||
loss = output['loss']
|
||||
|
||||
loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb)
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
# backward pass
|
||||
|
@ -456,9 +384,8 @@ class Trainer(TrainerIO):
|
|||
else:
|
||||
loss.backward()
|
||||
|
||||
if self.print_nan_grads:
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
for param in model.parameters():
|
||||
if self.check_grad_nans:
|
||||
for param in self.model.parameters():
|
||||
print(param.grad.float().sum())
|
||||
|
||||
self.batch_loss_value += loss.item()
|
||||
|
@ -466,11 +393,6 @@ class Trainer(TrainerIO):
|
|||
# gradient update with accumulated gradients
|
||||
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
||||
|
||||
# clip gradients
|
||||
if self.gradient_clip > 0:
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip)
|
||||
|
||||
# update gradients across all optimizers
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.step()
|
||||
|
@ -487,7 +409,7 @@ class Trainer(TrainerIO):
|
|||
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||
|
||||
# update progbar
|
||||
if self.progress_bar:
|
||||
if self.enable_tqdm:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
@ -529,7 +451,7 @@ class Trainer(TrainerIO):
|
|||
print(e)
|
||||
print(traceback.print_exc())
|
||||
|
||||
if self.progress_bar:
|
||||
if self.enable_tqdm:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
|
Loading…
Reference in New Issue