added safeguards for callbacks in loading saving
This commit is contained in:
parent
98c112598e
commit
55a33edd0a
|
@ -87,7 +87,23 @@ def main():
|
|||
pdb.set_trace()
|
||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
|
||||
trained_model = LightningTemplateModel.load_from_metrics(weights_path=save_dir, tags_csv=tags_path)
|
||||
|
||||
# run prediction
|
||||
for batch in model.test_dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = model(x)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
print(val_acc)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
|
Loading…
Reference in New Issue