parent
f8c058215f
commit
b7d72706c3
|
@ -4,8 +4,8 @@ Runs a model on the CPU on a single node.
|
|||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pl_examples.models.lightning_template import LightningTemplateModel
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
|
||||
seed_everything(234)
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ Runs a model on a single node across multiple gpus.
|
|||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pl_examples.models.lightning_template import LightningTemplateModel
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
|
||||
seed_everything(234)
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ Multi-node example (GPU)
|
|||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pl_examples.models.lightning_template import LightningTemplateModel
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
|
||||
seed_everything(234)
|
||||
|
||||
|
|
|
@ -27,13 +27,10 @@ from pathlib import Path
|
|||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Generator, Union
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning import _logger as log
|
||||
from torch import optim
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -42,6 +39,9 @@ from torchvision import transforms
|
|||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.utils import download_and_extract_archive
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import _logger as log
|
||||
|
||||
BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)
|
||||
DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
"""
|
||||
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import os
|
||||
import random
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
|
|
@ -16,12 +16,9 @@ see the metrics:
|
|||
tensorboard --logdir default
|
||||
"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from typing import Tuple, List
|
||||
|
||||
import argparse
|
||||
from collections import OrderedDict, deque, namedtuple
|
||||
from typing import Tuple, List
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
@ -32,6 +29,8 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
import numpy as np
|
||||
|
@ -7,7 +8,6 @@ import torch.nn.functional as F
|
|||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import random
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pl_examples.models.unet import UNet
|
||||
|
|
|
@ -12,7 +12,6 @@ from torch import optim
|
|||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core import LightningModule
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from pytorch_lightning.accelerators.cpu_backend import CPUBackend
|
||||
from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend
|
||||
from pytorch_lightning.accelerators.ddp_backend import DDPBackend
|
||||
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
|
||||
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
|
||||
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
|
||||
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
|
||||
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
|
||||
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
|
||||
from pytorch_lightning.accelerators.cpu_backend import CPUBackend
|
||||
from pytorch_lightning.accelerators.ddp_backend import DDPBackend
|
||||
from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
# limitations under the License
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
|
|
|
@ -13,16 +13,18 @@
|
|||
# limitations under the License
|
||||
|
||||
import os
|
||||
import torch
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
import numpy as np
|
||||
from os.path import abspath
|
||||
from time import sleep
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning import _logger as log
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
|
||||
from torch import optim
|
||||
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
|
|
|
@ -17,10 +17,10 @@ import os
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning import _logger as log
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -7,7 +7,6 @@ Monitor a validation metric and stop training when it stops improving.
|
|||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
|
@ -7,7 +7,6 @@ Change gradient accumulation factor according to scheduling.
|
|||
"""
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
class GradientAccumulationScheduler(Callback):
|
||||
|
|
|
@ -8,9 +8,8 @@ Log learning rate for lr schedulers during training
|
|||
"""
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class LearningRateLogger(Callback):
|
||||
|
|
|
@ -8,11 +8,11 @@ Automatically save model checkpoints during training.
|
|||
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
|
||||
|
||||
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -9,9 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
||||
UNKNOWN_SIZE = "?"
|
||||
|
|
|
@ -2,11 +2,11 @@ import ast
|
|||
import csv
|
||||
import inspect
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import numbers
|
||||
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
|
||||
from torch import Tensor
|
||||
import torch
|
||||
from copy import copy
|
||||
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.metrics.converters import _sync_ddp_if_available
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from os import environ
|
||||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
|
||||
__all__ = [
|
||||
'LightningLoggerBase',
|
||||
|
|
|
@ -5,13 +5,14 @@ CSV logger
|
|||
CSV logger for basic experiment logging that does not require opening ports
|
||||
|
||||
"""
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import csv
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
|
|
|
@ -1,13 +1,3 @@
|
|||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.regression import (
|
||||
MAE,
|
||||
MSE,
|
||||
PSNR,
|
||||
RMSE,
|
||||
RMSLE,
|
||||
SSIM
|
||||
)
|
||||
from pytorch_lightning.metrics.classification import (
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
|
@ -24,12 +14,22 @@ from pytorch_lightning.metrics.classification import (
|
|||
PrecisionRecall,
|
||||
IoU,
|
||||
)
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
from pytorch_lightning.metrics.regression import (
|
||||
MAE,
|
||||
MSE,
|
||||
PSNR,
|
||||
RMSE,
|
||||
RMSLE,
|
||||
SSIM
|
||||
)
|
||||
from pytorch_lightning.metrics.sklearns import (
|
||||
AUC,
|
||||
PrecisionRecallCurve,
|
||||
SklearnMetric,
|
||||
)
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
|
||||
__classification_metrics = [
|
||||
"AUC",
|
||||
|
|
|
@ -10,8 +10,9 @@ from typing import Union, Any, Callable, Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import np_str_obj_array_pattern
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
try:
|
||||
from torch.distributed import ReduceOp
|
||||
|
|
|
@ -20,6 +20,7 @@ from pytorch_lightning.metrics.functional.classification import (
|
|||
to_onehot,
|
||||
iou,
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
from pytorch_lightning.metrics.functional.regression import (
|
||||
mae,
|
||||
mse,
|
||||
|
@ -28,4 +29,3 @@ from pytorch_lightning.metrics.functional.regression import (
|
|||
rmsle,
|
||||
ssim
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import sys
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
# Authors: torchtext authors and @sluks
|
||||
# Date: 2020-07-18
|
||||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
|
||||
from typing import Sequence, List
|
||||
from collections import Counter
|
||||
from typing import Sequence, List
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import itertools
|
||||
import threading
|
||||
from itertools import chain
|
||||
from collections import Mapping, Iterable
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch.cuda._utils import _get_device_index
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel._functions import Gather
|
||||
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from abc import ABC
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
|
||||
|
||||
|
||||
class TrainerAMPMixin(ABC):
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class ConfigValidator(object):
|
||||
|
|
|
@ -131,23 +131,17 @@ import os
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union, List, Optional, Callable, Tuple
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
import numpy as np
|
||||
from os.path import abspath
|
||||
from pkg_resources import parse_version
|
||||
from typing import Union, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -18,25 +18,23 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
|
|||
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import ExitStack
|
||||
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -131,13 +131,10 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict
|
||||
from torch import distributed as dist
|
||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from abc import ABC
|
||||
from typing import Union, Iterable
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
"""
|
||||
Trainer Learning Rate Finder
|
||||
"""
|
||||
import os
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence, Tuple, List, Union
|
||||
|
||||
|
|
|
@ -20,9 +20,10 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.accelerators import (
|
||||
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend)
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -33,6 +34,7 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, Simple
|
|||
from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin
|
||||
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
||||
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
|
||||
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
||||
|
@ -50,9 +52,6 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
|||
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
from pytorch_lightning.accelerators import (
|
||||
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend)
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
|
|
@ -89,15 +89,14 @@ import signal
|
|||
from abc import ABC
|
||||
from distutils.version import LooseVersion
|
||||
from subprocess import call
|
||||
from pkg_resources import parse_version
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDistributedDataParallel,
|
||||
|
|
|
@ -157,29 +157,28 @@ in your model.
|
|||
trainer = Trainer(terminate_on_nan=True)
|
||||
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from typing import Callable
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.distributed as torch_distrib
|
||||
from copy import copy
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -13,19 +13,16 @@
|
|||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
import gc
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
import numpy
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import copy
|
||||
|
@ -5,8 +6,6 @@ from typing import Any, Callable, Union
|
|||
|
||||
import torch
|
||||
|
||||
import importlib
|
||||
|
||||
TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
|
||||
if TORCHTEXT_AVAILABLE:
|
||||
from torchtext.data import Batch
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load(path_or_url: str, map_location=None):
|
||||
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from functools import wraps
|
||||
import warnings
|
||||
from pytorch_lightning import _logger as log
|
||||
import os
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
|
||||
|
||||
def rank_zero_only(fn):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"""Helper functions to help with reproducibility of models. """
|
||||
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
|
|
Loading…
Reference in New Issue