added option to change default tensor
This commit is contained in:
parent
edd406f419
commit
90a460ec62
|
@ -7,6 +7,9 @@ from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||||
from torch.optim.lr_scheduler import MultiStepLR
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
APEX_AVAILABLE = True
|
APEX_AVAILABLE = True
|
||||||
|
@ -83,6 +86,7 @@ class Trainer(TrainerIO):
|
||||||
use_amp = True
|
use_amp = True
|
||||||
self.use_amp = use_amp and APEX_AVAILABLE
|
self.use_amp = use_amp and APEX_AVAILABLE
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
|
self.amp_handle = amp.init(enabled=True)
|
||||||
print('using 16bit precision')
|
print('using 16bit precision')
|
||||||
|
|
||||||
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
|
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
|
||||||
|
@ -219,14 +223,14 @@ class Trainer(TrainerIO):
|
||||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||||
self.optimizers = model.configure_optimizers()
|
self.optimizers = model.configure_optimizers()
|
||||||
|
|
||||||
if self.use_amp:
|
# if self.use_amp:
|
||||||
# An example
|
# An example
|
||||||
self.model, optimizer = amp.initialize(
|
# self.model, optimizer = amp.initialize(
|
||||||
self.model, self.optimizers[0], opt_level="O2",
|
# self.model, self.optimizers[0], opt_level="O2",
|
||||||
keep_batchnorm_fp32=True, loss_scale="dynamic"
|
# keep_batchnorm_fp32=True, loss_scale="dynamic"
|
||||||
)
|
# )
|
||||||
self.optimizers[0] = optimizer
|
# self.optimizers[0] = optimizer
|
||||||
model.trainer = self
|
# model.trainer = self
|
||||||
|
|
||||||
# add lr schedulers
|
# add lr schedulers
|
||||||
if self.lr_scheduler_milestones is not None:
|
if self.lr_scheduler_milestones is not None:
|
||||||
|
@ -370,7 +374,7 @@ class Trainer(TrainerIO):
|
||||||
# backward pass
|
# backward pass
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
for optimizer in self.optimizers:
|
for optimizer in self.optimizers:
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
with self.amp_handle.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -7,7 +7,7 @@ from setuptools import setup, find_packages
|
||||||
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
|
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
|
||||||
setup(
|
setup(
|
||||||
name="pytorch-lightning",
|
name="pytorch-lightning",
|
||||||
version='0.1.dev1721',
|
version='0.1.dev1722',
|
||||||
description="The Keras for ML researchers using PyTorch",
|
description="The Keras for ML researchers using PyTorch",
|
||||||
author="William Falcon",
|
author="William Falcon",
|
||||||
author_email="waf2107@columbia.edu",
|
author_email="waf2107@columbia.edu",
|
||||||
|
|
Loading…
Reference in New Issue