Workflow#

class Workflow(evaluation_manager: EvaluationManager, X_train: DataFrame, X_test: DataFrame, y_train: Series, y_test: Series, output_dir: str, algorithm_names: List[str], feature_names: List[str], workflow_attributes: Dict[str, Any])[source]#

Abstract base class for machine learning workflows.

This class defines the interface and common functionality for machine learning workflows. It provides a standardized way to structure machine learning experiments with consistent data handling, model evaluation, visualization, and result saving capabilities.

The Workflow class serves as a foundation that specific workflow implementations must inherit from. It delegates evaluation and visualization tasks to the EvaluationManager, ensuring consistent behavior across different workflow types.

Parameters:
evaluation_managerEvaluationManager

Manager for model evaluation, visualization, and analysis

X_trainpd.DataFrame

Training feature data with pandas DataFrame structure

X_testpd.DataFrame

Test feature data with pandas DataFrame structure

y_trainpd.Series

Training target data with pandas Series structure

y_testpd.Series

Test target data with pandas Series structure

output_dirstr

Directory path where workflow results will be saved

algorithm_namesList[str]

List of algorithm names to be used in the workflow

feature_namesList[str]

List of feature names corresponding to the data columns

workflow_attributesDict[str, Any]

Additional attributes to be unpacked as instance attributes

Attributes:
evaluation_managerEvaluationManager

Manager for model evaluation, visualization, and analysis

X_trainpd.DataFrame

Training feature data with ‘is_test’ attribute set to False

X_testpd.DataFrame

Test feature data with ‘is_test’ attribute set to True

y_trainpd.Series

Training target data with ‘is_test’ attribute set to False

y_testpd.Series

Test target data with ‘is_test’ attribute set to True

output_dirstr

Output directory path for saving results

algorithm_namesList[str]

List of algorithm names for the workflow

feature_namesList[str]

List of feature names for the dataset

model1, model2, …BaseEstimator

Model instances unpacked from workflow_attributes

Notes

The Workflow class provides a comprehensive interface for machine learning experiments including: - Model evaluation and comparison - Visualization and plotting - Hyperparameter tuning - Model saving and loading - SHAP value analysis

All data objects (X_train, X_test, y_train, y_test) are marked with an ‘is_test’ attribute to distinguish between training and test data.

Subclasses must implement the abstract workflow() method that defines the specific workflow logic for their use case.

Examples

>>> from brisk.training.workflow import Workflow
>>> from brisk.evaluation import evaluation_manager
>>> import pandas as pd
>>> 
>>> class ClassificationWorkflow(Workflow):
...     def workflow(self):
...         # Train models
...         from sklearn.ensemble import RandomForestClassifier
...         model = RandomForestClassifier()
...         model.fit(self.X_train, self.y_train)
...         
...         # Evaluate model
...         self.evaluate_model(
...             model, self.X_test, self.y_test,
...             ['accuracy', 'precision', 'recall'], 'rf_results'
...         )
...         
...         # Generate plots
...         self.plot_confusion_heatmap(
...             model, self.X_test.values, 
...             self.y_test.values, 'confusion_matrix'
...         )
...         self.plot_roc_curve(model, self.X_test.values, 
...                           self.y_test.values, 'roc_curve')
...         
...         # Save model
...         self.save_model(model, 'trained_model')
compare_models(*models: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, calculate_diff: bool = False) Dict[str, Dict[str, float]][source]#

Compare multiple models using specified metrics.

This method evaluates multiple models on the same data and metrics, allowing for direct comparison of their performance. It can optionally calculate differences between model performances.

Parameters:
*modelsBaseEstimator

Variable number of trained models to compare

Xpd.DataFrame

Feature data for evaluation

ypd.Series

Target data for evaluation

metricsList[str]

List of metric names to calculate for comparison

filenamestr

Base filename for saving comparison results (without extension)

calculate_diffbool, default=False

Whether to compute and include differences between model performances

Returns:
Dict[str, Dict[str, float]]

Nested dictionary with model names as keys and metric results as values. Structure: {model_name: {metric_name: metric_value}}

Notes

The method uses the ‘brisk_compare_models’ evaluator from the evaluation manager. Results are saved and returned for further analysis or reporting.

Examples

>>> results = workflow.compare_models(model1, model2, model3,
...                                  X=X_test, y=y_test,
...                                  metrics=['accuracy', 'f1_score'],
...                                  filename='model_comparison',
...                                  calculate_diff=True)
>>> print(results['model1']['accuracy'])
confusion_matrix(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#

Generate and save a confusion matrix.

This method creates a confusion matrix for classification models, showing the count of correct and incorrect predictions for each class. It’s useful for understanding model performance on classification tasks.

Parameters:
modelAny

Trained classification model with predict method

Xnp.ndarray

Feature data for making predictions

ynp.ndarray

True class labels

filenamestr

Output filename for the confusion matrix (without extension)

Notes

The method uses the ‘brisk_confusion_matrix’ evaluator from the evaluation manager. The confusion matrix shows: - True Positives (TP): Correctly predicted positive cases - True Negatives (TN): Correctly predicted negative cases - False Positives (FP): Incorrectly predicted positive cases - False Negatives (FN): Incorrectly predicted negative cases

Examples

>>> workflow.confusion_matrix(
...     model, X_test.values, y_test.values, 'confusion_matrix'
... )
evaluate_model(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str) None[source]#

Evaluate model on specified metrics and save results.

This method evaluates a trained model on the provided data using the specified metrics and saves the results to files. It delegates to the evaluation manager to perform the actual evaluation.

Parameters:
modelBaseEstimator

Trained model to evaluate (must have predict method)

Xpd.DataFrame

Feature data for evaluation

ypd.Series

Target data for evaluation

metricsList[str]

List of metric names to calculate (e.g., [‘accuracy’, ‘precision’])

filenamestr

Base filename for saving results (without extension)

Notes

The method uses the ‘brisk_evaluate_model’ evaluator from the evaluation manager. Results are saved in the workflow’s output directory with the specified filename.

Examples

>>> workflow.evaluate_model(model, X_test, y_test, 
...                        ['accuracy', 'precision', 'recall'], 
...                        'model_evaluation')
evaluate_model_cv(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, cv: int = 5) None[source]#

Evaluate model using cross-validation.

This method evaluates a model using k-fold cross-validation to provide more robust performance estimates. It trains and evaluates the model on multiple train/test splits and saves the results.

Parameters:
modelBaseEstimator

Model to evaluate (will be cloned for each CV fold)

Xpd.DataFrame

Feature data for evaluation

ypd.Series

Target data for evaluation

metricsList[str]

List of metric names to calculate

filenamestr

Base filename for saving results (without extension)

cvint, default=5

Number of cross-validation folds to use

Notes

The method uses the ‘brisk_evaluate_model_cv’ evaluator from the evaluation manager. Cross-validation provides more reliable performance estimates by testing on multiple data splits.

Examples

>>> workflow.evaluate_model_cv(model, X_train, y_train,
...                           ['accuracy', 'f1_score'], 
...                           'cv_evaluation', cv=10)
hyperparameter_tuning(model: BaseEstimator, method: str, X_train: DataFrame, y_train: Series, scorer: str, kf: int, num_rep: int, n_jobs: int, plot_results: bool = False) BaseEstimator[source]#

Perform hyperparameter tuning using grid or random search.

This method optimizes model hyperparameters using either grid search or random search with cross-validation. It returns the best model found during the search process.

Parameters:
modelBaseEstimator

Base model to tune (will be cloned for each parameter combination)

methodstr

Search method to use (‘grid’ or ‘random’)

X_trainpd.DataFrame

Training feature data

y_trainpd.Series

Training target data

scorerstr

Scoring metric to optimize (e.g., ‘accuracy’, ‘neg_mean_squared_error’)

kfint

Number of cross-validation folds

num_repint

Number of CV repetitions for stability

n_jobsint

Number of parallel jobs (-1 uses all cores)

plot_resultsbool, default=False

Whether to generate plots showing hyperparameter performance

Returns:
BaseEstimator

Best model found during hyperparameter search

Notes

The method uses the ‘brisk_hyperparameter_tuning’ evaluator from the evaluation manager. The search process tests different parameter combinations and selects the one with the best cross-validation score.

Examples

>>> tuned_model = workflow.hyperparameter_tuning(
...     RandomForestClassifier(), 'grid', X_train, y_train,
...     'accuracy', kf=5, num_rep=3, n_jobs=-1, plot_results=True)
load_model(filepath: str) BaseEstimator[source]#

Load model from pickle file.

This method loads a previously saved model from a pickle file, allowing it to be used for inference or further analysis.

Parameters:
filepathstr

Path to the saved model file (with extension)

Returns:
BaseEstimator

Loaded model ready for use

Raises:
FileNotFoundError

If the model file does not exist at the specified path

Notes

The method delegates to the evaluation manager’s load_model method. The loaded model can be used for making predictions or further evaluation.

Examples

>>> loaded_model = workflow.load_model('my_model.pkl')
>>> predictions = loaded_model.predict(X_new)
plot_confusion_heatmap(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#

Plot a heatmap of the confusion matrix for a model.

This method generates a visual heatmap representation of the confusion matrix, making it easier to interpret classification performance across different classes.

Parameters:
modelAny

Trained classification model with predict method

Xnp.ndarray

Feature data for making predictions

ynp.ndarray

True class labels

filenamestr

Output filename for the heatmap plot (without extension)

Notes

The method uses the ‘brisk_plot_confusion_heatmap’ evaluator from the evaluation manager. The heatmap uses color intensity to show the count of predictions, making it easy to identify patterns in classification errors.

Examples

>>> workflow.plot_confusion_heatmap(
...     model, X_test.values, y_test.values, 'confusion_heatmap'
... )
plot_feature_importance(model: BaseEstimator, X: DataFrame, y: Series, threshold: int | float, feature_names: List[str], filename: str, metric: str, num_rep: int) None[source]#

Plot feature importance for the model and save the plot.

This method generates a plot showing the importance of each feature in the model’s predictions. It can filter features by importance threshold or number of top features.

Parameters:
modelBaseEstimator

Trained model with feature importance or permutation importance

Xpd.DataFrame

Feature data for importance calculation

ypd.Series

Target data for importance calculation

thresholdUnion[int, float]

If int: number of top features to show If float: minimum importance threshold for features

feature_namesList[str]

List of feature names corresponding to X columns

filenamestr

Output filename for the plot (without extension)

metricstr

Metric to use for importance calculation

num_repint

Number of repetitions for calculating importance (for stability)

Notes

The method uses the ‘brisk_plot_feature_importance’ evaluator from the evaluation manager. Feature importance helps identify which features contribute most to model predictions.

Examples

>>> workflow.plot_feature_importance(model, X_train, y_train,
...                                 threshold=10,
...                                 feature_names=feature_names,
...                                 filename='feature_importance',
...                                 metric='accuracy', num_rep=5)
plot_learning_curve(model: BaseEstimator, X_train: DataFrame, y_train: Series, filename: str = 'learning_curve', cv: int = 5, num_repeats: int = 1, n_jobs: int = -1, metric: str = 'neg_mean_absolute_error') None[source]#

Plot learning curves showing model performance vs training size.

This method generates learning curves that show how model performance changes as the training set size increases. This helps identify whether the model would benefit from more data or if it’s suffering from high bias or variance.

Parameters:
modelBaseEstimator

Model to evaluate (will be cloned for each training size)

X_trainpd.DataFrame

Training feature data

y_trainpd.Series

Training target data

filenamestr, default=”learning_curve”

Base filename for saving the plot (without extension)

cvint, default=5

Number of cross-validation folds for each training size

num_repeatsint, default=1

Number of times to repeat cross-validation for stability

n_jobsint, default=-1

Number of parallel jobs for cross-validation (-1 uses all cores)

metricstr, default=”neg_mean_absolute_error”

Scoring metric to use for evaluation

Notes

The method uses the ‘brisk_plot_learning_curve’ evaluator from the evaluation manager. Learning curves help diagnose model behavior: - High bias: both training and validation scores are low - High variance: large gap between training and validation scores - Good fit: both scores converge to similar high values

Examples

>>> workflow.plot_learning_curve(model, X_train, y_train,
...                             filename='rf_learning_curve',
...                             cv=10, metric='neg_mean_squared_error')
plot_model_comparison(*models: BaseEstimator, X: DataFrame, y: Series, metric: str, filename: str) None[source]#

Plot comparison of multiple models based on specified metric.

This method generates a visualization comparing the performance of multiple models on a single metric, making it easy to identify the best performing model.

Parameters:
*modelsBaseEstimator

Variable number of trained models to compare

Xpd.DataFrame

Feature data for evaluation

ypd.Series

Target data for evaluation

metricstr

Single metric name to use for comparison

filenamestr

Output filename for the plot (without extension)

Notes

The method uses the ‘brisk_plot_model_comparison’ evaluator from the evaluation manager. The plot typically shows model names on one axis and metric values on the other, making performance comparison easy.

Examples

>>> workflow.plot_model_comparison(model1, model2, model3,
...                               X=X_test, y=y_test,
...                               metric='accuracy',
...                               filename='model_comparison'
...                          )
plot_precision_recall_curve(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#

Plot a precision-recall curve with average precision.

This method generates a precision-recall curve for binary classification models, showing the trade-off between precision and recall at different classification thresholds. This is particularly useful for imbalanced datasets.

Parameters:
modelAny

Trained binary classification model with predict_proba method

Xnp.ndarray

Feature data for making predictions

ynp.ndarray

True binary class labels (0 and 1)

filenamestr

Output filename for the precision-recall curve plot (without extension)

pos_labelOptional[int], default=1

Label of the positive class for precision-recall calculation

Notes

The method uses the ‘brisk_plot_precision_recall_curve’ evaluator from the evaluation manager. The precision-recall curve shows: - X-axis: Recall (True Positive Rate) - Y-axis: Precision (Positive Predictive Value) - AP: Average Precision (higher is better)

Precision-recall curves are especially useful for imbalanced datasets where the positive class is rare.

Examples

>>> workflow.plot_precision_recall_curve(
...     model, X_test.values, y_test.values, 'pr_curve'
... )
plot_pred_vs_obs(model: BaseEstimator, X: DataFrame, y_true: Series, filename: str) None[source]#

Plot predicted vs. observed values and save the plot.

This method generates a scatter plot comparing predicted values against observed values, which is useful for regression model evaluation and identifying prediction patterns.

Parameters:
modelBaseEstimator

Trained model with predict method

Xpd.DataFrame

Feature data for making predictions

y_truepd.Series

True target values for comparison

filenamestr

Output filename for the plot (without extension)

Notes

The method uses the ‘brisk_plot_pred_vs_obs’ evaluator from the evaluation manager. The plot helps assess model performance by showing how well predictions align with actual values.

Examples

>>> workflow.plot_pred_vs_obs(model, X_test, y_test, 'pred_vs_obs')
plot_residuals(model: BaseEstimator, X: DataFrame, y: Series, filename: str, add_fit_line: bool = False) None[source]#

Plot residuals of the model and save the plot.

This method generates a residual plot showing the difference between predicted and actual values. Residual plots help assess model assumptions and identify patterns in prediction errors.

Parameters:
modelBaseEstimator

Trained model with predict method

Xpd.DataFrame

Feature data for making predictions

ypd.Series

True target values

filenamestr

Output filename for the plot (without extension)

add_fit_linebool, default=False

Whether to add a line of best fit to the residual plot

Notes

The method uses the ‘brisk_plot_residuals’ evaluator from the evaluation manager. Residual plots help identify: - Non-linear patterns in residuals - Heteroscedasticity (varying variance) - Outliers and influential points

Examples

>>> workflow.plot_residuals(
...     model, X_test, y_test, 'residuals', add_fit_line=True
... )
plot_roc_curve(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#

Plot a receiver operating characteristic (ROC) curve with AUC.

This method generates a ROC curve for binary classification models, showing the trade-off between true positive rate and false positive rate at different classification thresholds.

Parameters:
modelAny

Trained binary classification model with predict_proba method

Xnp.ndarray

Feature data for making predictions

ynp.ndarray

True binary class labels (0 and 1)

filenamestr

Output filename for the ROC curve plot (without extension)

pos_labelOptional[int], default=1

Label of the positive class for ROC calculation

Notes

The method uses the ‘brisk_plot_roc_curve’ evaluator from the evaluation manager. The ROC curve shows: - X-axis: False Positive Rate (1 - Specificity) - Y-axis: True Positive Rate (Sensitivity) - AUC: Area Under the Curve (higher is better)

Examples

>>> workflow.plot_roc_curve(
...     model, X_test.values, y_test.values, 'roc_curve'
... )
plot_shapley_values(model: BaseEstimator, X: DataFrame, y: Series, filename: str = 'shapley_values', plot_type: str = 'bar') None[source]#

Generate SHAP value plots for feature importance.

This method generates SHAP (SHapley Additive exPlanations) value plots to explain individual predictions and feature importance. SHAP values provide a unified framework for explaining model predictions.

Parameters:
modelBaseEstimator

Trained model to explain (must be compatible with SHAP)

Xpd.DataFrame

Feature data for generating explanations

ypd.Series

Target data (used for context in some plot types)

filenamestr, default=”shapley_values”

Base output filename for SHAP plots (without extension)

plot_typestr, default=”bar”

Type of SHAP plot to generate. Options: - ‘bar’: Bar plot of mean SHAP values - ‘waterfall’: Waterfall plot for individual predictions - ‘violin’: Violin plot showing SHAP value distributions - ‘beeswarm’: Beeswarm plot for feature importance Multiple types can be specified as ‘bar,waterfall’ to generate multiple plots

Notes

The method uses the ‘brisk_plot_shapley_values’ evaluator from the evaluation manager. SHAP values provide: - Feature importance rankings - Individual prediction explanations - Feature interaction effects - Model interpretability insights

Examples

>>> workflow.plot_shapley_values(model, X_test, y_test, 
...                             'shap_explanation', 'bar,waterfall')
run() None[source]#

Execute the workflow.

This method serves as the entry point for running the workflow. It delegates to the abstract workflow() method that must be implemented by subclasses, providing a consistent interface for workflow execution.

Raises:
NotImplementedError

If called directly on the base Workflow class without implementing the abstract workflow() method

Notes

This method is marked with # pragma: no cover because it’s an abstract method that should be overridden by subclasses. The actual workflow logic is implemented in the workflow() method of concrete subclasses.

save_model(model: BaseEstimator, filename: str) None[source]#

Save model to pickle file.

This method saves a trained model to a pickle file in the workflow’s output directory, allowing it to be loaded later for inference or further analysis.

Parameters:
modelBaseEstimator

Trained model to save

filenamestr

Base filename for the saved model (without extension)

Notes

The method delegates to the evaluation manager’s save_model method. The model is saved in the workflow’s output directory with the specified filename.

Examples

>>> workflow.save_model(trained_model, 'my_model')
abstractmethod workflow(X_train, X_test, y_train, y_test, output_dir, feature_names) None[source]#

Abstract method defining the workflow logic.

This method must be implemented by all concrete subclasses of Workflow. It should contain the specific logic for the machine learning workflow, including model training, evaluation, visualization, and result saving.

Raises:
NotImplementedError

Always raises this error as it’s an abstract method

Notes

This is an abstract method that must be implemented by subclasses. The implementation should define the complete workflow logic for the specific use case, utilizing the available data (X_train, X_test, y_train, y_test) and the evaluation manager for model assessment.

Typical workflow implementations include: - Model instantiation and training - Model evaluation using provided methods - Visualization generation - Result saving and reporting