remove deprecated in v0.9 (#2760)
* remove deprecated in v0.9 * data_loader * import * hook * args
This commit is contained in:
parent
d18b9ef9d9
commit
949734489a
|
@ -141,9 +141,6 @@ exclude_patterns = [
|
|||
'api/pytorch_lightning.accelerator_backends.*',
|
||||
'api/modules.rst',
|
||||
'PULL_REQUEST_TEMPLATE.md',
|
||||
|
||||
# deprecated/renamed:
|
||||
'api/pytorch_lightning.logging.*', # TODO: remove in v0.9.0
|
||||
]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
|
|
|
@ -429,7 +429,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
|
||||
trainer = pl.Trainer(
|
||||
weights_summary=None,
|
||||
show_progress_bar=True,
|
||||
progress_bar_refresh_rate=1,
|
||||
num_sanity_val_steps=0,
|
||||
gpus=args.gpus,
|
||||
min_epochs=args.nb_epochs,
|
||||
|
|
|
@ -53,7 +53,7 @@ if __LIGHTNING_SETUP__:
|
|||
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
|
||||
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
|
||||
else:
|
||||
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
|
||||
from pytorch_lightning.core import LightningDataModule, LightningModule
|
||||
from pytorch_lightning.core.step_result import TrainResult, EvalResult
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
@ -65,7 +65,6 @@ else:
|
|||
'LightningDataModule',
|
||||
'LightningModule',
|
||||
'Callback',
|
||||
'data_loader',
|
||||
'seed_everything',
|
||||
'metrics',
|
||||
'EvalResult',
|
||||
|
|
|
@ -302,12 +302,10 @@ LightningModule Class
|
|||
"""
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.decorators import data_loader
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
__all__ = [
|
||||
'LightningDataModule',
|
||||
'LightningModule',
|
||||
'data_loader',
|
||||
]
|
||||
# __call__ = __all__
|
||||
|
|
|
@ -2,20 +2,6 @@ from functools import wraps
|
|||
from typing import Callable
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
def data_loader(fn):
|
||||
"""Decorator to make any fx with this use the lazy property.
|
||||
|
||||
Warnings:
|
||||
This decorator deprecated in v0.7.0 and it will be removed v0.9.0.
|
||||
"""
|
||||
rank_zero_warn("`data_loader` decorator deprecated in v0.7.0. It will be removed in v0.9.0", DeprecationWarning)
|
||||
|
||||
def inner_fx(self):
|
||||
return fn(self)
|
||||
return inner_fx
|
||||
|
||||
|
||||
def auto_move_data(fn: Callable) -> Callable:
|
||||
|
|
|
@ -65,15 +65,6 @@ class ModelHooks(Module):
|
|||
If on DDP it is called on every process
|
||||
"""
|
||||
|
||||
# TODO: remove in v0.9.0
|
||||
def on_sanity_check_start(self):
|
||||
"""
|
||||
Called before starting evaluation.
|
||||
|
||||
Warning:
|
||||
Deprecated. Will be removed in v0.9.0.
|
||||
"""
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
"""
|
||||
Called at the beginning of training before sanity check.
|
||||
|
|
|
@ -31,19 +31,6 @@ class ModelIO(object):
|
|||
CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'
|
||||
CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'
|
||||
|
||||
@classmethod
|
||||
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
||||
r"""
|
||||
Warning:
|
||||
Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead.
|
||||
Will be removed in v0.9.0.
|
||||
"""
|
||||
rank_zero_warn(
|
||||
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
|
||||
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
|
||||
)
|
||||
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(
|
||||
cls,
|
||||
|
@ -51,7 +38,6 @@ class ModelIO(object):
|
|||
*args,
|
||||
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
||||
hparams_file: Optional[str] = None,
|
||||
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
|
@ -84,21 +70,6 @@ class ModelIO(object):
|
|||
If your model's `hparams` argument is :class:`~argparse.Namespace`
|
||||
and .yaml file has hierarchical structure, you need to refactor your model to treat
|
||||
`hparams` as :class:`~dict`.
|
||||
|
||||
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
|
||||
tags_csv:
|
||||
.. warning:: .. deprecated:: 0.7.6
|
||||
|
||||
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
|
||||
|
||||
Optional path to a .csv file with two columns (key, value)
|
||||
as in this example::
|
||||
|
||||
key,value
|
||||
drop_prob,0.2
|
||||
batch_size,32
|
||||
|
||||
Use this method to pass in a .csv file with the hparams you'd like to use.
|
||||
hparam_overrides: A dictionary with keys to override in the hparams
|
||||
kwargs: Any keyword args needed to init the model.
|
||||
|
||||
|
@ -141,11 +112,6 @@ class ModelIO(object):
|
|||
else:
|
||||
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
# add the hparams from csv file to checkpoint
|
||||
if tags_csv is not None:
|
||||
hparams_file = tags_csv
|
||||
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)
|
||||
|
||||
if hparams_file is not None:
|
||||
extension = hparams_file.split('.')[-1]
|
||||
if extension.lower() in ('csv'):
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0.
|
||||
The deprecated package name will be removed in v0.9.0.
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0"
|
||||
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers import * # noqa: F403 E402
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging.comet` module has been renamed to `loggers.comet` since v0.7.0."
|
||||
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers.comet import CometLogger # noqa: F403 E402
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging.mlflow` module has been renamed to `loggers.mlflow` since v0.7.0."
|
||||
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403 E402
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging.neptune` module has been renamed to `loggers.neptune` since v0.7.0."
|
||||
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403 E402
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging.test_tube` module has been renamed to `loggers.test_tube` since v0.7.0."
|
||||
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403 E402
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0
|
||||
"""
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
rank_zero_warn("`logging.wandb` module has been renamed to `loggers.wandb` since v0.7.0."
|
||||
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
|
||||
|
||||
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403 E402
|
|
@ -616,12 +616,6 @@ The Trainer uses 2 steps by default. Turn it off or modify it here.
|
|||
# check all validation data
|
||||
trainer = Trainer(num_sanity_val_steps=-1)
|
||||
|
||||
num_tpu_cores
|
||||
^^^^^^^^^^^^^
|
||||
.. warning:: .. deprecated:: 0.7.6
|
||||
|
||||
Use `tpu_cores` instead. Will remove 0.9.0.
|
||||
|
||||
Example::
|
||||
|
||||
python -m torch_xla.distributed.xla_dist
|
||||
|
@ -737,15 +731,6 @@ Example::
|
|||
# one day
|
||||
trainer = Trainer(precision=8|4|2)
|
||||
|
||||
print_nan_grads
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. warning:: .. deprecated:: 0.7.2.
|
||||
|
||||
Has no effect. When detected, NaN grads will be printed automatically.
|
||||
Will remove 0.9.0.
|
||||
|
||||
|
||||
process_position
|
||||
^^^^^^^^^^^^^^^^
|
||||
Orders the progress bar. Useful when running multiple trainers on the same node.
|
||||
|
@ -853,18 +838,6 @@ How often to add logging rows (does not write to disk)
|
|||
# default used by the Trainer
|
||||
trainer = Trainer(row_log_interval=50)
|
||||
|
||||
use_amp:
|
||||
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `precision` instead. Will remove 0.9.0.
|
||||
|
||||
show_progress_bar
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. warning:: .. deprecated:: 0.7.2
|
||||
|
||||
Set `progress_bar_refresh_rate` to 0 instead. Will remove 0.9.0.
|
||||
|
||||
val_percent_check
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -21,7 +21,6 @@ class TrainerAMPMixin(ABC):
|
|||
log.info('Using native 16bit precision.')
|
||||
return
|
||||
|
||||
# TODO: replace `use_amp` by `precision` all below for v0.9.0
|
||||
if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover
|
||||
raise ModuleNotFoundError(
|
||||
"You set `use_amp=True` but do not have apex installed."
|
||||
|
|
|
@ -6,40 +6,6 @@ from typing import Union
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
class TrainerDeprecatedAPITillVer0_9(ABC):
|
||||
progress_bar_dict: ...
|
||||
progress_bar_callback: ...
|
||||
|
||||
def __init__(self):
|
||||
super().__init__() # mixin calls super too
|
||||
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
"""Back compatibility, will be removed in v0.9.0"""
|
||||
rank_zero_warn("Attribute `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
|
||||
" and this method will be removed in v0.9.0", DeprecationWarning)
|
||||
return self.progress_bar_callback and self.progress_bar_callback.refresh_rate >= 1
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, tf):
|
||||
"""Back compatibility, will be removed in v0.9.0"""
|
||||
rank_zero_warn("Attribute `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
|
||||
" and this method will be removed in v0.9.0", DeprecationWarning)
|
||||
|
||||
@property
|
||||
def training_tqdm_dict(self):
|
||||
"""Back compatibility, will be removed in v0.9.0"""
|
||||
rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3"
|
||||
" and this method will be removed in v0.9.0", DeprecationWarning)
|
||||
return self.progress_bar_dict
|
||||
|
||||
@property
|
||||
def num_tpu_cores(self):
|
||||
"""Back compatibility, will be removed in v0.9.0"""
|
||||
rank_zero_warn("Attribute `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
|
||||
" and this argument will be removed in v0.9.0", DeprecationWarning)
|
||||
|
||||
|
||||
class TrainerDeprecatedAPITillVer0_10(ABC):
|
||||
limit_val_batches: Union[int, float]
|
||||
limit_test_batches: Union[int, float]
|
||||
|
|
|
@ -137,12 +137,6 @@ class TrainerLRFinderMixin(ABC):
|
|||
trainer.fit(model)
|
||||
|
||||
"""
|
||||
if num_accumulation_steps is not None:
|
||||
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
|
||||
" since v0.7.6 and will be removed in 0.9. Please"
|
||||
" set trainer argument `accumulate_grad_batches` instead.",
|
||||
DeprecationWarning)
|
||||
|
||||
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')
|
||||
|
||||
self.__lr_finder_dump_params(model)
|
||||
|
|
|
@ -34,7 +34,7 @@ from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, T
|
|||
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
||||
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_9, TrainerDeprecatedAPITillVer0_10
|
||||
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
|
||||
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
||||
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin, _parse_gpu_ids, _parse_tpu_cores,
|
||||
determine_root_gpu_device, pick_multiple_gpus)
|
||||
|
@ -98,7 +98,6 @@ class Trainer(
|
|||
TrainerTrainLoopMixin,
|
||||
TrainerCallbackConfigMixin,
|
||||
TrainerLRFinderMixin,
|
||||
TrainerDeprecatedAPITillVer0_9,
|
||||
TrainerDeprecatedAPITillVer0_10,
|
||||
):
|
||||
"""
|
||||
|
@ -153,8 +152,6 @@ class Trainer(
|
|||
25
|
||||
"""
|
||||
|
||||
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
|
||||
|
@ -188,7 +185,6 @@ class Trainer(
|
|||
row_log_interval: int = 50,
|
||||
distributed_backend: Optional[str] = None,
|
||||
precision: int = 32,
|
||||
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
|
||||
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
|
||||
weights_save_path: Optional[str] = None,
|
||||
num_sanity_val_steps: int = 2,
|
||||
|
@ -204,9 +200,6 @@ class Trainer(
|
|||
auto_scale_batch_size: Union[str, bool] = False,
|
||||
prepare_data_per_node: bool = True,
|
||||
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
|
||||
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
|
||||
use_amp=None, # backward compatible, todo: remove in v0.9.0
|
||||
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
|
||||
val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
||||
test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
||||
train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
||||
|
@ -230,20 +223,10 @@ class Trainer(
|
|||
|
||||
gradient_clip_val: 0 means don't clip.
|
||||
|
||||
gradient_clip:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `gradient_clip_val` instead. Will remove 0.9.0.
|
||||
|
||||
process_position: orders the progress bar when running multiple models on same machine.
|
||||
|
||||
num_nodes: number of GPU nodes for distributed training.
|
||||
|
||||
nb_gpu_nodes:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `num_nodes` instead. Will remove 0.9.0.
|
||||
|
||||
gpus: Which GPUs to train on.
|
||||
|
||||
auto_select_gpus:
|
||||
|
@ -255,16 +238,8 @@ class Trainer(
|
|||
|
||||
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
|
||||
|
||||
num_tpu_cores: How many TPU cores to train on (1 or 8)
|
||||
.. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.
|
||||
|
||||
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
||||
|
||||
show_progress_bar:
|
||||
.. warning:: .. deprecated:: 0.7.2
|
||||
|
||||
Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0.
|
||||
|
||||
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
||||
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
|
||||
|
||||
|
@ -285,18 +260,8 @@ class Trainer(
|
|||
|
||||
max_epochs: Stop training once this number of epochs is reached.
|
||||
|
||||
max_nb_epochs:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `max_epochs` instead. Will remove 0.9.0.
|
||||
|
||||
min_epochs: Force training for at least these many epochs
|
||||
|
||||
min_nb_epochs:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `min_epochs` instead. Will remove 0.9.0.
|
||||
|
||||
max_steps: Stop training after this number of steps. Disabled by default (None).
|
||||
|
||||
min_steps: Force training for at least these number of steps. Disabled by default (None).
|
||||
|
@ -328,26 +293,10 @@ class Trainer(
|
|||
|
||||
row_log_interval: How often to add logging rows (does not write to disk)
|
||||
|
||||
add_row_log_interval:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `row_log_interval` instead. Will remove 0.9.0.
|
||||
|
||||
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
|
||||
|
||||
use_amp:
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
|
||||
Use `precision` instead. Will remove 0.9.0.
|
||||
|
||||
precision: Full precision (32), half precision (16).
|
||||
|
||||
print_nan_grads:
|
||||
.. warning:: .. deprecated:: 0.7.2
|
||||
|
||||
Has no effect. When detected, NaN grads will be printed automatically.
|
||||
Will remove 0.9.0.
|
||||
|
||||
weights_summary: Prints a summary of the weights when training begins.
|
||||
|
||||
weights_save_path: Where to save weights if specified. Will override default_root_dir
|
||||
|
@ -480,16 +429,6 @@ class Trainer(
|
|||
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
|
||||
self.track_grad_norm = float(track_grad_norm)
|
||||
|
||||
# tpu config
|
||||
if num_tpu_cores is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
|
||||
" and this argument will be removed in v0.9.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if tpu_cores is None:
|
||||
tpu_cores = num_tpu_cores
|
||||
self.tpu_cores = _parse_tpu_cores(tpu_cores)
|
||||
self.on_tpu = self.tpu_cores is not None
|
||||
|
||||
|
@ -511,14 +450,6 @@ class Trainer(
|
|||
else:
|
||||
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
|
||||
|
||||
# Backward compatibility, TODO: remove in v0.9.0
|
||||
if print_nan_grads:
|
||||
rank_zero_warn(
|
||||
"Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
|
||||
" NaN grads will be printed automatically when detected.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
||||
|
||||
self.auto_lr_find = auto_lr_find
|
||||
|
@ -587,10 +518,6 @@ class Trainer(
|
|||
# NVIDIA setup
|
||||
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
||||
|
||||
# backward compatibility
|
||||
if show_progress_bar is not None:
|
||||
self.show_progress_bar = show_progress_bar
|
||||
|
||||
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
||||
|
||||
# logging
|
||||
|
@ -651,15 +578,6 @@ class Trainer(
|
|||
self.precision = precision
|
||||
self.scaler = None
|
||||
|
||||
# Backward compatibility, TODO: remove in v0.9.0
|
||||
if use_amp is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `use_amp` is now set by `precision` since v0.7.0"
|
||||
" and this method will be removed in v0.9.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.precision = 16 if use_amp else 32
|
||||
|
||||
self.amp_level = amp_level
|
||||
self.init_amp()
|
||||
|
||||
|
@ -729,7 +647,6 @@ class Trainer(
|
|||
...
|
||||
('precision', (<class 'int'>,), 32),
|
||||
('prepare_data_per_node', (<class 'bool'>,), True),
|
||||
('print_nan_grads', (<class 'bool'>,), False),
|
||||
('process_position', (<class 'int'>,), 0),
|
||||
('profiler',
|
||||
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
|
||||
|
@ -1264,7 +1181,6 @@ class Trainer(
|
|||
|
||||
# hook and callback
|
||||
self.running_sanity_check = True
|
||||
ref_model.on_sanity_check_start()
|
||||
self.on_sanity_check_start()
|
||||
|
||||
num_loaders = len(self.val_dataloaders)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
torchvision>=0.4.0
|
||||
torchvision>=0.4.0, <0.7
|
||||
gym>=0.17.0
|
|
@ -55,44 +55,6 @@ def test_tbd_remove_in_v0_10_0_trainer():
|
|||
assert trainer.ckpt_path == trainer.weights_save_path == 'foo'
|
||||
|
||||
|
||||
def test_tbd_remove_in_v0_9_0_trainer():
|
||||
# test show_progress_bar set by progress_bar_refresh_rate
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
trainer = Trainer(progress_bar_refresh_rate=0, show_progress_bar=True)
|
||||
assert not getattr(trainer, 'show_progress_bar')
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
trainer = Trainer(progress_bar_refresh_rate=50, show_progress_bar=False)
|
||||
assert getattr(trainer, 'show_progress_bar')
|
||||
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
trainer = Trainer(num_tpu_cores=8)
|
||||
assert trainer.tpu_cores == 8
|
||||
|
||||
|
||||
def test_tbd_remove_in_v0_9_0_module_imports():
|
||||
_soft_unimport_module("pytorch_lightning.core.decorators")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.core.decorators import data_loader # noqa: F811
|
||||
data_loader(print)
|
||||
|
||||
_soft_unimport_module("pytorch_lightning.logging.comet")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.logging.comet import CometLogger # noqa: F402
|
||||
_soft_unimport_module("pytorch_lightning.logging.mlflow")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.logging.mlflow import MLFlowLogger # noqa: F402
|
||||
_soft_unimport_module("pytorch_lightning.logging.neptune")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402
|
||||
_soft_unimport_module("pytorch_lightning.logging.test_tube")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
|
||||
_soft_unimport_module("pytorch_lightning.logging.wandb")
|
||||
with pytest.deprecated_call(match='will be removed in v0.9.0'):
|
||||
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
|
||||
|
||||
|
||||
class ModelVer0_6(EvalModelTemplate):
|
||||
|
||||
# todo: this shall not be needed while evaluate asks for dataloader explicitly
|
||||
|
@ -130,7 +92,7 @@ class ModelVer0_7(EvalModelTemplate):
|
|||
def test_end(self, outputs):
|
||||
return {'test_loss': torch.tensor(0.7)}
|
||||
|
||||
#
|
||||
|
||||
# def test_tbd_remove_in_v1_0_0_model_hooks():
|
||||
#
|
||||
# model = ModelVer0_6()
|
||||
|
|
|
@ -32,7 +32,7 @@ def test_default_args(mock_argparse, tmpdir):
|
|||
|
||||
@pytest.mark.parametrize('cli_args', [
|
||||
['--accumulate_grad_batches=22'],
|
||||
['--print_nan_grads', '--weights_save_path=./'],
|
||||
['--weights_save_path=./'],
|
||||
[]
|
||||
])
|
||||
def test_add_argparse_args_redefined(cli_args):
|
||||
|
|
Loading…
Reference in New Issue