fix dataparallel

This commit is contained in:
William Falcon 2019-07-01 18:38:07 -04:00 committed by GitHub
parent 0f5a7c322e
commit 49c27770da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 121 additions and 42 deletions

View File

@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
import traceback
from pytorch_lightning.root_module.model_saving import TrainerIO
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn import DataParallel
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
import pdb
try:
@ -14,37 +14,53 @@ try:
except ModuleNotFoundError:
APEX_AVAILABLE = False
def reduce_distributed_output(output, nb_gpus):
for k, v in output.items():
# recurse on nested dics
if isinstance(output[k], dict):
output[k] = reduce_distributed_output(output[k], nb_gpus)
# reduce only metrics that have the same nb of gpus
elif output[k].size(0) == nb_gpus:
reduced = torch.mean(output[k])
output[k] = reduced
return output
class Trainer(TrainerIO):
def __init__(self,
experiment,
checkpoint_callback, early_stop_callback,
gradient_clip=0,
cluster=None,
process_position=0,
current_gpu_name=0,
on_gpu=False,
enable_tqdm=True,
gpus=None,
progress_bar=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1,
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95,
log_save_interval=1, add_log_row_interval=1,
log_save_interval=100, add_log_row_interval=10,
lr_scheduler_milestones=None,
use_amp=False,
check_grad_nans=False,
print_nan_grads=False,
amp_level='O2',
nb_sanity_val_steps=5):
# Transfer params
self.gradient_clip = gradient_clip
self.check_val_every_n_epoch = check_val_every_n_epoch
self.enable_early_stop = enable_early_stop
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = on_gpu
self.enable_tqdm = enable_tqdm
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
@ -62,9 +78,9 @@ class Trainer(TrainerIO):
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
self.lr_schedulers = []
self.amp_level = amp_level
self.check_grad_nans = check_grad_nans
self.data_parallel_device_ids = [0]
self.data_parallel = False
self.print_nan_grads = print_nan_grads
self.data_parallel_device_ids = gpus
self.data_parallel = gpus is not None and len(gpus) > 0
# training state
self.optimizers = None
@ -112,15 +128,18 @@ class Trainer(TrainerIO):
def __tng_tqdm_dic(self):
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'gpu': '{}'.format(self.current_gpu_name),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb':'{}'.format(self.batch_nb),
}
tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
return tqdm_dic
def __layout_bookeeping(self):
def __layout_bookeeping(self, model):
# training bookeeping
self.total_batch_nb = 0
self.running_loss = []
@ -129,17 +148,17 @@ class Trainer(TrainerIO):
self.tqdm_metrics = {}
# determine number of training batches
self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader)
self.nb_tng_batches = model.nb_batches(self.tng_dataloader)
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
# determine number of validation batches
self.nb_val_batches = self.model.nb_batches(self.val_dataloader)
self.nb_val_batches = model.nb_batches(self.val_dataloader)
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches
# determine number of test batches
self.nb_test_batches = self.model.nb_batches(self.test_dataloader)
self.nb_test_batches = model.nb_batches(self.test_dataloader)
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
# determine when to check validation
@ -147,6 +166,9 @@ class Trainer(TrainerIO):
def __add_tqdm_metrics(self, metrics):
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
self.tqdm_metrics[k] = v
def validate(self, model, dataloader, max_batches):
@ -162,6 +184,7 @@ class Trainer(TrainerIO):
# enable eval mode
model.zero_grad()
model.eval()
model.from_lightning = True
# disable gradients to save memory
torch.set_grad_enabled(False)
@ -182,21 +205,30 @@ class Trainer(TrainerIO):
# -----------------
# RUN VALIDATION STEP
# -----------------
output = model.validation_step(data_batch, batch_i)
if self.data_parallel:
output = model(data_batch, batch_i)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = model.validation_step(data_batch, batch_i)
outputs.append(output)
# batch done
if self.enable_tqdm and self.prog_bar is not None:
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)
# give model a chance to do something with the outputs
val_results = model.validation_end(outputs)
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
# enable train mode again
model.train()
# enable gradients to save memory
torch.set_grad_enabled(True)
return val_results
def __get_dataloaders(self, model):
@ -213,14 +245,16 @@ class Trainer(TrainerIO):
# MODEL TRAINING
# -----------------------------
def fit(self, model):
self.model = model
# give model convenience properties
model.trainer = self
model.experiment = self.experiment
# transfer data loaders from model
self.__get_dataloaders(model)
# init training constants
self.__layout_bookeeping()
self.__layout_bookeeping(model)
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
@ -228,8 +262,8 @@ class Trainer(TrainerIO):
if self.use_amp:
# An example
self.model, optimizer = amp.initialize(
self.model, self.optimizers[0], opt_level=self.amp_level,
model, optimizer = amp.initialize(
model, self.optimizers[0], opt_level=self.amp_level,
)
self.optimizers[0] = optimizer
model.trainer = self
@ -245,9 +279,7 @@ class Trainer(TrainerIO):
# put on gpu if needed
if self.on_gpu:
if self.data_parallel:
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
model.cuda(self.data_parallel_device_ids[0])
# run tiny validation to make sure program won't crash during val
@ -263,6 +295,7 @@ class Trainer(TrainerIO):
# ---------------------------
# CORE TRAINING LOOP
# ---------------------------
self.model = model
self.__train()
def __train(self):
@ -272,24 +305,28 @@ class Trainer(TrainerIO):
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
self.model.current_epoch = epoch_nb
model = self.model.module if self.data_parallel else self.model
model.current_epoch = epoch_nb
# hook
if self.__is_function_implemented('on_epoch_start'):
self.model.on_epoch_start()
model = self.model.module if self.data_parallel else self.model
model.on_epoch_start()
self.current_epoch = epoch_nb
self.total_batches = self.nb_tng_batches + self.nb_val_batches
self.batch_loss_value = 0 # accumulated grads
# init progbar when requested
if self.enable_tqdm:
if self.progress_bar:
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
for batch_nb, data_batch in enumerate(self.tng_dataloader):
self.batch_nb = batch_nb
self.global_step += 1
self.model.global_step = self.global_step
model = self.model.module if self.data_parallel else self.model
model.global_step = self.global_step
# stop when the flag is changed or we've gone past the amount requested in the batches
self.total_batch_nb += 1
@ -319,7 +356,10 @@ class Trainer(TrainerIO):
# count items in memory
# nb_params, nb_tensors = count_mem_items()
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
if self.data_parallel:
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
else:
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
# add gpu memory
if self.on_gpu:
@ -328,16 +368,20 @@ class Trainer(TrainerIO):
# add norms
if self.track_grad_norm > 0:
grad_norm_dic = self.model.grad_norm(self.track_grad_norm)
model = self.model.module if self.data_parallel else self.model
grad_norm_dic = model.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
# log metrics
self.experiment.log(metrics)
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()
# hook
if self.__is_function_implemented('on_batch_end'):
self.model.on_batch_end()
model = self.model.module if self.data_parallel else self.model
model.on_batch_end()
# end epoch early
if early_stop_epoch:
@ -345,7 +389,8 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_epoch_end'):
self.model.on_epoch_end()
model = self.model.module if self.data_parallel else self.model
model.on_epoch_end()
# early stopping
if self.enable_early_stop:
@ -357,6 +402,24 @@ class Trainer(TrainerIO):
if stop:
return
def __metrics_to_scalars(self, metrics, blacklist=[]):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
if type(v) is dict:
v = self.__metrics_to_scalars(v)
if k not in blacklist:
new_metrics[k] = float(v)
return new_metrics
def __log_vals_blacklist(self):
"""avoid logging some vals lightning uses to maintain state"""
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
return blacklist
def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
@ -364,16 +427,26 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_batch_start'):
response = self.model.on_batch_start(data_batch)
model = self.model.module if self.data_parallel else self.model
response = model.on_batch_start(data_batch)
if response == -1:
return -1
if self.enable_tqdm:
if self.progress_bar:
self.prog_bar.update(1)
# forward pass
# return a scalar value and a dic with tqdm metrics
loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb)
if self.data_parallel:
output = self.model(data_batch, batch_nb)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = self.model.training_step(data_batch, batch_nb)
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
loss = output['loss']
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# backward pass
@ -384,8 +457,9 @@ class Trainer(TrainerIO):
else:
loss.backward()
if self.check_grad_nans:
for param in self.model.parameters():
if self.print_nan_grads:
model = self.model.module if self.data_parallel else self.model
for param in model.parameters():
print(param.grad.float().sum())
self.batch_loss_value += loss.item()
@ -393,6 +467,11 @@ class Trainer(TrainerIO):
# gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
# clip gradients
if self.gradient_clip > 0:
model = self.model.module if self.data_parallel else self.model
torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip)
# update gradients across all optimizers
for optimizer in self.optimizers:
optimizer.step()
@ -409,7 +488,7 @@ class Trainer(TrainerIO):
self.avg_loss = np.mean(self.running_loss[-100:])
# update progbar
if self.enable_tqdm:
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
@ -451,7 +530,7 @@ class Trainer(TrainerIO):
print(e)
print(traceback.print_exc())
if self.enable_tqdm:
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)