Fix conversion in on_before_forward
This commit is contained in:
parent
50ed083fc7
commit
df416f6c78
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Any, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
|
@ -50,11 +50,12 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
optimizer.consolidate_state_dict()
|
||||
return self._optim_state_dict(optimizer)
|
||||
|
||||
def on_before_forward(self, args: Any, model: LightningModule):
|
||||
def on_before_forward(self, model: LightningModule, *args):
|
||||
args = list(args)
|
||||
batch = args[0]
|
||||
batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu)
|
||||
args[0] = batch
|
||||
return args
|
||||
return tuple(args)
|
||||
|
||||
def _check_fairscale(self):
|
||||
if not FAIRSCALE_AVAILABLE:
|
||||
|
|
Loading…
Reference in New Issue