26 lines
828 B
Python
26 lines
828 B
Python
from unittest import mock
|
|
|
|
import pytest
|
|
from lightning.fabric.plugins.collectives import SingleDeviceCollective
|
|
|
|
|
|
def test_can_instantiate_without_args():
|
|
SingleDeviceCollective()
|
|
|
|
|
|
def test_create_group():
|
|
collective = SingleDeviceCollective()
|
|
assert collective.is_initialized()
|
|
|
|
with pytest.raises(RuntimeError, match=r"SingleDeviceCollective` does not own a group"):
|
|
_ = collective.group
|
|
|
|
with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.new_group") as new_mock:
|
|
collective.create_group(arg1=15, arg3=10)
|
|
|
|
group_kwargs = {"arg3": 10, "arg1": 15}
|
|
new_mock.assert_called_once_with(**group_kwargs)
|
|
|
|
with mock.patch("lightning.fabric.plugins.collectives.single_device.SingleDeviceCollective.destroy_group"):
|
|
collective.teardown()
|