* Provide backward compatibility for e681253
* typo fix
This commit is contained in:
parent
67f6e7bb19
commit
4103a5ca73
|
@ -7,6 +7,8 @@ from pytorch_lightning.root_module.model_saving import ModelIO
|
||||||
from pytorch_lightning.root_module.hooks import ModelHooks
|
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||||
from pytorch_lightning.root_module.decorators import data_loader
|
from pytorch_lightning.root_module.decorators import data_loader
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
|
@ -110,13 +112,29 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||||
# clear gradients
|
# clear gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
@data_loader
|
||||||
|
def tng_dataloader(self):
|
||||||
|
"""
|
||||||
|
Implement a PyTorch DataLoader
|
||||||
|
* Deprecated in v0.5.0. use train_dataloader instead. *
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@data_loader
|
@data_loader
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
Implement a PyTorch DataLoader
|
Implement a PyTorch DataLoader
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
#
|
||||||
|
try:
|
||||||
|
output = self.tng_dataloader()
|
||||||
|
warnings.warn("tng_dataloader has been renamed to train_dataloader since v0.5.0",
|
||||||
|
DeprecationWarning)
|
||||||
|
return output
|
||||||
|
except NotImplementedError:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@data_loader
|
@data_loader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
|
|
|
@ -63,6 +63,7 @@ class Trainer(TrainerIOMixin):
|
||||||
early_stop_callback=True,
|
early_stop_callback=True,
|
||||||
default_save_path=None,
|
default_save_path=None,
|
||||||
gradient_clip_val=0,
|
gradient_clip_val=0,
|
||||||
|
gradient_clip=None, # backward compatible
|
||||||
process_position=0,
|
process_position=0,
|
||||||
nb_gpu_nodes=1,
|
nb_gpu_nodes=1,
|
||||||
gpus=None,
|
gpus=None,
|
||||||
|
@ -81,6 +82,7 @@ class Trainer(TrainerIOMixin):
|
||||||
val_check_interval=1.0,
|
val_check_interval=1.0,
|
||||||
log_save_interval=100,
|
log_save_interval=100,
|
||||||
row_log_interval=10,
|
row_log_interval=10,
|
||||||
|
add_row_log_interval=None, # backward compatible
|
||||||
distributed_backend=None,
|
distributed_backend=None,
|
||||||
use_amp=False,
|
use_amp=False,
|
||||||
print_nan_grads=False,
|
print_nan_grads=False,
|
||||||
|
@ -95,6 +97,7 @@ class Trainer(TrainerIOMixin):
|
||||||
:param early_stop_callback: Callback for early stopping
|
:param early_stop_callback: Callback for early stopping
|
||||||
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
|
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
|
||||||
:param gradient_clip_val: int. 0 means don't clip.
|
:param gradient_clip_val: int. 0 means don't clip.
|
||||||
|
:param gradient_clip: int. 0 means don't clip. Deprecated.
|
||||||
:param process_position: shown in the tqdm bar
|
:param process_position: shown in the tqdm bar
|
||||||
:param nb_gpu_nodes: number of GPU nodes
|
:param nb_gpu_nodes: number of GPU nodes
|
||||||
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
||||||
|
@ -113,6 +116,7 @@ class Trainer(TrainerIOMixin):
|
||||||
:param val_check_interval: int. Check val this frequently within a train epoch
|
:param val_check_interval: int. Check val this frequently within a train epoch
|
||||||
:param log_save_interval: int. Writes logs to disk this often
|
:param log_save_interval: int. Writes logs to disk this often
|
||||||
:param row_log_interval: int. How often to add logging rows
|
:param row_log_interval: int. How often to add logging rows
|
||||||
|
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
||||||
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
|
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
|
||||||
:param use_amp: Bool. If true uses apex for 16bit precision
|
:param use_amp: Bool. If true uses apex for 16bit precision
|
||||||
:param print_nan_grads: Bool. Prints nan gradients
|
:param print_nan_grads: Bool. Prints nan gradients
|
||||||
|
@ -124,6 +128,11 @@ class Trainer(TrainerIOMixin):
|
||||||
# Transfer params
|
# Transfer params
|
||||||
self.nb_gpu_nodes = nb_gpu_nodes
|
self.nb_gpu_nodes = nb_gpu_nodes
|
||||||
self.log_gpu_memory = log_gpu_memory
|
self.log_gpu_memory = log_gpu_memory
|
||||||
|
if not (gradient_clip is None):
|
||||||
|
# Backward compatibility
|
||||||
|
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
||||||
|
DeprecationWarning)
|
||||||
|
gradient_clip_val = gradient_clip
|
||||||
self.gradient_clip_val = gradient_clip_val
|
self.gradient_clip_val = gradient_clip_val
|
||||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||||
self.track_grad_norm = track_grad_norm
|
self.track_grad_norm = track_grad_norm
|
||||||
|
@ -242,6 +251,11 @@ class Trainer(TrainerIOMixin):
|
||||||
# logging
|
# logging
|
||||||
self.log_save_interval = log_save_interval
|
self.log_save_interval = log_save_interval
|
||||||
self.val_check_interval = val_check_interval
|
self.val_check_interval = val_check_interval
|
||||||
|
if not (add_row_log_interval is None):
|
||||||
|
# backward compatibility
|
||||||
|
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
||||||
|
DeprecationWarning)
|
||||||
|
row_log_interval = add_row_log_interval
|
||||||
self.row_log_interval = row_log_interval
|
self.row_log_interval = row_log_interval
|
||||||
|
|
||||||
# how much of the data to use
|
# how much of the data to use
|
||||||
|
@ -520,6 +534,16 @@ class Trainer(TrainerIOMixin):
|
||||||
"""
|
"""
|
||||||
return self.__training_tqdm_dict
|
return self.__training_tqdm_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tng_tqdm_dic(self):
|
||||||
|
"""
|
||||||
|
* Deprecated in v0.5.0. use training_tqdm_dict instead. *
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
warnings.warn("tng_tqdm_dict has renamed to training_tqdm_dict since v0.5.0",
|
||||||
|
DeprecationWarning)
|
||||||
|
return self.training_tqdm_dict
|
||||||
|
|
||||||
def __layout_bookeeping(self):
|
def __layout_bookeeping(self):
|
||||||
|
|
||||||
# determine number of training batches
|
# determine number of training batches
|
||||||
|
|
Loading…
Reference in New Issue