added single gpu train test
This commit is contained in:
parent
9ecb1f2aee
commit
2f7a9ad40d
|
@ -393,7 +393,9 @@ class Trainer(TrainerIO):
|
|||
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
||||
|
||||
elif self.single_gpu:
|
||||
output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i)
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
|
||||
output = model(data_batch, batch_i)
|
||||
|
||||
else:
|
||||
output = model.validation_step(data_batch, batch_i)
|
||||
|
@ -474,7 +476,7 @@ If you're not using SLURM, ignore this message!
|
|||
self.__dp_train(model)
|
||||
|
||||
elif self.single_gpu:
|
||||
self.__single_gpu_train(model)\
|
||||
self.__single_gpu_train(model)
|
||||
|
||||
# ON CPU
|
||||
else:
|
||||
|
@ -846,7 +848,10 @@ We recommend you switch to ddp if you want to use amp
|
|||
output = self.model(data_batch, batch_nb)
|
||||
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
||||
elif self.single_gpu:
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
|
||||
output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
|
||||
|
||||
else:
|
||||
output = self.model.training_step(data_batch, batch_nb)
|
||||
|
||||
|
|
Loading…
Reference in New Issue