added test for model loading and predicting
This commit is contained in:
parent
926fa206ff
commit
85eaa28872
|
@ -54,6 +54,41 @@ def clear_save_dir():
|
|||
shutil.rmtree(save_dir)
|
||||
|
||||
|
||||
def load_model(exp, save_dir):
|
||||
|
||||
# load trained model
|
||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
|
||||
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
|
||||
weights_dir = os.path.join(save_dir, checkpoints[0])
|
||||
|
||||
trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True)
|
||||
|
||||
assert trained_model is not None, 'loading model failed'
|
||||
|
||||
return trained_model
|
||||
|
||||
|
||||
def run_prediction(dataloader, trained_model):
|
||||
# run prediction on 1 batch
|
||||
for batch in dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = trained_model(x)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
val_acc = val_acc.item()
|
||||
|
||||
assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set'
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
@ -72,7 +107,7 @@ def main():
|
|||
progress_bar=True,
|
||||
experiment=exp,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
train_percent_check=0.2,
|
||||
val_percent_check=0.1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
|
@ -84,32 +119,11 @@ def main():
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# load trained model
|
||||
pdb.set_trace()
|
||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
# test model loading
|
||||
pretrained_model = load_model(exp, save_dir)
|
||||
|
||||
pdb.set_trace()
|
||||
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
|
||||
weights_dir = os.path.join(save_dir, checkpoints[0])
|
||||
|
||||
trained_model = LightningTemplateModel.load_from_metrics(weights_path=weights_dir, tags_csv=tags_path, on_gpu=True)
|
||||
|
||||
# run prediction
|
||||
for batch in model.test_dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = trained_model(x)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
print(val_acc)
|
||||
# test model preds
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
|
Loading…
Reference in New Issue