* Provide backward compatibility for e681253
* typo fix
This commit is contained in:
parent
67f6e7bb19
commit
4103a5ca73
|
@ -7,6 +7,8 @@ from pytorch_lightning.root_module.model_saving import ModelIO
|
|||
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||
from pytorch_lightning.root_module.decorators import data_loader
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||
|
||||
|
@ -110,13 +112,29 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
"""
|
||||
Implement a PyTorch DataLoader
|
||||
* Deprecated in v0.5.0. use train_dataloader instead. *
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@data_loader
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Implement a PyTorch DataLoader
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
#
|
||||
try:
|
||||
output = self.tng_dataloader()
|
||||
warnings.warn("tng_dataloader has been renamed to train_dataloader since v0.5.0",
|
||||
DeprecationWarning)
|
||||
return output
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError
|
||||
|
||||
@data_loader
|
||||
def test_dataloader(self):
|
||||
|
|
|
@ -63,6 +63,7 @@ class Trainer(TrainerIOMixin):
|
|||
early_stop_callback=True,
|
||||
default_save_path=None,
|
||||
gradient_clip_val=0,
|
||||
gradient_clip=None, # backward compatible
|
||||
process_position=0,
|
||||
nb_gpu_nodes=1,
|
||||
gpus=None,
|
||||
|
@ -81,6 +82,7 @@ class Trainer(TrainerIOMixin):
|
|||
val_check_interval=1.0,
|
||||
log_save_interval=100,
|
||||
row_log_interval=10,
|
||||
add_row_log_interval=None, # backward compatible
|
||||
distributed_backend=None,
|
||||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
|
@ -95,6 +97,7 @@ class Trainer(TrainerIOMixin):
|
|||
:param early_stop_callback: Callback for early stopping
|
||||
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
|
||||
:param gradient_clip_val: int. 0 means don't clip.
|
||||
:param gradient_clip: int. 0 means don't clip. Deprecated.
|
||||
:param process_position: shown in the tqdm bar
|
||||
:param nb_gpu_nodes: number of GPU nodes
|
||||
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
||||
|
@ -113,6 +116,7 @@ class Trainer(TrainerIOMixin):
|
|||
:param val_check_interval: int. Check val this frequently within a train epoch
|
||||
:param log_save_interval: int. Writes logs to disk this often
|
||||
:param row_log_interval: int. How often to add logging rows
|
||||
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
||||
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
|
||||
:param use_amp: Bool. If true uses apex for 16bit precision
|
||||
:param print_nan_grads: Bool. Prints nan gradients
|
||||
|
@ -124,6 +128,11 @@ class Trainer(TrainerIOMixin):
|
|||
# Transfer params
|
||||
self.nb_gpu_nodes = nb_gpu_nodes
|
||||
self.log_gpu_memory = log_gpu_memory
|
||||
if not (gradient_clip is None):
|
||||
# Backward compatibility
|
||||
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
||||
DeprecationWarning)
|
||||
gradient_clip_val = gradient_clip
|
||||
self.gradient_clip_val = gradient_clip_val
|
||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.track_grad_norm = track_grad_norm
|
||||
|
@ -242,6 +251,11 @@ class Trainer(TrainerIOMixin):
|
|||
# logging
|
||||
self.log_save_interval = log_save_interval
|
||||
self.val_check_interval = val_check_interval
|
||||
if not (add_row_log_interval is None):
|
||||
# backward compatibility
|
||||
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
|
||||
DeprecationWarning)
|
||||
row_log_interval = add_row_log_interval
|
||||
self.row_log_interval = row_log_interval
|
||||
|
||||
# how much of the data to use
|
||||
|
@ -520,6 +534,16 @@ class Trainer(TrainerIOMixin):
|
|||
"""
|
||||
return self.__training_tqdm_dict
|
||||
|
||||
@property
|
||||
def tng_tqdm_dic(self):
|
||||
"""
|
||||
* Deprecated in v0.5.0. use training_tqdm_dict instead. *
|
||||
:return:
|
||||
"""
|
||||
warnings.warn("tng_tqdm_dict has renamed to training_tqdm_dict since v0.5.0",
|
||||
DeprecationWarning)
|
||||
return self.training_tqdm_dict
|
||||
|
||||
def __layout_bookeeping(self):
|
||||
|
||||
# determine number of training batches
|
||||
|
|
Loading…
Reference in New Issue