fix duplicate console logging bug v2 (#6275)
Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
0f9134e043
commit
bc577ca792
|
@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
|
||||
|
||||
|
||||
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
|
||||
|
||||
|
||||
## [1.2.1] - 2021-02-23
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -259,13 +259,19 @@ Configure console logging
|
|||
*************************
|
||||
|
||||
Lightning logs useful information about the training process and user warnings to the console.
|
||||
You can retrieve the Lightning logger and change it to your liking. For example, increase the logging level
|
||||
to see fewer messages like so:
|
||||
You can retrieve the Lightning logger and change it to your liking. For example, adjust the logging level
|
||||
or redirect output for certain modules to log files:
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
import logging
|
||||
logging.getLogger("lightning").setLevel(logging.ERROR)
|
||||
|
||||
# configure logging at the root level of lightning
|
||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||
|
||||
# configure logging on module level, redirect to file
|
||||
logger = logging.getLogger("pytorch_lightning.core")
|
||||
logger.addHandler(logging.FileHandler("core.log"))
|
||||
|
||||
Read more about custom Python logging `here <https://docs.python.org/3/library/logging.html>`_.
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ Note:
|
|||
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
@ -54,11 +55,11 @@ from torchvision.datasets.utils import download_and_extract_archive
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pl_examples import cli_lightning_logo
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
|
||||
|
||||
# --- Finetuning Callback ---
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Root package info."""
|
||||
|
||||
import logging as python_logging
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
_this_year = time.strftime("%Y")
|
||||
|
@ -37,10 +38,14 @@ Documentation
|
|||
- https://pytorch-lightning.readthedocs.io/en/latest
|
||||
- https://pytorch-lightning.readthedocs.io/en/stable
|
||||
"""
|
||||
_root_logger = logging.getLogger()
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
_logger = python_logging.getLogger("lightning")
|
||||
_logger.addHandler(python_logging.StreamHandler())
|
||||
_logger.setLevel(python_logging.INFO)
|
||||
# if root logger has handlers, propagate messages up and let root logger process them
|
||||
if not _root_logger.hasHandlers():
|
||||
_logger.addHandler(logging.StreamHandler())
|
||||
_logger.propagate = False
|
||||
|
||||
_PACKAGE_ROOT = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
|
||||
|
@ -53,9 +58,7 @@ try:
|
|||
except NameError:
|
||||
__LIGHTNING_SETUP__: bool = False
|
||||
|
||||
if __LIGHTNING_SETUP__:
|
||||
import sys # pragma: no-cover
|
||||
|
||||
if __LIGHTNING_SETUP__: # pragma: no-cover
|
||||
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
|
||||
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
|
||||
else:
|
||||
|
|
|
@ -16,6 +16,7 @@ Finetuning Callback
|
|||
^^^^^^^^^^^^^^^^^^^^
|
||||
Freeze and unfreeze models for finetuning purposes
|
||||
"""
|
||||
import logging
|
||||
from typing import Callable, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -24,12 +25,13 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||
from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def multiplicative(epoch):
|
||||
return 2
|
||||
|
|
|
@ -18,7 +18,7 @@ Model Checkpointing
|
|||
Automatically save model checkpoints during training.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
@ -29,13 +29,13 @@ import numpy as np
|
|||
import torch
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ ModelPruning
|
|||
^^^^^^^^^^^^
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
@ -24,12 +25,13 @@ import torch
|
|||
import torch.nn.utils.prune as pytorch_prune
|
||||
from torch import nn
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PYTORCH_PRUNING_FUNCTIONS = {
|
||||
"ln_structured": pytorch_prune.ln_structured,
|
||||
"l1_unstructured": pytorch_prune.l1_unstructured,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
@ -30,7 +31,6 @@ from torch import ScriptModule, Tensor
|
|||
from torch.nn import Module
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
|
@ -45,6 +45,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightningModule(
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import ast
|
||||
import csv
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
|
@ -25,13 +26,13 @@ from warnings import warn
|
|||
import torch
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ Comet Logger
|
|||
------------
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
@ -23,12 +24,12 @@ from typing import Any, Dict, Optional, Union
|
|||
import torch
|
||||
from torch import is_tensor
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_COMET_AVAILABLE = _module_available("comet_ml")
|
||||
|
||||
if _COMET_AVAILABLE:
|
||||
|
|
|
@ -20,17 +20,19 @@ CSV logger for basic experiment logging that does not require opening ports
|
|||
"""
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentWriter(object):
|
||||
r"""
|
||||
|
|
|
@ -15,17 +15,17 @@
|
|||
MLflow Logger
|
||||
-------------
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from time import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
||||
_MLFLOW_AVAILABLE = _module_available("mlflow")
|
||||
try:
|
||||
import mlflow
|
||||
|
|
|
@ -15,16 +15,17 @@
|
|||
Neptune Logger
|
||||
--------------
|
||||
"""
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
_NEPTUNE_AVAILABLE = _module_available("neptune")
|
||||
|
||||
if _NEPTUNE_AVAILABLE:
|
||||
|
|
|
@ -16,6 +16,7 @@ TensorBoard Logger
|
|||
------------------
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
@ -24,13 +25,14 @@ import torch
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.tensorboard.summary import hparams
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container, OmegaConf
|
||||
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SLURMEnvironment(ClusterEnvironment):
|
||||
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorchElasticEnvironment(ClusterEnvironment):
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -23,7 +24,6 @@ import torch.distributed as torch_distrib
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
|
@ -44,6 +44,9 @@ if _HYDRA_AVAILABLE:
|
|||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DDPPlugin(ParallelPlugin):
|
||||
"""
|
||||
Plugin for multi-process single-device training on one or multiple nodes.
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
@ -21,7 +22,6 @@ import torch.multiprocessing as mp
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
|
@ -39,6 +39,8 @@ from pytorch_lightning.utilities.distributed import (
|
|||
)
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DDPSpawnPlugin(ParallelPlugin):
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import cProfile
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import pstats
|
||||
import time
|
||||
|
@ -25,9 +26,10 @@ from typing import Optional, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
"""
|
||||
|
|
|
@ -14,18 +14,20 @@
|
|||
"""Profiler to check if there are any bottlenecks in your code."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.profiler.profilers import BaseProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PyTorchProfiler(BaseProfiler):
|
||||
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.accelerators.gpu import GPUAccelerator
|
||||
|
@ -61,6 +61,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AcceleratorConnector(object):
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import signal
|
||||
from subprocess import call
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SLURMConnector:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trainer to automate the training."""
|
||||
import logging
|
||||
import warnings
|
||||
from itertools import count
|
||||
from pathlib import Path
|
||||
|
@ -20,7 +21,6 @@ from typing import Dict, Iterable, List, Optional, Union
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators import Accelerator
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
|
@ -63,6 +63,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
'ignore', message='torch.distributed.reduce_op is deprecated, '
|
||||
|
|
|
@ -12,16 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
EPSILON = 1e-6
|
||||
EPSILON_FP16 = 1e-5
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainerTrainingTricksMixin(ABC):
|
||||
|
|
|
@ -11,10 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
|
||||
|
@ -24,6 +24,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
||||
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scale_batch_size(
|
||||
trainer,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
|
@ -22,7 +23,6 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -39,6 +39,8 @@ if importlib.util.find_spec('ipywidgets') is not None:
|
|||
else:
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
|
||||
if isinstance(trainer.auto_lr_find, str):
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
@ -19,7 +20,7 @@ from typing import Any, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import group, ReduceOp
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Helper functions to help with reproducibility of models. """
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
|
@ -20,9 +21,10 @@ from typing import Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None) -> int:
|
||||
"""
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
KEYS_MAPPING = {
|
||||
|
@ -27,6 +27,8 @@ KEYS_MAPPING = {
|
|||
"early_stop_callback_patience": (EarlyStopping, "patience"),
|
||||
}
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade_checkpoint(filepath):
|
||||
checkpoint = torch.load(filepath)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
@ -31,3 +32,5 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
|||
|
||||
if not os.path.isdir(_TEMP_PATH):
|
||||
os.mkdir(_TEMP_PATH)
|
||||
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from unittest import mock
|
||||
|
@ -20,7 +21,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger, seed_everything, Trainer
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -29,6 +30,8 @@ from tests.helpers.datamodules import ClassifDataModule
|
|||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.simple_models import ClassificationModel
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
# this class has to be defined outside the test function, otherwise we get pickle error
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
@ -677,7 +678,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
|
|||
callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)],
|
||||
max_epochs=max_epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
with caplog.at_level(logging.INFO):
|
||||
trainer.fit(model)
|
||||
assert caplog.messages.count('Saving latest checkpoint...') == save_last
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
@ -95,7 +95,8 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5):
|
|||
|
||||
def test_simple_profiler_describe(caplog, simple_profiler):
|
||||
"""Ensure the profiler won't fail when reporting the summary."""
|
||||
simple_profiler.describe()
|
||||
with caplog.at_level(logging.INFO):
|
||||
simple_profiler.describe()
|
||||
|
||||
assert "Profiler Report" in caplog.text
|
||||
|
||||
|
|
Loading…
Reference in New Issue