removed self.model refs

This commit is contained in:
William Falcon 2019-06-26 18:04:29 -04:00
parent 12a0e98920
commit bc0278252e
1 changed files with 24 additions and 9 deletions

View File

@ -293,11 +293,13 @@ class Trainer(TrainerIO):
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
self.model.current_epoch = epoch_nb
model = self.model.module if self.data_parallel else self.model
model.current_epoch = epoch_nb
# hook
if self.__is_function_implemented('on_epoch_start'):
self.model.on_epoch_start()
model = self.model.module if self.data_parallel else self.model
model.on_epoch_start()
self.current_epoch = epoch_nb
self.total_batches = self.nb_tng_batches + self.nb_val_batches
@ -310,7 +312,9 @@ class Trainer(TrainerIO):
for batch_nb, data_batch in enumerate(self.tng_dataloader):
self.batch_nb = batch_nb
self.global_step += 1
self.model.global_step = self.global_step
model = self.model.module if self.data_parallel else self.model
model.global_step = self.global_step
# stop when the flag is changed or we've gone past the amount requested in the batches
self.total_batch_nb += 1
@ -340,7 +344,10 @@ class Trainer(TrainerIO):
# count items in memory
# nb_params, nb_tensors = count_mem_items()
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
if self.data_parallel:
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
else:
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
# add gpu memory
if self.on_gpu:
@ -349,7 +356,9 @@ class Trainer(TrainerIO):
# add norms
if self.track_grad_norm > 0:
grad_norm_dic = self.model.module.grad_norm(self.track_grad_norm)
model = self.model.module if self.data_parallel else self.model
grad_norm_dic = model.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
# log metrics
@ -358,7 +367,8 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_batch_end'):
self.model.on_batch_end()
model = self.model.module if self.data_parallel else self.model
model.on_batch_end()
# end epoch early
if early_stop_epoch:
@ -366,7 +376,8 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_epoch_end'):
self.model.on_epoch_end()
model = self.model.module if self.data_parallel else self.model
model.on_epoch_end()
# early stopping
if self.enable_early_stop:
@ -385,7 +396,9 @@ class Trainer(TrainerIO):
# hook
if self.__is_function_implemented('on_batch_start'):
response = self.model.on_batch_start(data_batch)
model = self.model.module if self.data_parallel else self.model
response = model.on_batch_start(data_batch)
if response == -1:
return -1
@ -415,7 +428,9 @@ class Trainer(TrainerIO):
loss.backward()
if self.check_grad_nans:
for param in self.model.parameters():
model = self.model.module if self.data_parallel else self.model
for param in model.parameters():
print(param.grad.float().sum())
self.batch_loss_value += loss.item()