attempt mypy fix

This commit is contained in:
Sumanth Ratna 2021-01-28 18:54:04 -05:00
parent b52fa07d6e
commit fb7ed827d9
No known key found for this signature in database
GPG Key ID: 310949B7C8B60603
2 changed files with 35 additions and 30 deletions

View File

@ -21,23 +21,25 @@ import json
import os
import sys
from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
from tests.base import EvalModelTemplate # noqa: E402
from tests.base.develop_pipelines import run_prediction # noqa: E402
from tests.base.develop_utils import (reset_seed, # noqa: E402
set_random_master_port)
# this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do
PYTHONPATH = os.getenv('PYTHONPATH', '')
if ':' in PYTHONPATH:
sys.path = PYTHONPATH.split(':') + sys.path
from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
if HOROVOD_AVAILABLE:
import horovod.torch as hvd # noqa: E402
else:
print('You requested to import Horovod which is missing or not supported for your OS.')
from tests.base import EvalModelTemplate # noqa: E402
from tests.base.develop_pipelines import run_prediction # noqa: E402
from tests.base.develop_utils import set_random_master_port, reset_seed # noqa: E402
parser = argparse.ArgumentParser()
@ -45,7 +47,7 @@ parser.add_argument('--trainer-options', required=True)
parser.add_argument('--on-gpu', action='store_true', default=False)
def run_test_from_config(trainer_options):
def run_test_from_config(trainer_options) -> None:
"""Trains the default model with the given config."""
set_random_master_port()
reset_seed()

View File

@ -15,22 +15,25 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple, TypeVar
import pytest
import torch
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import \
CallbackHookNameValidator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset
from torch.utils.data import DataLoader
F = TypeVar('F', bound=Callable[..., Any])
def decorator_with_arguments(fx_name='', hook_fx_name=None):
def decorator(func):
def wrapper(self, *args, **kwargs):
def decorator_with_arguments(fx_name='', hook_fx_name=None) -> Callable[[F], F]:
def decorator(func: F) -> F:
def wrapper(self, *args, **kwargs) -> Any:
# Set information
self._current_fx_name = fx_name
self._current_hook_fx_name = hook_fx_name
@ -47,7 +50,7 @@ def decorator_with_arguments(fx_name='', hook_fx_name=None):
return decorator
def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch) -> None:
"""
Tests that LoggerConnector will properly capture logged information
and reduce them
@ -59,7 +62,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
train_losses = []
@decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
output = self.layer(batch)
loss = self.loss(batch, output)
@ -69,7 +72,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
return {"loss": loss}
def training_step_end(self, *_):
def training_step_end(self, *_) -> None:
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)
model = TestModel()
@ -105,7 +108,7 @@ def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
assert generated == excepted
def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
def test__logger_connector__epoch_result_store__train__ttbt(tmpdir) -> None:
"""
Tests that LoggerConnector will properly capture logged information with ttbt
and reduce them
@ -118,23 +121,23 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
def __getitem__(self, i) -> Tuple[Any, List[List[List[float]]]]:
return x_seq, y_seq_list
def __len__(self):
def __len__(self) -> int:
return 1
class TestModel(BoringModel):
train_losses = []
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.test_hidden = None
self.layer = torch.nn.Linear(2, 2)
@decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx, hiddens):
def training_step(self, batch, batch_idx, hiddens) -> Dict[str, Any]:
self.test_hidden = torch.rand(1)
x_tensor, y_list = batch
@ -155,7 +158,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
def on_train_epoch_start(self) -> None:
self.test_hidden = None
def train_dataloader(self):
def train_dataloader(self) -> Any:
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),
batch_size=batch_size,
@ -163,7 +166,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
sampler=None,
)
def training_step_end(self, *_):
def training_step_end(self, *_) -> None:
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)
model = TestModel()
@ -200,7 +203,7 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
@pytest.mark.parametrize('num_dataloaders', [1, 2])
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders):
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders) -> None:
"""
Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario
"""
@ -211,7 +214,7 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
test_losses = {}
@decorator_with_arguments(fx_name="test_step")
def test_step(self, batch, batch_idx, dl_idx=0):
def test_step(self, batch, batch_idx, dl_idx=0) -> Dict[str, Any]:
output = self.layer(batch)
loss = self.loss(batch, output)
@ -221,15 +224,15 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
self.log("test_loss", loss, on_step=True, on_epoch=True)
return {"test_loss": loss}
def on_test_batch_end(self, *args, **kwargs):
def on_test_batch_end(self, *args, **kwargs) -> None:
# save objects as it will be reset at the end of epoch.
self.batch_results = deepcopy(self.trainer.logger_connector.cached_results)
def on_test_epoch_end(self):
def on_test_epoch_end(self) -> None:
# save objects as it will be reset at the end of epoch.
self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results)
def test_dataloader(self):
def test_dataloader(self) -> List[Any]:
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]
model = TestModel()
@ -266,7 +269,7 @@ def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, m
assert abs(expected.item() - generated.item()) < 1e-6
def test_call_back_validator(tmpdir):
def test_call_back_validator(tmpdir) -> None:
funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])
@ -368,7 +371,7 @@ def test_call_back_validator(tmpdir):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires two GPUs")
def test_epoch_results_cache_dp(tmpdir):
def test_epoch_results_cache_dp(tmpdir) -> None:
root_device = torch.device("cuda", 0)