attempt mypy fix
This commit is contained in:
parent
b52fa07d6e
commit
fb7ed827d9
|
@ -21,23 +21,25 @@ import json
|
|||
import os
|
||||
import sys
|
||||
|
||||
from pytorch_lightning import Trainer # noqa: E402
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
|
||||
from tests.base import EvalModelTemplate # noqa: E402
|
||||
from tests.base.develop_pipelines import run_prediction # noqa: E402
|
||||
from tests.base.develop_utils import (reset_seed, # noqa: E402
|
||||
set_random_master_port)
|
||||
|
||||
# this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do
|
||||
PYTHONPATH = os.getenv('PYTHONPATH', '')
|
||||
if ':' in PYTHONPATH:
|
||||
sys.path = PYTHONPATH.split(':') + sys.path
|
||||
|
||||
from pytorch_lightning import Trainer # noqa: E402
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd # noqa: E402
|
||||
else:
|
||||
print('You requested to import Horovod which is missing or not supported for your OS.')
|
||||
|
||||
from tests.base import EvalModelTemplate # noqa: E402
|
||||
from tests.base.develop_pipelines import run_prediction # noqa: E402
|
||||
from tests.base.develop_utils import set_random_master_port, reset_seed # noqa: E402
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -45,7 +47,7 @@ parser.add_argument('--trainer-options', required=True)
|
|||
parser.add_argument('--on-gpu', action='store_true', default=False)
|
||||
|
||||
|
||||
def run_test_from_config(trainer_options):
|
||||
def run_test_from_config(trainer_options) -> None:
|
||||
"""Trains the default model with the given config."""
|
||||
set_random_master_port()
|
||||
reset_seed()
|
||||
|
|
|
@ -15,22 +15,25 @@
|
|||
Tests to ensure that the training loop works with a dict (1.0)
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import \
|
||||
CallbackHookNameValidator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
F = TypeVar('F', bound=Callable[..., Any])
|
||||
|
||||
|
||||
def decorator_with_arguments(fx_name='', hook_fx_name=None):
|
||||
def decorator(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def decorator_with_arguments(fx_name='', hook_fx_name=None) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
def wrapper(self, *args, **kwargs) -> Any:
|
||||
# Set information
|
||||
self._current_fx_name = fx_name
|
||||
self._current_hook_fx_name = hook_fx_name
|
||||
|
@ -47,7 +50,7 @@ def decorator_with_arguments(fx_name='', hook_fx_name=None):
|
|||
return decorator
|
||||
|
||||
|
||||
def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
|
||||
def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch) -> None:
|
||||
"""
|
||||
Tests that LoggerConnector will properly capture logged information
|
||||
and reduce them
|
||||
|
@ -59,7 +62,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
|
|||
train_losses = []
|
||||
|
||||
@decorator_with_arguments(fx_name="training_step")
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
|
||||
|
@ -69,7 +72,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
|
|||
|
||||
return {"loss": loss}
|
||||
|
||||
def training_step_end(self, *_):
|
||||
def training_step_end(self, *_) -> None:
|
||||
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)
|
||||
|
||||
model = TestModel()
|
||||
|
@ -105,7 +108,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
|
|||
assert generated == excepted
|
||||
|
||||
|
||||
def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
|
||||
def test__logger_connector__epoch_result_store__train__ttbt(tmpdir) -> None:
|
||||
"""
|
||||
Tests that LoggerConnector will properly capture logged information with ttbt
|
||||
and reduce them
|
||||
|
@ -118,23 +121,23 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
|
|||
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
|
||||
|
||||
class MockSeq2SeqDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, i):
|
||||
def __getitem__(self, i) -> Tuple[Any, List[List[List[float]]]]:
|
||||
return x_seq, y_seq_list
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return 1
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
train_losses = []
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.test_hidden = None
|
||||
self.layer = torch.nn.Linear(2, 2)
|
||||
|
||||
@decorator_with_arguments(fx_name="training_step")
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
def training_step(self, batch, batch_idx, hiddens) -> Dict[str, Any]:
|
||||
self.test_hidden = torch.rand(1)
|
||||
|
||||
x_tensor, y_list = batch
|
||||
|
@ -155,7 +158,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
|
|||
def on_train_epoch_start(self) -> None:
|
||||
self.test_hidden = None
|
||||
|
||||
def train_dataloader(self):
|
||||
def train_dataloader(self) -> Any:
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=MockSeq2SeqDataset(),
|
||||
batch_size=batch_size,
|
||||
|
@ -163,7 +166,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
|
|||
sampler=None,
|
||||
)
|
||||
|
||||
def training_step_end(self, *_):
|
||||
def training_step_end(self, *_) -> None:
|
||||
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)
|
||||
|
||||
model = TestModel()
|
||||
|
@ -200,7 +203,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('num_dataloaders', [1, 2])
|
||||
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders):
|
||||
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders) -> None:
|
||||
"""
|
||||
Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario
|
||||
"""
|
||||
|
@ -211,7 +214,7 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
|
|||
test_losses = {}
|
||||
|
||||
@decorator_with_arguments(fx_name="test_step")
|
||||
def test_step(self, batch, batch_idx, dl_idx=0):
|
||||
def test_step(self, batch, batch_idx, dl_idx=0) -> Dict[str, Any]:
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
|
||||
|
@ -221,15 +224,15 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
|
|||
self.log("test_loss", loss, on_step=True, on_epoch=True)
|
||||
return {"test_loss": loss}
|
||||
|
||||
def on_test_batch_end(self, *args, **kwargs):
|
||||
def on_test_batch_end(self, *args, **kwargs) -> None:
|
||||
# save objects as it will be reset at the end of epoch.
|
||||
self.batch_results = deepcopy(self.trainer.logger_connector.cached_results)
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
def on_test_epoch_end(self) -> None:
|
||||
# save objects as it will be reset at the end of epoch.
|
||||
self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results)
|
||||
|
||||
def test_dataloader(self):
|
||||
def test_dataloader(self) -> List[Any]:
|
||||
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]
|
||||
|
||||
model = TestModel()
|
||||
|
@ -266,7 +269,7 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
|
|||
assert abs(expected.item() - generated.item()) < 1e-6
|
||||
|
||||
|
||||
def test_call_back_validator(tmpdir):
|
||||
def test_call_back_validator(tmpdir) -> None:
|
||||
|
||||
funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])
|
||||
|
||||
|
@ -368,7 +371,7 @@ def test_call_back_validator(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires two GPUs")
|
||||
def test_epoch_results_cache_dp(tmpdir):
|
||||
def test_epoch_results_cache_dp(tmpdir) -> None:
|
||||
|
||||
root_device = torch.device("cuda", 0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue