test for complicated batch structure (#3928)

* test for complicated batch structure

* test for complicated batch structure
This commit is contained in:
William Falcon 2020-10-06 23:14:51 -04:00 committed by GitHub
parent 71a4c61f6e
commit 1a345a4a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 54 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import pytest
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
from torch.utils.data import Dataset
def test__training_step__log(tmpdir):
@ -419,3 +420,56 @@ def test_validation_step_with_string_data_logging():
weights_summary=None,
)
trainer.fit(model, train_data, val_data)
def test_nested_datasouce_batch(tmpdir):
class NestedDictStringDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
x = {
'post_text': ['bird is fast', 'big cat'],
'dense_0': [
torch.tensor([-0.1000, 0.2000], dtype=torch.float64),
torch.tensor([1, 1], dtype=torch.uint8)
],
'post_id': ['115', '116'],
'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)]
}
return x
def __len__(self):
return self.len
class TestModel(BoringModel):
def on_train_epoch_start(self) -> None:
print("override any method to prove your bug")
def training_step(self, batch, batch_idx):
output = self.layer(torch.rand(32))
loss = self.loss(batch, output)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
output = self.layer(torch.rand(32))
loss = self.loss(batch, output)
self.log("x", loss)
return {"x": loss}
# fake data
train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
# model
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model, train_data, val_data)