This commit is contained in:
William Falcon 2019-10-22 10:34:30 +03:00
commit 7225e5d6d5
4 changed files with 41 additions and 7 deletions

View File

@ -43,13 +43,20 @@ trainer = Trainer(test_percent_check=0.1)
--- ---
#### Set validation check frequency within 1 training epoch #### Set validation check frequency within 1 training epoch
For large datasets it's often desirable to check validation multiple times within a training loop For large datasets it's often desirable to check validation multiple times within a training loop.
Pass in a float to check that often within 1 training epoch.
Pass in an int k to check every k training batches. Must use an int if using
an IterableDataset.
``` {.python} ``` {.python}
# DEFAULT # DEFAULT
trainer = Trainer(val_check_interval=0.95) trainer = Trainer(val_check_interval=0.95)
# check every .25 of an epoch # check every .25 of an epoch
trainer = Trainer(val_check_interval=0.25) trainer = Trainer(val_check_interval=0.25)
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
trainer = Trainer(val_check_interval=100)
``` ```
--- ---

View File

@ -2,6 +2,9 @@ import warnings
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import IterableDataset
from pytorch_lightning.utilities.debugging import MisconfigurationException
try: try:
from apex import amp from apex import amp
@ -15,8 +18,11 @@ class TrainerDataLoadingMixin(object):
def layout_bookeeping(self): def layout_bookeeping(self):
# determine number of training batches # determine number of training batches
self.nb_training_batches = len(self.get_train_dataloader()) if isinstance(self.get_train_dataloader(), IterableDataset):
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check) self.nb_training_batches = float('inf')
else:
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
# determine number of validation batches # determine number of validation batches
# val datasets could be none, 1 or 2+ # val datasets could be none, 1 or 2+
@ -34,8 +40,13 @@ class TrainerDataLoadingMixin(object):
self.nb_test_batches = max(1, self.nb_test_batches) self.nb_test_batches = max(1, self.nb_test_batches)
# determine when to check validation # determine when to check validation
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval) # if int passed in, val checks that often
self.val_check_batch = max(1, self.val_check_batch) # otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
else:
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def get_dataloaders(self, model): def get_dataloaders(self, model):
""" """
@ -127,6 +138,16 @@ class TrainerDataLoadingMixin(object):
self.get_test_dataloaders() self.get_test_dataloaders()
self.get_val_dataloaders() self.get_val_dataloaders()
# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Trainer(val_check_interval) must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException('when using ')
def determine_data_use_amount(self, train_percent_check, val_percent_check, def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct): test_percent_check, overfit_pct):
""" """

View File

@ -31,7 +31,12 @@ class TrainerTrainLoopMixin(object):
# init progress_bar when requested # init progress_bar when requested
if self.show_progress_bar: if self.show_progress_bar:
self.progress_bar.reset(self.total_batches) nb_iterations = self.total_batches
# for iterable train loader, the progress bar never ends
if self.is_iterable_train_dataloader:
nb_iterations = float('inf')
self.progress_bar.reset(nb_iterations)
# changing gradient according accumulation_scheduler # changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self) self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)

View File

@ -101,7 +101,7 @@ class Trainer(TrainerIOMixin,
:param train_percent_check: int. How much of train set to check :param train_percent_check: int. How much of train set to check
:param val_percent_check: int. How much of val set to check :param val_percent_check: int. How much of val set to check
:param test_percent_check: int. How much of test set to check :param test_percent_check: int. How much of test set to check
:param val_check_interval: int. Check val this frequently within a train epoch :param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
:param log_save_interval: int. Writes logs to disk this often :param log_save_interval: int. Writes logs to disk this often
:param row_log_interval: int. How often to add logging rows :param row_log_interval: int. How often to add logging rows
:param add_row_log_interval: int. How often to add logging rows. Deprecated. :param add_row_log_interval: int. How often to add logging rows. Deprecated.
@ -160,6 +160,7 @@ class Trainer(TrainerIOMixin,
self.get_train_dataloader = None self.get_train_dataloader = None
self.get_test_dataloaders = None self.get_test_dataloaders = None
self.get_val_dataloaders = None self.get_val_dataloaders = None
self.is_iterable_train_dataloader = False
# training state # training state
self.model = None self.model = None