[WIP] Reduction when batch size < num gpus (#1609)
* reduce if <= num_gpus * add test with explanation * chlog * fix changelog Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
fafe5d63a7
commit
e6b34ef90d
|
@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
|
||||
|
||||
### Deprecated
|
||||
|
||||
### Removed
|
||||
|
|
|
@ -196,8 +196,8 @@ class TrainerLoggingMixin(ABC):
|
|||
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
|
||||
pass
|
||||
|
||||
# reduce only metrics that have the same number of gpus
|
||||
elif output[k].size(0) == num_gpus:
|
||||
reduced = torch.mean(output[k])
|
||||
output[k] = reduced
|
||||
# do not reduce metrics that have batch size > num gpus
|
||||
elif output[k].size(0) <= num_gpus:
|
||||
output[k] = torch.mean(output[k])
|
||||
|
||||
return output
|
||||
|
|
|
@ -2,6 +2,8 @@ import platform
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Subset
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -482,3 +484,46 @@ def test_dataloader_reinit_for_subclass():
|
|||
assert isinstance(result, torch.utils.data.DataLoader)
|
||||
assert isinstance(result, CustomDataLoader)
|
||||
assert hasattr(result, 'dummy_kwarg')
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
|
||||
def test_batch_size_smaller_than_num_gpus():
|
||||
# we need at least 3 gpus for this test
|
||||
num_gpus = 3
|
||||
batch_size = 3
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
TestModelBase,
|
||||
):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.c_d1_bn = torch.nn.ReLU()
|
||||
|
||||
def train_dataloader(self):
|
||||
dataloader = super().train_dataloader()
|
||||
# construct a dataset with a size that is not divisible by num_gpus
|
||||
# therefore the last batch will have a size < num_gpus
|
||||
size = num_gpus * batch_size + (num_gpus - 1)
|
||||
dataset = Subset(dataloader.dataset, range(size))
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
drop_last=False,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
hparams.batch_size = batch_size
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
gpus=num_gpus,
|
||||
)
|
||||
|
||||
# we expect the reduction for the metrics also to happen on the last batch
|
||||
# where we will get fewer metrics than gpus
|
||||
result = trainer.fit(model)
|
||||
assert 1 == result
|
||||
|
|
Loading…
Reference in New Issue