Unified API upstream with suggestion to ben

This commit is contained in:
SeanNaren 2020-11-21 11:40:38 +00:00
parent 9c34589493
commit 1e429bae58
1 changed files with 0 additions and 17 deletions

View File

@ -11,28 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS
from torch import nn
class LightningShardedDataParallel(ShardedDataParallel):
def __init__(
self,
base_model: nn.Module,
sharded_optimizer: Union[OSS, List[OSS]],
process_group: Any = None,
broadcast_buffers: bool = True
):
super().__init__(
base_model=base_model,
sharded_optimizer=sharded_optimizer,
process_group=process_group,
broadcast_buffers=broadcast_buffers
)
self.module = base_model
def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers: