Fix mypy errors attributed to `pytorch_lightning.demos.boring_classes` (#14201)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: otaj <ota@lightning.ai>
This commit is contained in:
Krishna Kalyan 2022-08-26 08:27:33 +01:00 committed by GitHub
parent a01e016fff
commit 6a999f123c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 35 deletions

View File

@ -51,7 +51,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",

View File

@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
class RandomDictDataset(Dataset):
@ -26,7 +31,7 @@ class RandomDictDataset(Dataset):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
a = self.data[index]
b = a + 2
return {"a": a, "b": b}
@ -40,7 +45,7 @@ class RandomDataset(Dataset):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tensor:
return self.data[index]
def __len__(self) -> int:
@ -52,7 +57,7 @@ class RandomIterableDataset(IterableDataset):
self.count = count
self.size = size
def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)
@ -62,16 +67,16 @@ class RandomIterableDatasetWithLen(IterableDataset):
self.count = count
self.size = size
def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(len(self)):
yield torch.randn(self.size)
def __len__(self):
def __len__(self) -> int:
return self.count
class BoringModel(LightningModule):
def __init__(self):
def __init__(self) -> None:
"""Testing PL Module.
Use as follows:
@ -90,60 +95,63 @@ class BoringModel(LightningModule):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return self.layer(x)
def loss(self, batch, preds):
def loss(self, batch: Tensor, preds: Tensor) -> Tensor:
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(preds, torch.ones_like(preds))
def step(self, x):
def step(self, x: Tensor) -> Tensor:
x = self(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out
def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_step_end(self, training_step_outputs):
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
return training_step_outputs
def training_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["loss"] for x in outputs]).mean()
def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}
def validation_epoch_end(self, outputs) -> None:
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["x"] for x in outputs]).mean()
def test_step(self, batch, batch_idx):
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}
def test_epoch_end(self, outputs) -> None:
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()
def configure_optimizers(self):
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))
def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))
def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))
@ -155,7 +163,7 @@ class BoringDataModule(LightningDataModule):
self.checkpoint_state: Optional[str] = None
self.random_full = RandomDataset(32, 64 * 4)
def setup(self, stage: Optional[str] = None):
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
self.random_train = Subset(self.random_full, indices=range(64))
@ -168,26 +176,27 @@ class BoringDataModule(LightningDataModule):
if stage == "predict" or stage is None:
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(self.random_train)
def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(self.random_val)
def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(self.random_test)
def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)
class ManualOptimBoringModel(BoringModel):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
opt = self.optimizers()
assert isinstance(opt, (Optimizer, LightningOptimizer))
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
@ -202,21 +211,21 @@ class DemoModel(LightningModule):
self.l1 = torch.nn.Linear(32, out_dim)
self.learning_rate = learning_rate
def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
def training_step(self, batch: Tensor, batch_nb: int) -> STEP_OUTPUT: # type: ignore[override]
x = batch
x = self(x)
loss = x.sum()
return loss
def configure_optimizers(self):
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
@ -225,7 +234,7 @@ class Net(nn.Module):
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)