lightning/pytorch_lightning/utils/plotting.py

29 lines
828 B
Python
Raw Normal View History

2019-03-31 01:45:16 +00:00
from matplotlib import pyplot as plt
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
def plot_confusion_matrix(cm,
save_path,
normalize=False,
title='Confusion matrix',
ylabel='y',
xlabel='x'):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
fig = plt.figure()
plt.matshow(cm)
plt.title(title)
plt.colorbar()
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.savefig(save_path)