Classification Plots#

class PlotConfusionHeatmap(method_name: str, description: str, plot_settings)[source]#

Bases: PlotEvaluator

Plot a heatmap of the confusion matrix for a model.

This evaluator creates a visual heatmap representation of the confusion matrix, showing both the count and percentage of predictions for each class combination. The heatmap uses color intensity to represent the percentage of predictions, making it easy to identify patterns in classification performance.

Parameters:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

plot_settingsPlotSettings

The plot settings containing theme and color configuration

Attributes:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

themeAny

The plot theme for styling plots

primary_colorstr

Primary color for plots

secondary_colorstr

Secondary color for plots

accent_colorstr

Accent color for plots

Examples

Use the confusion matrix heatmap evaluator:
>>> from brisk.evaluation.evaluators import registry
>>> evaluator = registry.get("brisk_plot_confusion_heatmap")
>>> evaluator.plot(model, X, y, "confusion_heatmap")
generate_plot_data(prediction: Series, y: ndarray) DataFrame[source]#

Calculate the plot data for the confusion matrix heatmap.

Generates the data needed to create a confusion matrix heatmap, including both count and percentage information for each cell.

Parameters:
predictionpd.Series

The predicted target values from the model

ynp.ndarray

The true target values

Returns:
pd.DataFrame

A dataframe containing the confusion matrix heatmap data with columns for True Label, Predicted Label, Percentage, and Label

Notes

The method calculates both absolute counts and percentages for each cell in the confusion matrix. The percentage is calculated as the proportion of total predictions, and the label combines both count and percentage for display in the heatmap.

The data is structured for use with plotnine’s geom_tile and geom_text functions.

plot(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#

Plot a heatmap of the confusion matrix for a model.

Executes the complete plotting workflow for generating a confusion matrix heatmap. This includes generating predictions, calculating the confusion matrix data, creating the plot, and saving the results.

Parameters:
modelAny

The trained classification model with a predict method

Xnp.ndarray

The input features for evaluation

ynp.ndarray

The true target labels

filenamestr

The name of the file to save the plot to (without extension)

Returns:
None

Notes

This method overrides the base plot method to provide classification-specific plotting workflow. It generates predictions using the model and creates a heatmap visualization of the confusion matrix.

The plot is saved with metadata for later analysis and reporting.

class PlotRocCurve(method_name: str, description: str, plot_settings)[source]#

Bases: PlotEvaluator

Plot a receiver operating characteristic curve with area under the curve.

This evaluator creates ROC curve plots for binary classification models, showing the relationship between true positive rate and false positive rate. The plot includes the area under the curve (AUC) score and a reference line for random guessing.

Parameters:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

plot_settingsPlotSettings

The plot settings containing theme and color configuration

Attributes:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

themeAny

The plot theme for styling plots

primary_colorstr

Primary color for plots

secondary_colorstr

Secondary color for plots

accent_colorstr

Accent color for plots

Notes

The ROC curve is a fundamental tool for evaluating binary classification performance. It shows the trade-off between sensitivity (true positive rate) and specificity (1 - false positive rate) across different classification thresholds.

The AUC score provides a single metric for overall performance, with values closer to 1.0 indicating better performance. A score of 0.5 indicates performance equivalent to random guessing.

Examples

Use the ROC curve evaluator:
>>> from brisk.evaluation.evaluators import registry
>>> evaluator = registry.get("brisk_plot_roc_curve")
>>> evaluator.plot(model, X, y, "roc_curve")
generate_plot_data(model: BaseEstimator, X: ndarray, y: ndarray, pos_label: int | None = 1) Tuple[DataFrame, DataFrame, float][source]#

Calculate the plot data for the ROC curve.

Generates the data needed to create a ROC curve plot, including the curve data, AUC calculation data, and the AUC score.

Parameters:
modelbase.BaseEstimator

The trained binary classification model

Xnp.ndarray

The input features for evaluation

ynp.ndarray

The true binary labels

pos_labelint, optional

The label of the positive class, by default 1

Returns:
Tuple[pd.DataFrame, pd.DataFrame, float]

A tuple containing: - ROC curve data (DataFrame with FPR, TPR, Type columns) - AUC calculation data (DataFrame for area shading) - AUC score (float)

Notes

The method automatically detects the appropriate prediction method: - predict_proba: Uses probability of positive class - decision_function: Uses decision function scores - predict: Uses binary predictions as fallback

The ROC curve data includes both the actual curve and a reference line for random guessing (diagonal line from 0,0 to 1,1).

log_results(plot_name: str, auc: float, filename: str) None[source]#

Log the results of the ROC curve to console.

Displays the ROC curve plot name, AUC score, and file path for easy tracking of evaluation results.

Parameters:
plot_namestr

The name of the plot that was created

aucfloat

The AUC score calculated for the model

filenamestr

The name of the file where the plot was saved

Returns:
None

Notes

The logging includes the full output path with .svg extension and the AUC score for quick performance assessment.

plot(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#

Plot a receiver operating characteristic curve with area under thecurve.

Executes the complete plotting workflow for generating a ROC curve. This includes calculating the ROC curve data, computing the AUC score, creating the plot, and saving the results.

Parameters:
modelAny

The trained binary classification model

Xnp.ndarray

The input features for evaluation

ynp.ndarray

The true binary labels

filenamestr

The name of the file to save the plot to (without extension)

pos_labelint, optional

The label of the positive class, by default 1

Returns:
None

Notes

This method handles different types of binary classification models by automatically detecting whether to use predict_proba, decision_function, or predict methods for obtaining prediction scores.

The plot includes both the ROC curve and a reference line for random guessing, along with the AUC score annotation.

class PlotPrecisionRecallCurve(method_name: str, description: str, plot_settings)[source]#

Bases: PlotEvaluator

Plot a precision-recall curve with area under the curve.

This evaluator creates precision-recall curve plots for binary classification models, showing the relationship between precision and recall across different classification thresholds. The plot includes the average precision (AP) score and a reference line showing the AP score.

Parameters:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

plot_settingsPlotSettings

The plot settings containing theme and color configuration

Attributes:
method_namestr

The name of the evaluator

descriptionstr

The description of the evaluator output

themeAny

The plot theme for styling plots

primary_colorstr

Primary color for plots

secondary_colorstr

Secondary color for plots

accent_colorstr

Accent color for plots

Notes

The precision-recall curve is particularly useful for imbalanced datasets where the focus is on the positive class. It shows the trade-off between precision and recall across different classification thresholds.

The average precision (AP) score provides a single metric for overall performance, with values closer to 1.0 indicating better performance. Unlike AUC, AP is more sensitive to the performance on the positive class.

Examples

Use the precision-recall curve evaluator:
>>> from brisk.evaluation.evaluators import registry
>>> evaluator = registry.get("brisk_plot_precision_recall_curve")
>>> evaluator.plot(model, X, y, "precision_recall_curve")
generate_plot_data(model: BaseEstimator, X: ndarray, y: ndarray, pos_label: int | None = 1) Tuple[DataFrame, float][source]#

Calculate the plot data for the precision-recall curve.

Generates the data needed to create a precision-recall curve plot, including the curve data and the average precision score.

Parameters:
modelbase.BaseEstimator

The trained binary classification model

Xnp.ndarray

The input features for evaluation

ynp.ndarray

The true binary labels

pos_labelint, optional

The label of the positive class, by default 1

Returns:
Tuple[pd.DataFrame, float]

A tuple containing: - Precision-recall curve data (DataFrame with Recall, Precision, Type columns) - Average precision score (float)

Notes

The method automatically detects the appropriate prediction method: - predict_proba: Uses probability of positive class - decision_function: Uses decision function scores - predict: Uses binary predictions as fallback

The precision-recall curve data includes both the actual curve and a reference line showing the average precision score.

log_results(plot_name: str, ap_score: float, filename: str) None[source]#

Log the results of the precision-recall curve to console.

Displays the precision-recall curve plot name, AP score, and file path for easy tracking of evaluation results.

Parameters:
plot_namestr

The name of the plot that was created

ap_scorefloat

The average precision score calculated for the model

filenamestr

The name of the file where the plot was saved

Returns:
None

Notes

The logging includes the full output path with .svg extension and the AP score for quick performance assessment.

plot(model: BaseEstimator, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#

Plot a precision-recall curve with area under the curve.

Executes the complete plotting workflow for generating a precision-recall curve. This includes calculating the curve data, computing the AP score, creating the plot, and saving the results.

Parameters:
modelbase.BaseEstimator

The trained binary classification model

Xnp.ndarray

The input features for evaluation

ynp.ndarray

The true binary labels

filenamestr

The name of the file to save the plot to (without extension)

pos_labelint, optional

The label of the positive class, by default 1

Returns:
None

Notes

This method handles different types of binary classification models by automatically detecting whether to use predict_proba, decision_function, or predict methods for obtaining prediction scores.

The plot includes both the precision-recall curve and a reference line showing the average precision score.