2019-03-31 01:45:16 +00:00
|
|
|
import numpy as np
|
|
|
|
np.seterr(divide='ignore', invalid='ignore')
|
|
|
|
|
|
|
|
|
|
|
|
def plot_confusion_matrix(cm,
|
|
|
|
save_path,
|
|
|
|
normalize=False,
|
|
|
|
title='Confusion matrix',
|
|
|
|
ylabel='y',
|
|
|
|
xlabel='x'):
|
|
|
|
"""
|
|
|
|
This function prints and plots the confusion matrix.
|
|
|
|
Normalization can be applied by setting `normalize=True`.
|
|
|
|
"""
|
2019-07-12 18:28:49 +00:00
|
|
|
from matplotlib import pyplot as plt
|
2019-03-31 01:45:16 +00:00
|
|
|
if normalize:
|
|
|
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
|
|
|
print("Normalized confusion matrix")
|
|
|
|
else:
|
|
|
|
print('Confusion matrix, without normalization')
|
|
|
|
|
|
|
|
fig = plt.figure()
|
|
|
|
plt.matshow(cm)
|
|
|
|
plt.title(title)
|
|
|
|
plt.colorbar()
|
|
|
|
plt.ylabel(ylabel)
|
|
|
|
plt.xlabel(xlabel)
|
|
|
|
plt.savefig(save_path)
|