diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index cf2667a04d..5bb222e3a9 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map import traceback from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR -from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel +from torch.nn import DataParallel import pdb try: @@ -14,53 +14,37 @@ try: except ModuleNotFoundError: APEX_AVAILABLE = False - -def reduce_distributed_output(output, nb_gpus): - for k, v in output.items(): - # recurse on nested dics - if isinstance(output[k], dict): - output[k] = reduce_distributed_output(output[k], nb_gpus) - - # reduce only metrics that have the same nb of gpus - elif output[k].size(0) == nb_gpus: - reduced = torch.mean(output[k]) - output[k] = reduced - return output - - class Trainer(TrainerIO): def __init__(self, experiment, checkpoint_callback, early_stop_callback, - gradient_clip=0, cluster=None, process_position=0, current_gpu_name=0, - gpus=None, - progress_bar=True, + on_gpu=False, + enable_tqdm=True, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, - enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1, + enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, - log_save_interval=100, add_log_row_interval=10, + log_save_interval=1, add_log_row_interval=1, lr_scheduler_milestones=None, use_amp=False, - print_nan_grads=False, + check_grad_nans=False, amp_level='O2', nb_sanity_val_steps=5): # Transfer params - self.gradient_clip = gradient_clip self.check_val_every_n_epoch = check_val_every_n_epoch self.enable_early_stop = enable_early_stop self.track_grad_norm = track_grad_norm self.fast_dev_run = fast_dev_run - self.on_gpu = gpus is not None and torch.cuda.is_available() - self.progress_bar = progress_bar + self.on_gpu = on_gpu + self.enable_tqdm = enable_tqdm self.experiment = experiment self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) self.cluster = cluster @@ -78,9 +62,9 @@ class Trainer(TrainerIO): self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level - self.print_nan_grads = print_nan_grads - self.data_parallel_device_ids = gpus - self.data_parallel = gpus is not None and len(gpus) > 0 + self.check_grad_nans = check_grad_nans + self.data_parallel_device_ids = [0] + self.data_parallel = False # training state self.optimizers = None @@ -128,18 +112,15 @@ class Trainer(TrainerIO): def __tng_tqdm_dic(self): tqdm_dic = { 'tng_loss': '{0:.3f}'.format(self.avg_loss), + 'gpu': '{}'.format(self.current_gpu_name), 'v_nb': '{}'.format(self.experiment.version), 'epoch': '{}'.format(self.current_epoch), 'batch_nb':'{}'.format(self.batch_nb), } tqdm_dic.update(self.tqdm_metrics) - - if self.on_gpu: - tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name) - return tqdm_dic - def __layout_bookeeping(self, model): + def __layout_bookeeping(self): # training bookeeping self.total_batch_nb = 0 self.running_loss = [] @@ -148,17 +129,17 @@ class Trainer(TrainerIO): self.tqdm_metrics = {} # determine number of training batches - self.nb_tng_batches = model.nb_batches(self.tng_dataloader) + self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader) self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches - self.nb_val_batches = model.nb_batches(self.val_dataloader) + self.nb_val_batches = self.model.nb_batches(self.val_dataloader) self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) self.nb_val_batches = max(1, self.nb_val_batches) self.nb_val_batches = self.nb_val_batches # determine number of test batches - self.nb_test_batches = model.nb_batches(self.test_dataloader) + self.nb_test_batches = self.model.nb_batches(self.test_dataloader) self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation @@ -166,9 +147,6 @@ class Trainer(TrainerIO): def __add_tqdm_metrics(self, metrics): for k, v in metrics.items(): - if type(v) is torch.Tensor: - v = v.item() - self.tqdm_metrics[k] = v def validate(self, model, dataloader, max_batches): @@ -184,7 +162,6 @@ class Trainer(TrainerIO): # enable eval mode model.zero_grad() model.eval() - model.from_lightning = True # disable gradients to save memory torch.set_grad_enabled(False) @@ -205,30 +182,21 @@ class Trainer(TrainerIO): # ----------------- # RUN VALIDATION STEP # ----------------- - if self.data_parallel: - output = model(data_batch, batch_i) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) - else: - output = model.validation_step(data_batch, batch_i) - + output = model.validation_step(data_batch, batch_i) outputs.append(output) # batch done - if self.progress_bar and self.prog_bar is not None: + if self.enable_tqdm and self.prog_bar is not None: self.prog_bar.update(1) # give model a chance to do something with the outputs - if self.data_parallel: - val_results = model.module.validation_end(outputs) - else: - val_results = model.validation_end(outputs) + val_results = model.validation_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) - return val_results def __get_dataloaders(self, model): @@ -245,16 +213,14 @@ class Trainer(TrainerIO): # MODEL TRAINING # ----------------------------- def fit(self, model): - - # give model convenience properties + self.model = model model.trainer = self - model.experiment = self.experiment # transfer data loaders from model self.__get_dataloaders(model) # init training constants - self.__layout_bookeeping(model) + self.__layout_bookeeping() # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus @@ -262,8 +228,8 @@ class Trainer(TrainerIO): if self.use_amp: # An example - model, optimizer = amp.initialize( - model, self.optimizers[0], opt_level=self.amp_level, + self.model, optimizer = amp.initialize( + self.model, self.optimizers[0], opt_level=self.amp_level, ) self.optimizers[0] = optimizer model.trainer = self @@ -279,7 +245,10 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) + if self.data_parallel: + model = DataParallel(model, device_ids=self.data_parallel_device_ids) + + model.cuda(self.data_parallel_device_ids[0]) # run tiny validation to make sure program won't crash during val _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps) @@ -294,7 +263,6 @@ class Trainer(TrainerIO): # --------------------------- # CORE TRAINING LOOP # --------------------------- - self.model = model self.__train() def __train(self): @@ -304,28 +272,24 @@ class Trainer(TrainerIO): for lr_scheduler in self.lr_schedulers: lr_scheduler.step() - model = self.model.module if self.data_parallel else self.model - model.current_epoch = epoch_nb + self.model.current_epoch = epoch_nb # hook if self.__is_function_implemented('on_epoch_start'): - model = self.model.module if self.data_parallel else self.model - model.on_epoch_start() + self.model.on_epoch_start() self.current_epoch = epoch_nb self.total_batches = self.nb_tng_batches + self.nb_val_batches self.batch_loss_value = 0 # accumulated grads # init progbar when requested - if self.progress_bar: + if self.enable_tqdm: self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position) for batch_nb, data_batch in enumerate(self.tng_dataloader): self.batch_nb = batch_nb self.global_step += 1 - - model = self.model.module if self.data_parallel else self.model - model.global_step = self.global_step + self.model.global_step = self.global_step # stop when the flag is changed or we've gone past the amount requested in the batches self.total_batch_nb += 1 @@ -355,10 +319,7 @@ class Trainer(TrainerIO): # count items in memory # nb_params, nb_tensors = count_mem_items() - if self.data_parallel: - metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic) - else: - metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic) + metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic) # add gpu memory if self.on_gpu: @@ -367,20 +328,16 @@ class Trainer(TrainerIO): # add norms if self.track_grad_norm > 0: - model = self.model.module if self.data_parallel else self.model - grad_norm_dic = model.grad_norm(self.track_grad_norm) - + grad_norm_dic = self.model.grad_norm(self.track_grad_norm) metrics.update(grad_norm_dic) # log metrics - scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist()) - self.experiment.log(scalar_metrics, global_step=self.global_step) + self.experiment.log(metrics) self.experiment.save() # hook if self.__is_function_implemented('on_batch_end'): - model = self.model.module if self.data_parallel else self.model - model.on_batch_end() + self.model.on_batch_end() # end epoch early if early_stop_epoch: @@ -388,8 +345,7 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_epoch_end'): - model = self.model.module if self.data_parallel else self.model - model.on_epoch_end() + self.model.on_epoch_end() # early stopping if self.enable_early_stop: @@ -401,24 +357,6 @@ class Trainer(TrainerIO): if stop: return - def __metrics_to_scalars(self, metrics, blacklist=[]): - new_metrics = {} - for k, v in metrics.items(): - if type(v) is torch.Tensor: - v = v.item() - - if type(v) is dict: - v = self.__metrics_to_scalars(v) - - if k not in blacklist: - new_metrics[k] = float(v) - - return new_metrics - - def __log_vals_blacklist(self): - """avoid logging some vals lightning uses to maintain state""" - blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'} - return blacklist def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: @@ -426,26 +364,16 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - model = self.model.module if self.data_parallel else self.model - response = model.on_batch_start(data_batch) - + response = self.model.on_batch_start(data_batch) if response == -1: return -1 - if self.progress_bar: + if self.enable_tqdm: self.prog_bar.update(1) # forward pass # return a scalar value and a dic with tqdm metrics - if self.data_parallel: - output = self.model(data_batch, batch_nb) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) - else: - output = self.model.training_step(data_batch, batch_nb) - - model_specific_tqdm_metrics_dic = output['tqdm_metrics'] - loss = output['loss'] - + loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb) self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass @@ -456,9 +384,8 @@ class Trainer(TrainerIO): else: loss.backward() - if self.print_nan_grads: - model = self.model.module if self.data_parallel else self.model - for param in model.parameters(): + if self.check_grad_nans: + for param in self.model.parameters(): print(param.grad.float().sum()) self.batch_loss_value += loss.item() @@ -466,11 +393,6 @@ class Trainer(TrainerIO): # gradient update with accumulated gradients if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: - # clip gradients - if self.gradient_clip > 0: - model = self.model.module if self.data_parallel else self.model - torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip) - # update gradients across all optimizers for optimizer in self.optimizers: optimizer.step() @@ -487,7 +409,7 @@ class Trainer(TrainerIO): self.avg_loss = np.mean(self.running_loss[-100:]) # update progbar - if self.progress_bar: + if self.enable_tqdm: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics) @@ -529,7 +451,7 @@ class Trainer(TrainerIO): print(e) print(traceback.print_exc()) - if self.progress_bar: + if self.enable_tqdm: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics)