Merge branch 'master' of https://github.com/williamFalcon/pytorch-lightning
This commit is contained in:
commit
7225e5d6d5
|
@ -43,13 +43,20 @@ trainer = Trainer(test_percent_check=0.1)
|
||||||
|
|
||||||
---
|
---
|
||||||
#### Set validation check frequency within 1 training epoch
|
#### Set validation check frequency within 1 training epoch
|
||||||
For large datasets it's often desirable to check validation multiple times within a training loop
|
For large datasets it's often desirable to check validation multiple times within a training loop.
|
||||||
|
Pass in a float to check that often within 1 training epoch.
|
||||||
|
Pass in an int k to check every k training batches. Must use an int if using
|
||||||
|
an IterableDataset.
|
||||||
|
|
||||||
``` {.python}
|
``` {.python}
|
||||||
# DEFAULT
|
# DEFAULT
|
||||||
trainer = Trainer(val_check_interval=0.95)
|
trainer = Trainer(val_check_interval=0.95)
|
||||||
|
|
||||||
# check every .25 of an epoch
|
# check every .25 of an epoch
|
||||||
trainer = Trainer(val_check_interval=0.25)
|
trainer = Trainer(val_check_interval=0.25)
|
||||||
|
|
||||||
|
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
|
||||||
|
trainer = Trainer(val_check_interval=100)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
@ -2,6 +2,9 @@ import warnings
|
||||||
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
@ -15,6 +18,9 @@ class TrainerDataLoadingMixin(object):
|
||||||
def layout_bookeeping(self):
|
def layout_bookeeping(self):
|
||||||
|
|
||||||
# determine number of training batches
|
# determine number of training batches
|
||||||
|
if isinstance(self.get_train_dataloader(), IterableDataset):
|
||||||
|
self.nb_training_batches = float('inf')
|
||||||
|
else:
|
||||||
self.nb_training_batches = len(self.get_train_dataloader())
|
self.nb_training_batches = len(self.get_train_dataloader())
|
||||||
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
|
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
|
||||||
|
|
||||||
|
@ -34,6 +40,11 @@ class TrainerDataLoadingMixin(object):
|
||||||
self.nb_test_batches = max(1, self.nb_test_batches)
|
self.nb_test_batches = max(1, self.nb_test_batches)
|
||||||
|
|
||||||
# determine when to check validation
|
# determine when to check validation
|
||||||
|
# if int passed in, val checks that often
|
||||||
|
# otherwise, it checks in [0, 1.0] % range of a training epoch
|
||||||
|
if isinstance(self.val_check_interval, int):
|
||||||
|
self.val_check_batch = self.val_check_interval
|
||||||
|
else:
|
||||||
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
|
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
|
||||||
self.val_check_batch = max(1, self.val_check_batch)
|
self.val_check_batch = max(1, self.val_check_batch)
|
||||||
|
|
||||||
|
@ -127,6 +138,16 @@ class TrainerDataLoadingMixin(object):
|
||||||
self.get_test_dataloaders()
|
self.get_test_dataloaders()
|
||||||
self.get_val_dataloaders()
|
self.get_val_dataloaders()
|
||||||
|
|
||||||
|
# support IterableDataset for train data
|
||||||
|
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
|
||||||
|
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
|
||||||
|
m = '''
|
||||||
|
When using an iterableDataset for train_dataloader,
|
||||||
|
Trainer(val_check_interval) must be an int.
|
||||||
|
An int k specifies checking validation every k training batches
|
||||||
|
'''
|
||||||
|
raise MisconfigurationException('when using ')
|
||||||
|
|
||||||
def determine_data_use_amount(self, train_percent_check, val_percent_check,
|
def determine_data_use_amount(self, train_percent_check, val_percent_check,
|
||||||
test_percent_check, overfit_pct):
|
test_percent_check, overfit_pct):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -31,7 +31,12 @@ class TrainerTrainLoopMixin(object):
|
||||||
|
|
||||||
# init progress_bar when requested
|
# init progress_bar when requested
|
||||||
if self.show_progress_bar:
|
if self.show_progress_bar:
|
||||||
self.progress_bar.reset(self.total_batches)
|
nb_iterations = self.total_batches
|
||||||
|
|
||||||
|
# for iterable train loader, the progress bar never ends
|
||||||
|
if self.is_iterable_train_dataloader:
|
||||||
|
nb_iterations = float('inf')
|
||||||
|
self.progress_bar.reset(nb_iterations)
|
||||||
|
|
||||||
# changing gradient according accumulation_scheduler
|
# changing gradient according accumulation_scheduler
|
||||||
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
|
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
|
||||||
|
|
|
@ -101,7 +101,7 @@ class Trainer(TrainerIOMixin,
|
||||||
:param train_percent_check: int. How much of train set to check
|
:param train_percent_check: int. How much of train set to check
|
||||||
:param val_percent_check: int. How much of val set to check
|
:param val_percent_check: int. How much of val set to check
|
||||||
:param test_percent_check: int. How much of test set to check
|
:param test_percent_check: int. How much of test set to check
|
||||||
:param val_check_interval: int. Check val this frequently within a train epoch
|
:param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
|
||||||
:param log_save_interval: int. Writes logs to disk this often
|
:param log_save_interval: int. Writes logs to disk this often
|
||||||
:param row_log_interval: int. How often to add logging rows
|
:param row_log_interval: int. How often to add logging rows
|
||||||
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
||||||
|
@ -160,6 +160,7 @@ class Trainer(TrainerIOMixin,
|
||||||
self.get_train_dataloader = None
|
self.get_train_dataloader = None
|
||||||
self.get_test_dataloaders = None
|
self.get_test_dataloaders = None
|
||||||
self.get_val_dataloaders = None
|
self.get_val_dataloaders = None
|
||||||
|
self.is_iterable_train_dataloader = False
|
||||||
|
|
||||||
# training state
|
# training state
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
Loading…
Reference in New Issue