From 1a345a4a7858726368d3287f0801f5920d697c17 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 6 Oct 2020 23:14:51 -0400 Subject: [PATCH] test for complicated batch structure (#3928) * test for complicated batch structure * test for complicated batch structure --- .../logging/test_train_loop_logging_1_0.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/trainer/logging/test_train_loop_logging_1_0.py b/tests/trainer/logging/test_train_loop_logging_1_0.py index 6eac64288e..ab4541489d 100644 --- a/tests/trainer/logging/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging/test_train_loop_logging_1_0.py @@ -8,6 +8,7 @@ import pytest from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel +from torch.utils.data import Dataset def test__training_step__log(tmpdir): @@ -419,3 +420,56 @@ def test_validation_step_with_string_data_logging(): weights_summary=None, ) trainer.fit(model, train_data, val_data) + + +def test_nested_datasouce_batch(tmpdir): + + class NestedDictStringDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + x = { + 'post_text': ['bird is fast', 'big cat'], + 'dense_0': [ + torch.tensor([-0.1000, 0.2000], dtype=torch.float64), + torch.tensor([1, 1], dtype=torch.uint8) + ], + 'post_id': ['115', '116'], + 'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)] + } + return x + + def __len__(self): + return self.len + + class TestModel(BoringModel): + def on_train_epoch_start(self) -> None: + print("override any method to prove your bug") + + def training_step(self, batch, batch_idx): + output = self.layer(torch.rand(32)) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + output = self.layer(torch.rand(32)) + loss = self.loss(batch, output) + self.log("x", loss) + return {"x": loss} + + # fake data + train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) + val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model, train_data, val_data)