test for complicated batch structure (#3928)
* test for complicated batch structure * test for complicated batch structure
This commit is contained in:
parent
71a4c61f6e
commit
1a345a4a78
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def test__training_step__log(tmpdir):
|
||||
|
@ -419,3 +420,56 @@ def test_validation_step_with_string_data_logging():
|
|||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model, train_data, val_data)
|
||||
|
||||
|
||||
def test_nested_datasouce_batch(tmpdir):
|
||||
|
||||
class NestedDictStringDataset(Dataset):
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
||||
def __getitem__(self, index):
|
||||
x = {
|
||||
'post_text': ['bird is fast', 'big cat'],
|
||||
'dense_0': [
|
||||
torch.tensor([-0.1000, 0.2000], dtype=torch.float64),
|
||||
torch.tensor([1, 1], dtype=torch.uint8)
|
||||
],
|
||||
'post_id': ['115', '116'],
|
||||
'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)]
|
||||
}
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def on_train_epoch_start(self) -> None:
|
||||
print("override any method to prove your bug")
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(torch.rand(32))
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(torch.rand(32))
|
||||
loss = self.loss(batch, output)
|
||||
self.log("x", loss)
|
||||
return {"x": loss}
|
||||
|
||||
# fake data
|
||||
train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
|
||||
val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
|
||||
|
||||
# model
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model, train_data, val_data)
|
||||
|
|
Loading…
Reference in New Issue