removed dummy d
This commit is contained in:
parent
5606fd86df
commit
480dcb0213
|
@ -3,6 +3,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
|
||||
from argparse import Namespace
|
||||
from test_tube import Experiment
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
import numpy as np
|
||||
import warnings
|
||||
import torch
|
||||
|
@ -27,28 +28,48 @@ def get_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_exp():
|
||||
def get_exp(debug=True):
|
||||
# set up exp object without actually saving logs
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
exp = Experiment(debug=True, save_dir=root_dir, name='tests_tt_dir')
|
||||
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir')
|
||||
return exp
|
||||
|
||||
|
||||
def clear_tt_dir():
|
||||
def init_save_dir():
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
tt_dir = os.path.join(root_dir, 'tests_tt_dir')
|
||||
if os.path.exists(tt_dir):
|
||||
shutil.rmtree(tt_dir)
|
||||
save_dir = os.path.join(root_dir, 'save_dir')
|
||||
|
||||
if os.path.exists(save_dir):
|
||||
shutil.rmtree(save_dir)
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
return save_dir
|
||||
|
||||
|
||||
def clear_save_dir():
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
save_dir = os.path.join(root_dir, 'save_dir')
|
||||
if os.path.exists(save_dir):
|
||||
shutil.rmtree(save_dir)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
clear_tt_dir()
|
||||
save_dir = init_save_dir()
|
||||
model = get_model()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
exp.save()
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
trainer = Trainer(
|
||||
checkpoint_callback=checkpoint,
|
||||
progress_bar=True,
|
||||
experiment=get_exp(),
|
||||
experiment=exp,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
|
@ -62,22 +83,12 @@ def main():
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test prediction
|
||||
data = model.val_dataloader
|
||||
for batch in data:
|
||||
break
|
||||
# load trained model
|
||||
pdb.set_trace()
|
||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||
LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=)
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
out = model(x)
|
||||
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
print(val_acc)
|
||||
|
||||
|
||||
clear_tt_dir()
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue