This commit is contained in:
William Falcon 2019-07-08 20:13:40 -04:00
parent 12ee3c60dd
commit f38f3827fd
1 changed files with 3 additions and 4 deletions

View File

@ -87,7 +87,6 @@ class Trainer(TrainerIO):
# if gpus = -1 then use all available devices # if gpus = -1 then use all available devices
# otherwise, split the string using commas # otherwise, split the string using commas
if gpus is not None: if gpus is not None:
if gpus == '-1': if gpus == '-1':
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count())) self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
else: else:
@ -123,7 +122,7 @@ class Trainer(TrainerIO):
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
# apex test # 16 bit mixed precision training using apex
self.use_amp = use_amp and APEX_AVAILABLE self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp: if self.use_amp:
print('using 16bit precision') print('using 16bit precision')
@ -169,7 +168,7 @@ class Trainer(TrainerIO):
return tqdm_dic return tqdm_dic
def __layout_bookeeping(self, model): def __layout_bookeeping(self):
# training bookeeping # training bookeeping
self.total_batch_nb = 0 self.total_batch_nb = 0
self.running_loss = [] self.running_loss = []
@ -410,7 +409,7 @@ class Trainer(TrainerIO):
self.__get_dataloaders(ref_model) self.__get_dataloaders(ref_model)
# init training constants # init training constants
self.__layout_bookeeping(ref_model) self.__layout_bookeeping()
# add lr schedulers # add lr schedulers
if self.lr_scheduler_milestones is not None: if self.lr_scheduler_milestones is not None: