set device to root gpu (#3042)

This commit is contained in:
Ananya Harsh Jha 2020-08-18 19:28:35 -04:00 committed by GitHub
parent 9f6be96f84
commit 9445c800b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType
@ -32,6 +33,7 @@ class GPUBackend(object):
# call setup
self.trainer.call_setup_hook(model)
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)
# CHOOSE OPTIMIZER