set device to root gpu (#3042)
This commit is contained in:
parent
9f6be96f84
commit
9445c800b0
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
|
||||
|
@ -32,6 +33,7 @@ class GPUBackend(object):
|
|||
# call setup
|
||||
self.trainer.call_setup_hook(model)
|
||||
|
||||
torch.cuda.set_device(self.trainer.root_gpu)
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
|
|
Loading…
Reference in New Issue