docs
This commit is contained in:
parent
12ee3c60dd
commit
f38f3827fd
|
@ -87,7 +87,6 @@ class Trainer(TrainerIO):
|
|||
# if gpus = -1 then use all available devices
|
||||
# otherwise, split the string using commas
|
||||
if gpus is not None:
|
||||
|
||||
if gpus == '-1':
|
||||
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
|
||||
else:
|
||||
|
@ -123,7 +122,7 @@ class Trainer(TrainerIO):
|
|||
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
|
||||
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
||||
|
||||
# apex test
|
||||
# 16 bit mixed precision training using apex
|
||||
self.use_amp = use_amp and APEX_AVAILABLE
|
||||
if self.use_amp:
|
||||
print('using 16bit precision')
|
||||
|
@ -169,7 +168,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
return tqdm_dic
|
||||
|
||||
def __layout_bookeeping(self, model):
|
||||
def __layout_bookeeping(self):
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
self.running_loss = []
|
||||
|
@ -410,7 +409,7 @@ class Trainer(TrainerIO):
|
|||
self.__get_dataloaders(ref_model)
|
||||
|
||||
# init training constants
|
||||
self.__layout_bookeeping(ref_model)
|
||||
self.__layout_bookeeping()
|
||||
|
||||
# add lr schedulers
|
||||
if self.lr_scheduler_milestones is not None:
|
||||
|
|
Loading…
Reference in New Issue