added safeguards for callbacks in loading saving
This commit is contained in:
parent
98c112598e
commit
55a33edd0a
|
@ -87,7 +87,23 @@ def main():
|
||||||
pdb.set_trace()
|
pdb.set_trace()
|
||||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||||
LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
|
trained_model = LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
|
||||||
|
|
||||||
|
# run prediction
|
||||||
|
for batch in model.test_dataloader:
|
||||||
|
break
|
||||||
|
|
||||||
|
x, y = batch
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
y_hat = model(x)
|
||||||
|
|
||||||
|
# acc
|
||||||
|
labels_hat = torch.argmax(y_hat, dim=1)
|
||||||
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||||
|
val_acc = torch.tensor(val_acc)
|
||||||
|
|
||||||
|
print(val_acc)
|
||||||
|
|
||||||
clear_save_dir()
|
clear_save_dir()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue