# %%
__author__ = "Sarah Shi"
import numpy as np
import pandas as pd
import warnings
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib
from matplotlib import pyplot as plt
# %%
[docs]
def confusion_matrix_df(given_min, pred_min):
"""
Constructs a confusion matrix as a pandas DataFrame for easy visualization and
analysis. The function first finds the unique classes and maps them to their
corresponding mineral names. Then, it uses these mappings to construct the
confusion matrix, which compares the given and predicted classes.
When parent labels such as "Feldspar" or "Pyroxene" are present in either
the given or predicted arrays, child labels (e.g., "Alkali_Feldspar",
"Plagioclase") are automatically merged into the parent label so the
confusion matrix dimensions remain consistent.
Labels that do not match any entry in the canonical mineral list after
all merges are applied will trigger a UserWarning and the corresponding
rows will be excluded from the confusion matrix.
Parameters:
given_min (array-like): The true class labels.
pred_min (array-like): The predicted class labels.
Returns:
cm_df (DataFrame): A DataFrame representing the confusion matrix, with rows
and columns labeled by the unique mineral names found in
the given and predicted class arrays.
"""
minerals = [
"Alkali_Feldspar",
"Amphibole",
"Apatite",
"Biotite",
"Carbonate",
"Chlorite",
"Clinopyroxene",
"Epidote",
"Garnet",
"Glass",
"Kalsilite",
"Leucite",
"Melilite",
"Muscovite",
"Nepheline",
"Olivine",
"Orthopyroxene",
"Oxide",
"Plagioclase",
"Rutile",
"Serpentine",
"SiO2_Polymorph",
"Titanite",
"Tourmaline",
"Zircon",
]
# Parent-group mapping: parent label -> set of child labels
parent_map = {
"Feldspar": {"Alkali_Feldspar", "Plagioclase"},
"Pyroxene": {"Clinopyroxene", "Orthopyroxene"},
}
given = pd.Series(given_min)
pred = pd.Series(pred_min)
given_nans = given.isna().sum()
pred_nans = pred.isna().sum()
if given_nans > 0 or pred_nans > 0:
warnings.warn(
f"Missing data detected: {given_nans} NaN(s) in given_min, "
f"{pred_nans} NaN(s) in pred_min. "
f"These rows will be excluded from the confusion matrix.",
UserWarning,
stacklevel=2,
)
mask = given.notna() & pred.notna()
given = given[mask]
pred = pred[mask]
# --- Case-insensitive group merges ---
def _merge_to_spinel_group(x):
if pd.isna(x):
return x
s = str(x).strip().lower()
if "spinel" in s or s in {"magnetite", "chromite", "hercynite", "ulvospinel"}:
return "Spinel_Group"
return x
def _merge_to_rhomb_oxide(x):
if pd.isna(x):
return x
s = str(x).strip().lower()
if s in {"hematite", "ilmenite"}:
return "Rhombohedral_Oxides"
return x
def _merge_to_oxide(x):
if pd.isna(x):
return x
if x in {"Rhombohedral_Oxides", "Spinel_Group"}:
return "Oxide"
return x
def _merge_to_clinopyroxene(x):
if pd.isna(x):
return x
s = str(x).strip().lower()
if s in {"na-pyroxene"}:
return "Clinopyroxene"
return x
given = given.map(_merge_to_spinel_group)
given = given.map(_merge_to_rhomb_oxide)
given = given.map(_merge_to_oxide)
given = given.map(_merge_to_clinopyroxene)
pred = pred.map(_merge_to_spinel_group)
pred = pred.map(_merge_to_rhomb_oxide)
pred = pred.map(_merge_to_oxide)
pred = pred.map(_merge_to_clinopyroxene)
# --- Dynamic parent-group merges ---
all_labels = set(given) | set(pred)
for parent, children in parent_map.items():
if parent in all_labels:
def _merge_parent(x, _p=parent, _c=children):
if pd.isna(x):
return x
return _p if x in _c else x
given = given.map(_merge_parent)
pred = pred.map(_merge_parent)
# Build label list, swapping children for parent where needed
active_minerals = []
for m in minerals:
# Skip children that were merged into a parent
skip = False
for parent, children in parent_map.items():
if m in children and parent in all_labels:
skip = True
break
if not skip:
active_minerals.append(m)
# Insert parent labels at the position of their first child
for parent, children in parent_map.items():
if parent in all_labels and parent not in active_minerals:
idx = next(
(i for i, m in enumerate(minerals) if m in children),
len(active_minerals),
)
# Translate the index in `minerals` to the corresponding
# position in `active_minerals`
insert_pos = len(active_minerals)
for i, m in enumerate(active_minerals):
if minerals.index(m) >= idx:
insert_pos = i
break
active_minerals.insert(insert_pos, parent)
# --- Warn and drop labels not in active_minerals ---
active_set = set(active_minerals)
post_merge_labels = set(given) | set(pred)
unrecognized = sorted(post_merge_labels - active_set)
if unrecognized:
warnings.warn(
f"Unrecognized label(s) not in the canonical mineral list: "
f"{unrecognized}. These rows will be excluded from the "
f"confusion matrix.",
UserWarning,
stacklevel=2,
)
mask = given.isin(active_set) & pred.isin(active_set)
given = given[mask]
pred = pred[mask]
# Build the confusion matrix
cm_matrix = confusion_matrix(given, pred, labels=active_minerals)
cm_df = pd.DataFrame(cm_matrix, index=active_minerals, columns=active_minerals)
return cm_df
[docs]
def pp_matrix(
df_cm,
annot=True,
cmap="BuGn",
fmt=".2f",
fz=12,
lw=0.5,
cbar=False,
figsize=[14, 14],
show_null_values=0,
pred_val_axis="x",
savefig=None
):
"""
Creates and displays a confusion matrix visualization using Seaborn's heatmap function.
Parameters:
df_cm (pd.DataFrame): DataFrame containing the confusion matrix without totals.
annot (bool, optional): If True, display the text in each cell. Default is True.
cmap (str, optional): Color map for the heatmap. Default is 'BuGn'.
fmt (str, optional): String format for annotating. Default is '.2f'.
fz (int, optional): Font size for text annotations. Default is 12.
lw (float, optional): Line width for cell borders. Default is 0.5.
cbar (bool, optional): If True, display the color bar. Default is False.
figsize (list, optional): Figure size. Default is [10.5, 10.5].
show_null_values (int, optional): Show null values, 0 or 1. Default is 0.
pred_val_axis (str, optional): Axis to show prediction values ('x' or 'y'). Default is 'x'.
savefig (str, optional): If provided, saves the plot to the specified path with a '.pdf' extension.
Returns:
None. The function creates and displays the heatmap of the confusion matrix.
Note:
The function modifies the input DataFrame to include total counts and adjusts text and color configurations.
The source of the original code is from:
https://github.com/wcipriano/pretty-print-confusion-matrix/blob/master/pretty_confusion_matrix/pretty_confusion_matrix.py\
"""
from matplotlib.collections import QuadMesh
if pred_val_axis in ("col", "x"):
xlbl = "Predicted"
ylbl = "Published [True]"
else:
xlbl = "Published [True]"
ylbl = "Predicted"
df_cm = df_cm.T
# create "Total" column
df_cm = df_cm.copy()
insert_totals(df_cm)
df_cm = df_cm.astype(int)
fig1 = plt.figure("Conf matrix default", figsize)
ax1 = fig1.gca() # Get Current Axis
ax1.cla() # clear existing plot
ax = sns.heatmap(
df_cm,
annot=annot,
annot_kws={"size": fz},
linewidths=lw,
ax=ax1,
cbar=cbar,
cmap=cmap,
linecolor="w",
fmt=fmt,
)
# force one tick per column
n_cols = df_cm.shape[1]
ax.set_xticks(np.arange(n_cols) + 0.5) # centre ticks in each cell
ax.set_xticklabels(df_cm.columns, rotation=45, # use all column names
fontsize=13, ha="right")
# force one tick per row
n_rows = df_cm.shape[0]
ax.set_yticks(np.arange(n_rows) + 0.5)
ax.set_yticklabels(df_cm.index, rotation=35, # use all index names
fontsize=13, va="top")
# # set ticklabels rotation
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=13, ha="right")
# ax.set_yticklabels(ax.get_yticklabels(), rotation=35, fontsize=13, va="top")
# Turn off all the ticks
for t in ax.xaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
for t in ax.yaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
# face colors list
quadmesh = ax.findobj(QuadMesh)[0]
facecolors = quadmesh.get_facecolors()
# iter in text elements
array_df = np.array(df_cm.to_records(index=False).tolist())
text_add = []
text_del = []
posi = -1 # from left to right, bottom to top.
for t in ax.collections[0].axes.texts: # ax.texts:
pos = np.array(t.get_position()) - [0.5, 0.5]
lin = int(pos[1])
col = int(pos[0])
posi += 1
# set text
txt_res = config_cell_text_and_colors(
array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values
)
text_add.extend(txt_res[0])
text_del.extend(txt_res[1])
# remove the old ones
for item in text_del:
item.remove()
# append the new ones
for item in text_add:
ax.text(item["x"], item["y"], item["text"], **item["kw"])
# titles and legends
ax.set_xlabel(xlbl)
ax.set_ylabel(ylbl)
plt.tight_layout() # set layout slim
if savefig:
plt.savefig(savefig + '.pdf')
[docs]
def insert_totals(df_cm):
"""
Inserts total sums for each row and column into the confusion matrix DataFrame.
This function adds a 'sum_row' column and a 'sum_col' row to the DataFrame, representing
the total counts across each row and column, respectively. It also sets the bottom-right
cell to the grand total.
Parameters:
df_cm (pd.DataFrame): DataFrame representing the confusion matrix.
Returns:
None: The function modifies the DataFrame in place.
Note:
If 'sum_row' or 'sum_col' already exist in the DataFrame, they will be recalculated.
"""
# Check if 'sum_row' and 'sum_col' already exist and remove them if they do
if "sum_row" in df_cm.columns:
df_cm.drop("sum_row", axis=1, inplace=True)
if "sum_col" in df_cm.index:
df_cm.drop("sum_col", axis=0, inplace=True)
# Calculate the sum of each column to create 'sum_row'
sum_col = df_cm.sum(axis=0).astype(int) # sum columns
sum_lin = df_cm.sum(axis=1).astype(int) # sum rows
# Add 'sum_row' and 'sum_col' to the dataframe
df_cm["sum_row"] = sum_lin
df_cm.loc["sum_col"] = sum_col
df_cm.at[
"sum_col", "sum_row"
] = sum_lin.sum() # Set the bottom right cell to the grand total
[docs]
def config_cell_text_and_colors(
array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0
):
"""
Configures cell text and colors for confusion matrix visualization.
Adjusts the text and background colors of cells in the confusion matrix based on their values.
Totals and percentages are calculated for the last row and column cells.
Parameters:
array_df (np.ndarray): 2D numpy array of the confusion matrix.
lin (int): Row index of the cell to configure.
col (int): Column index of the cell to configure.
oText (matplotlib.text.Text): Text object of the cell.
facecolors (np.ndarray): Array of facecolors for the cells.
posi (int): Position index in the flattened array of cells.
fz (int): Font size for cell text.
fmt (str): Format string for cell text.
show_null_values (int, optional): Flag to show null values. Default is 0.
Returns:
tuple: A tuple containing two lists: text elements to add and to delete.
Note:
The function modifies text and background colors based on the value in each cell.
"""
import matplotlib.font_manager as fm
text_add = []
text_del = []
cell_val = array_df[lin][col]
tot_all = array_df[-1][-1]
per = (float(cell_val) / tot_all) * 100
curr_column = array_df[:, col]
ccl = len(curr_column)
# last line and/or last column
if (col == (ccl - 1)) or (lin == (ccl - 1)):
# tots and percents
if cell_val != 0:
if (col == ccl - 1) and (lin == ccl - 1):
tot_rig = 0
for i in range(array_df.shape[0] - 1):
tot_rig += array_df[i][i]
per_ok = (float(tot_rig) / cell_val) * 100
elif col == ccl - 1:
tot_rig = array_df[lin][lin]
per_ok = (float(tot_rig) / cell_val) * 100
elif lin == ccl - 1:
tot_rig = array_df[col][col]
per_ok = (float(tot_rig) / cell_val) * 100
per_err = 100 - per_ok
else:
per_ok = per_err = 0
per_ok_s = "100%" if per_ok == 100 else f"{per_ok:.1f}%"
# text to DEL
text_del.append(oText)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# text to ADD
font_prop = fm.FontProperties(weight="bold", size=fz)
text_kwargs = dict(
color="k",
ha="center",
va="center",
gid="sum",
fontproperties=font_prop,
)
lis_txt = [f"{int(cell_val)}", per_ok_s, f"{per_err:.1f}%"]
lis_kwa = [text_kwargs]
dic = text_kwargs.copy()
dic["color"] = "g"
lis_kwa.append(dic)
dic = text_kwargs.copy()
dic["color"] = "r"
lis_kwa.append(dic)
lis_pos = [
(oText._x, oText._y - 0.3),
(oText._x, oText._y),
(oText._x, oText._y + 0.3),
]
for i in range(len(lis_txt)):
newText = dict(
x=lis_pos[i][0],
y=lis_pos[i][1],
text=lis_txt[i],
kw=lis_kwa[i],
)
text_add.append(newText)
# set background color for sum cells (last line and last column)
carr = [0.27, 0.30, 0.27, 1.0]
if (col == ccl - 1) and (lin == ccl - 1):
carr = [0.17, 0.20, 0.17, 1.0]
facecolors[posi] = carr
else:
if per > 0:
txt = "%s\n%.1f%%" % (cell_val, per)
else:
if show_null_values == 0:
txt = ""
elif show_null_values == 1:
txt = "0"
else:
txt = "0\n0.0%"
oText.set_text(txt)
# main diagonal
if col == lin:
# set color of the textin the diagonal to white
oText.set_color("k")
# set background color in the diagonal to blue
facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
else:
oText.set_color("r")
return text_add, text_del