Prevent crash if sync_dist=True on CPU (#4626)
* Added test/fix for sync_dist raising NotImplementedError * Fixed comments/formatting * Revert base class change, enforce sync tensors across accelerators, added GPU test
This commit is contained in:
parent
3d202f9ecc
commit
33470ba605
|
@ -11,9 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -80,3 +82,9 @@ class CPUAccelerator(Accelerator):
|
|||
else:
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
||||
def sync_tensor(self,
|
||||
tensor: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
|
||||
return tensor
|
||||
|
|
|
@ -11,10 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union, Optional, Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
|
||||
|
@ -120,3 +121,9 @@ class GPUAccelerator(Accelerator):
|
|||
# be referenced from and if there are multiple optimizers the batch will
|
||||
# wind up copying it to the same device repeatedly.
|
||||
return self.batch_to_device(batch, gpu_id)
|
||||
|
||||
def sync_tensor(self,
|
||||
tensor: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
|
||||
return tensor
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
|
@ -337,3 +337,9 @@ class TPUAccelerator(Accelerator):
|
|||
buffer = io.BytesIO(data.cpu().byte().numpy())
|
||||
obj = torch.load(buffer)
|
||||
return obj
|
||||
|
||||
def sync_tensor(self,
|
||||
tensor: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
|
||||
return tensor
|
||||
|
|
|
@ -682,3 +682,69 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
assert func_name in trainer.logger_connector.progress_bar_metrics
|
||||
else:
|
||||
assert func_name not in trainer.logger_connector.progress_bar_metrics
|
||||
|
||||
|
||||
def test_logging_sync_dist_true_cpu(tmpdir):
|
||||
"""
|
||||
Tests to ensure that the sync_dist flag works with CPU (should just return the original value)
|
||||
"""
|
||||
fake_result = 1
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch[0])
|
||||
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
|
||||
return acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
|
||||
return {"x": loss}
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
max_epochs=2,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.logged_metrics['foo'] == fake_result
|
||||
assert trainer.logged_metrics['bar'] == fake_result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_logging_sync_dist_true_gpu(tmpdir):
|
||||
"""
|
||||
Tests to ensure that the sync_dist flag works with GPU (should just return the original value)
|
||||
"""
|
||||
fake_result = 1
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch[0])
|
||||
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
|
||||
return acc
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
|
||||
return {"x": loss}
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
max_epochs=2,
|
||||
gpus=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.logged_metrics['foo'] == fake_result
|
||||
assert trainer.logged_metrics['bar'] == fake_result
|
||||
|
|
Loading…
Reference in New Issue