Unified API upstream with suggestion to ben
This commit is contained in:
parent
9c34589493
commit
1e429bae58
|
@ -11,28 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Union
|
||||
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
from fairscale.optim import OSS
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LightningShardedDataParallel(ShardedDataParallel):
|
||||
def __init__(
|
||||
self,
|
||||
base_model: nn.Module,
|
||||
sharded_optimizer: Union[OSS, List[OSS]],
|
||||
process_group: Any = None,
|
||||
broadcast_buffers: bool = True
|
||||
):
|
||||
super().__init__(
|
||||
base_model=base_model,
|
||||
sharded_optimizer=sharded_optimizer,
|
||||
process_group=process_group,
|
||||
broadcast_buffers=broadcast_buffers
|
||||
)
|
||||
self.module = base_model
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if self.enable_broadcast_buffers:
|
||||
|
|
Loading…
Reference in New Issue