test for complicated batch structure (#3928)
* test for complicated batch structure * test for complicated batch structure
This commit is contained in:
parent
71a4c61f6e
commit
1a345a4a78
|
@ -8,6 +8,7 @@ import pytest
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from tests.base.deterministic_model import DeterministicModel
|
from tests.base.deterministic_model import DeterministicModel
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
def test__training_step__log(tmpdir):
|
def test__training_step__log(tmpdir):
|
||||||
|
@ -419,3 +420,56 @@ def test_validation_step_with_string_data_logging():
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
)
|
)
|
||||||
trainer.fit(model, train_data, val_data)
|
trainer.fit(model, train_data, val_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_datasouce_batch(tmpdir):
|
||||||
|
|
||||||
|
class NestedDictStringDataset(Dataset):
|
||||||
|
def __init__(self, size, length):
|
||||||
|
self.len = length
|
||||||
|
self.data = torch.randn(length, size)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
x = {
|
||||||
|
'post_text': ['bird is fast', 'big cat'],
|
||||||
|
'dense_0': [
|
||||||
|
torch.tensor([-0.1000, 0.2000], dtype=torch.float64),
|
||||||
|
torch.tensor([1, 1], dtype=torch.uint8)
|
||||||
|
],
|
||||||
|
'post_id': ['115', '116'],
|
||||||
|
'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)]
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def on_train_epoch_start(self) -> None:
|
||||||
|
print("override any method to prove your bug")
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
output = self.layer(torch.rand(32))
|
||||||
|
loss = self.loss(batch, output)
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
output = self.layer(torch.rand(32))
|
||||||
|
loss = self.loss(batch, output)
|
||||||
|
self.log("x", loss)
|
||||||
|
return {"x": loss}
|
||||||
|
|
||||||
|
# fake data
|
||||||
|
train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
|
||||||
|
val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64))
|
||||||
|
|
||||||
|
# model
|
||||||
|
model = TestModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=os.getcwd(),
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=1,
|
||||||
|
max_epochs=1,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model, train_data, val_data)
|
||||||
|
|
Loading…
Reference in New Issue