Classification Plots#
- class PlotConfusionHeatmap(method_name: str, description: str, plot_settings)[source]#
Bases:
PlotEvaluatorPlot a heatmap of the confusion matrix for a model.
This evaluator creates a visual heatmap representation of the confusion matrix, showing both the count and percentage of predictions for each class combination. The heatmap uses color intensity to represent the percentage of predictions, making it easy to identify patterns in classification performance.
- Parameters:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- plot_settingsPlotSettings
The plot settings containing theme and color configuration
- Attributes:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- themeAny
The plot theme for styling plots
- primary_colorstr
Primary color for plots
- secondary_colorstr
Secondary color for plots
- accent_colorstr
Accent color for plots
Examples
- Use the confusion matrix heatmap evaluator:
>>> from brisk.evaluation.evaluators import registry >>> evaluator = registry.get("brisk_plot_confusion_heatmap") >>> evaluator.plot(model, X, y, "confusion_heatmap")
- generate_plot_data(prediction: Series, y: ndarray) DataFrame[source]#
Calculate the plot data for the confusion matrix heatmap.
Generates the data needed to create a confusion matrix heatmap, including both count and percentage information for each cell.
- Parameters:
- predictionpd.Series
The predicted target values from the model
- ynp.ndarray
The true target values
- Returns:
- pd.DataFrame
A dataframe containing the confusion matrix heatmap data with columns for True Label, Predicted Label, Percentage, and Label
Notes
The method calculates both absolute counts and percentages for each cell in the confusion matrix. The percentage is calculated as the proportion of total predictions, and the label combines both count and percentage for display in the heatmap.
The data is structured for use with plotnine’s geom_tile and geom_text functions.
- plot(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#
Plot a heatmap of the confusion matrix for a model.
Executes the complete plotting workflow for generating a confusion matrix heatmap. This includes generating predictions, calculating the confusion matrix data, creating the plot, and saving the results.
- Parameters:
- modelAny
The trained classification model with a predict method
- Xnp.ndarray
The input features for evaluation
- ynp.ndarray
The true target labels
- filenamestr
The name of the file to save the plot to (without extension)
- Returns:
- None
Notes
This method overrides the base plot method to provide classification-specific plotting workflow. It generates predictions using the model and creates a heatmap visualization of the confusion matrix.
The plot is saved with metadata for later analysis and reporting.
- class PlotRocCurve(method_name: str, description: str, plot_settings)[source]#
Bases:
PlotEvaluatorPlot a receiver operating characteristic curve with area under the curve.
This evaluator creates ROC curve plots for binary classification models, showing the relationship between true positive rate and false positive rate. The plot includes the area under the curve (AUC) score and a reference line for random guessing.
- Parameters:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- plot_settingsPlotSettings
The plot settings containing theme and color configuration
- Attributes:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- themeAny
The plot theme for styling plots
- primary_colorstr
Primary color for plots
- secondary_colorstr
Secondary color for plots
- accent_colorstr
Accent color for plots
Notes
The ROC curve is a fundamental tool for evaluating binary classification performance. It shows the trade-off between sensitivity (true positive rate) and specificity (1 - false positive rate) across different classification thresholds.
The AUC score provides a single metric for overall performance, with values closer to 1.0 indicating better performance. A score of 0.5 indicates performance equivalent to random guessing.
Examples
- Use the ROC curve evaluator:
>>> from brisk.evaluation.evaluators import registry >>> evaluator = registry.get("brisk_plot_roc_curve") >>> evaluator.plot(model, X, y, "roc_curve")
- generate_plot_data(model: BaseEstimator, X: ndarray, y: ndarray, pos_label: int | None = 1) Tuple[DataFrame, DataFrame, float][source]#
Calculate the plot data for the ROC curve.
Generates the data needed to create a ROC curve plot, including the curve data, AUC calculation data, and the AUC score.
- Parameters:
- modelbase.BaseEstimator
The trained binary classification model
- Xnp.ndarray
The input features for evaluation
- ynp.ndarray
The true binary labels
- pos_labelint, optional
The label of the positive class, by default 1
- Returns:
- Tuple[pd.DataFrame, pd.DataFrame, float]
A tuple containing: - ROC curve data (DataFrame with FPR, TPR, Type columns) - AUC calculation data (DataFrame for area shading) - AUC score (float)
Notes
The method automatically detects the appropriate prediction method: - predict_proba: Uses probability of positive class - decision_function: Uses decision function scores - predict: Uses binary predictions as fallback
The ROC curve data includes both the actual curve and a reference line for random guessing (diagonal line from 0,0 to 1,1).
- log_results(plot_name: str, auc: float, filename: str) None[source]#
Log the results of the ROC curve to console.
Displays the ROC curve plot name, AUC score, and file path for easy tracking of evaluation results.
- Parameters:
- plot_namestr
The name of the plot that was created
- aucfloat
The AUC score calculated for the model
- filenamestr
The name of the file where the plot was saved
- Returns:
- None
Notes
The logging includes the full output path with .svg extension and the AUC score for quick performance assessment.
- plot(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#
Plot a receiver operating characteristic curve with area under thecurve.
Executes the complete plotting workflow for generating a ROC curve. This includes calculating the ROC curve data, computing the AUC score, creating the plot, and saving the results.
- Parameters:
- modelAny
The trained binary classification model
- Xnp.ndarray
The input features for evaluation
- ynp.ndarray
The true binary labels
- filenamestr
The name of the file to save the plot to (without extension)
- pos_labelint, optional
The label of the positive class, by default 1
- Returns:
- None
Notes
This method handles different types of binary classification models by automatically detecting whether to use predict_proba, decision_function, or predict methods for obtaining prediction scores.
The plot includes both the ROC curve and a reference line for random guessing, along with the AUC score annotation.
- class PlotPrecisionRecallCurve(method_name: str, description: str, plot_settings)[source]#
Bases:
PlotEvaluatorPlot a precision-recall curve with area under the curve.
This evaluator creates precision-recall curve plots for binary classification models, showing the relationship between precision and recall across different classification thresholds. The plot includes the average precision (AP) score and a reference line showing the AP score.
- Parameters:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- plot_settingsPlotSettings
The plot settings containing theme and color configuration
- Attributes:
- method_namestr
The name of the evaluator
- descriptionstr
The description of the evaluator output
- themeAny
The plot theme for styling plots
- primary_colorstr
Primary color for plots
- secondary_colorstr
Secondary color for plots
- accent_colorstr
Accent color for plots
Notes
The precision-recall curve is particularly useful for imbalanced datasets where the focus is on the positive class. It shows the trade-off between precision and recall across different classification thresholds.
The average precision (AP) score provides a single metric for overall performance, with values closer to 1.0 indicating better performance. Unlike AUC, AP is more sensitive to the performance on the positive class.
Examples
- Use the precision-recall curve evaluator:
>>> from brisk.evaluation.evaluators import registry >>> evaluator = registry.get("brisk_plot_precision_recall_curve") >>> evaluator.plot(model, X, y, "precision_recall_curve")
- generate_plot_data(model: BaseEstimator, X: ndarray, y: ndarray, pos_label: int | None = 1) Tuple[DataFrame, float][source]#
Calculate the plot data for the precision-recall curve.
Generates the data needed to create a precision-recall curve plot, including the curve data and the average precision score.
- Parameters:
- modelbase.BaseEstimator
The trained binary classification model
- Xnp.ndarray
The input features for evaluation
- ynp.ndarray
The true binary labels
- pos_labelint, optional
The label of the positive class, by default 1
- Returns:
- Tuple[pd.DataFrame, float]
A tuple containing: - Precision-recall curve data (DataFrame with Recall, Precision, Type columns) - Average precision score (float)
Notes
The method automatically detects the appropriate prediction method: - predict_proba: Uses probability of positive class - decision_function: Uses decision function scores - predict: Uses binary predictions as fallback
The precision-recall curve data includes both the actual curve and a reference line showing the average precision score.
- log_results(plot_name: str, ap_score: float, filename: str) None[source]#
Log the results of the precision-recall curve to console.
Displays the precision-recall curve plot name, AP score, and file path for easy tracking of evaluation results.
- Parameters:
- plot_namestr
The name of the plot that was created
- ap_scorefloat
The average precision score calculated for the model
- filenamestr
The name of the file where the plot was saved
- Returns:
- None
Notes
The logging includes the full output path with .svg extension and the AP score for quick performance assessment.
- plot(model: BaseEstimator, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#
Plot a precision-recall curve with area under the curve.
Executes the complete plotting workflow for generating a precision-recall curve. This includes calculating the curve data, computing the AP score, creating the plot, and saving the results.
- Parameters:
- modelbase.BaseEstimator
The trained binary classification model
- Xnp.ndarray
The input features for evaluation
- ynp.ndarray
The true binary labels
- filenamestr
The name of the file to save the plot to (without extension)
- pos_labelint, optional
The label of the positive class, by default 1
- Returns:
- None
Notes
This method handles different types of binary classification models by automatically detecting whether to use predict_proba, decision_function, or predict methods for obtaining prediction scores.
The plot includes both the precision-recall curve and a reference line showing the average precision score.