added option to change default tensor
This commit is contained in:
parent
f228e5ae66
commit
4693276494
|
@ -6,6 +6,11 @@ import traceback
|
|||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
APEX_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
APEX_AVAILABLE = False
|
||||
|
||||
class Trainer(TrainerIO):
|
||||
|
||||
|
@ -73,6 +78,10 @@ class Trainer(TrainerIO):
|
|||
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
|
||||
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
||||
|
||||
# apex test
|
||||
use_amp = True
|
||||
self.use_amp = use_amp and APEX_AVAILABLE
|
||||
|
||||
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
|
||||
"""
|
||||
Use less data for debugging purposes
|
||||
|
@ -207,6 +216,15 @@ class Trainer(TrainerIO):
|
|||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers = model.configure_optimizers()
|
||||
|
||||
if self.use_amp:
|
||||
# An example
|
||||
self.model, optimizer = amp.initialize(
|
||||
self.model, self.optimizers[0], opt_level="O2",
|
||||
keep_batchnorm_fp32=True, loss_scale="dynamic"
|
||||
)
|
||||
self.optimizers[0] = optimizer
|
||||
model.trainer = self
|
||||
|
||||
# add lr schedulers
|
||||
if self.lr_scheduler_milestones is not None:
|
||||
for optimizer in self.optimizers:
|
||||
|
@ -347,7 +365,13 @@ class Trainer(TrainerIO):
|
|||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
# backward pass
|
||||
loss.backward()
|
||||
if self.use_amp:
|
||||
for optimizer in self.optimizers:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
self.batch_loss_value += loss.item()
|
||||
|
||||
# gradient update with accumulated gradients
|
||||
|
|
Loading…
Reference in New Issue