removed self.model refs
This commit is contained in:
parent
12a0e98920
commit
bc0278252e
|
@ -293,11 +293,13 @@ class Trainer(TrainerIO):
|
|||
for lr_scheduler in self.lr_schedulers:
|
||||
lr_scheduler.step()
|
||||
|
||||
self.model.current_epoch = epoch_nb
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.current_epoch = epoch_nb
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_start'):
|
||||
self.model.on_epoch_start()
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_epoch_start()
|
||||
|
||||
self.current_epoch = epoch_nb
|
||||
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
||||
|
@ -310,7 +312,9 @@ class Trainer(TrainerIO):
|
|||
for batch_nb, data_batch in enumerate(self.tng_dataloader):
|
||||
self.batch_nb = batch_nb
|
||||
self.global_step += 1
|
||||
self.model.global_step = self.global_step
|
||||
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.global_step = self.global_step
|
||||
|
||||
# stop when the flag is changed or we've gone past the amount requested in the batches
|
||||
self.total_batch_nb += 1
|
||||
|
@ -340,7 +344,10 @@ class Trainer(TrainerIO):
|
|||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
if self.data_parallel:
|
||||
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
else:
|
||||
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu:
|
||||
|
@ -349,7 +356,9 @@ class Trainer(TrainerIO):
|
|||
|
||||
# add norms
|
||||
if self.track_grad_norm > 0:
|
||||
grad_norm_dic = self.model.module.grad_norm(self.track_grad_norm)
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
||||
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
|
@ -358,7 +367,8 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_end'):
|
||||
self.model.on_batch_end()
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_batch_end()
|
||||
|
||||
# end epoch early
|
||||
if early_stop_epoch:
|
||||
|
@ -366,7 +376,8 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_end'):
|
||||
self.model.on_epoch_end()
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
model.on_epoch_end()
|
||||
|
||||
# early stopping
|
||||
if self.enable_early_stop:
|
||||
|
@ -385,7 +396,9 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_start'):
|
||||
response = self.model.on_batch_start(data_batch)
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
response = model.on_batch_start(data_batch)
|
||||
|
||||
if response == -1:
|
||||
return -1
|
||||
|
||||
|
@ -415,7 +428,9 @@ class Trainer(TrainerIO):
|
|||
loss.backward()
|
||||
|
||||
if self.check_grad_nans:
|
||||
for param in self.model.parameters():
|
||||
|
||||
model = self.model.module if self.data_parallel else self.model
|
||||
for param in model.parameters():
|
||||
print(param.grad.float().sum())
|
||||
|
||||
self.batch_loss_value += loss.item()
|
||||
|
|
Loading…
Reference in New Issue