22 lines
585 B
Python
22 lines
585 B
Python
![]() |
import torch
|
||
|
|
||
|
from pytorch_lightning.utilities.data import extract_batch_size
|
||
|
|
||
|
|
||
|
def test_extract_batch_size():
|
||
|
"""Tests the behavior of extracting the batch size."""
|
||
|
batch = "test string"
|
||
|
assert extract_batch_size(batch) == 11
|
||
|
|
||
|
batch = torch.zeros(11, 10, 9, 8)
|
||
|
assert extract_batch_size(batch) == 11
|
||
|
|
||
|
batch = {'test': torch.zeros(11, 10)}
|
||
|
assert extract_batch_size(batch) == 11
|
||
|
|
||
|
batch = [torch.zeros(11, 10)]
|
||
|
assert extract_batch_size(batch) == 11
|
||
|
|
||
|
batch = {'test': [{'test': [torch.zeros(11, 10)]}]}
|
||
|
assert extract_batch_size(batch) == 11
|