Change nb to num in ABCs, comments, and tqdm logging (#613)

* Change nb to num in ABCs, comments, and tqdm logging

* Fix warnings text

* Make warnings one line

* Change num to number in comments
This commit is contained in:
Elliot Waite 2019-12-09 04:40:27 -08:00 committed by William Falcon
parent 607dbdaefd
commit b492e2b89e
4 changed files with 29 additions and 10 deletions

View File

@ -145,8 +145,8 @@ class TrainerEvaluationLoopMixin(ABC):
self.single_gpu = None
self.data_parallel_device_ids = None
self.model = None
self.nb_test_batches = None
self.nb_val_batches = None
self.num_test_batches = None
self.num_val_batches = None
self.fast_dev_run = None
self.process_position = None
self.show_progress_bar = None

View File

@ -177,7 +177,7 @@ class TrainerLoggingMixin(ABC):
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
pass
# reduce only metrics that have the same nb of gpus
# reduce only metrics that have the same number of gpus
elif output[k].size(0) == num_gpus:
reduced = torch.mean(output[k])
output[k] = reduced

View File

@ -346,7 +346,7 @@ class Trainer(TrainerIOMixin,
tqdm_dict['split_idx'] = self.split_idx
if self.logger is not None and self.logger.version is not None:
tqdm_dict['v_nb'] = self.logger.version
tqdm_dict['v_num'] = self.logger.version
tqdm_dict.update(self.tqdm_metrics)

View File

@ -151,6 +151,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b
import inspect
from abc import ABC, abstractmethod
import warnings
import numpy as np
@ -169,22 +170,22 @@ class TrainerTrainLoopMixin(ABC):
def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.max_nb_epochs = None
self.max_epochs = None
self.min_epochs = None
self.use_ddp = None
self.use_dp = None
self.use_ddp2 = None
self.single_gpu = None
self.data_parallel_device_ids = None
self.check_val_every_n_epoch = None
self.nb_training_batches = None
self.num_training_batches = None
self.val_check_batch = None
self.nb_val_batches = None
self.num_val_batches = None
self.fast_dev_run = None
self.is_iterable_train_dataloader = None
self.main_progress_bar = None
self.accumulation_scheduler = None
self.lr_schedulers = None
self.min_nb_epochs = None
self.enable_early_stop = None
self.early_stop_callback = None
self.callback_metrics = None
@ -194,7 +195,7 @@ class TrainerTrainLoopMixin(ABC):
self.log_save_interval = None
self.proc_rank = None
self.row_log_interval = None
self.total_batch_nb = None
self.total_batches = None
self.truncated_bptt_steps = None
self.optimizers = None
self.accumulate_grad_batches = None
@ -207,6 +208,24 @@ class TrainerTrainLoopMixin(ABC):
self.get_train_dataloader = None
self.reduce_lr_on_plateau_scheduler = None
@property
def max_nb_epochs(self):
"""
.. warning:: `max_nb_epochs` is deprecated and will be removed in v0.8.0, use `max_epochs` instead.
"""
warnings.warn("`max_nb_epochs` is deprecated and will be removed in "
"v0.8.0, use `max_epochs` instead.", DeprecationWarning)
return self.max_epochs
@property
def min_nb_epochs(self):
"""
.. warning:: `min_nb_epochs` is deprecated and will be removed in v0.8.0, use `min_epochs` instead.
"""
warnings.warn("`min_nb_epochs` is deprecated and will be removed in "
"v0.8.0, use `min_epochs` instead.", DeprecationWarning)
return self.min_epochs
@abstractmethod
def get_model(self):
# this is just empty shell for code from other class
@ -391,7 +410,7 @@ class TrainerTrainLoopMixin(ABC):
if early_stop_epoch or self.fast_dev_run:
break
# stop epoch if we limited nb batches
# stop epoch if we limited the number of training batches
met_batch_limit = batch_idx >= self.num_training_batches
if met_batch_limit:
break