Workflow#
- class Workflow(evaluation_manager: EvaluationManager, X_train: DataFrame, X_test: DataFrame, y_train: Series, y_test: Series, output_dir: str, algorithm_names: List[str], feature_names: List[str], workflow_attributes: Dict[str, Any])[source]#
Abstract base class for machine learning workflows.
This class defines the interface and common functionality for machine learning workflows. It provides a standardized way to structure machine learning experiments with consistent data handling, model evaluation, visualization, and result saving capabilities.
The Workflow class serves as a foundation that specific workflow implementations must inherit from. It delegates evaluation and visualization tasks to the EvaluationManager, ensuring consistent behavior across different workflow types.
- Parameters:
- evaluation_managerEvaluationManager
Manager for model evaluation, visualization, and analysis
- X_trainpd.DataFrame
Training feature data with pandas DataFrame structure
- X_testpd.DataFrame
Test feature data with pandas DataFrame structure
- y_trainpd.Series
Training target data with pandas Series structure
- y_testpd.Series
Test target data with pandas Series structure
- output_dirstr
Directory path where workflow results will be saved
- algorithm_namesList[str]
List of algorithm names to be used in the workflow
- feature_namesList[str]
List of feature names corresponding to the data columns
- workflow_attributesDict[str, Any]
Additional attributes to be unpacked as instance attributes
- Attributes:
- evaluation_managerEvaluationManager
Manager for model evaluation, visualization, and analysis
- X_trainpd.DataFrame
Training feature data with ‘is_test’ attribute set to False
- X_testpd.DataFrame
Test feature data with ‘is_test’ attribute set to True
- y_trainpd.Series
Training target data with ‘is_test’ attribute set to False
- y_testpd.Series
Test target data with ‘is_test’ attribute set to True
- output_dirstr
Output directory path for saving results
- algorithm_namesList[str]
List of algorithm names for the workflow
- feature_namesList[str]
List of feature names for the dataset
- model1, model2, …BaseEstimator
Model instances unpacked from workflow_attributes
Notes
The Workflow class provides a comprehensive interface for machine learning experiments including: - Model evaluation and comparison - Visualization and plotting - Hyperparameter tuning - Model saving and loading - SHAP value analysis
All data objects (X_train, X_test, y_train, y_test) are marked with an ‘is_test’ attribute to distinguish between training and test data.
Subclasses must implement the abstract workflow() method that defines the specific workflow logic for their use case.
Examples
>>> from brisk.training.workflow import Workflow >>> from brisk.evaluation import evaluation_manager >>> import pandas as pd >>> >>> class ClassificationWorkflow(Workflow): ... def workflow(self): ... # Train models ... from sklearn.ensemble import RandomForestClassifier ... model = RandomForestClassifier() ... model.fit(self.X_train, self.y_train) ... ... # Evaluate model ... self.evaluate_model( ... model, self.X_test, self.y_test, ... ['accuracy', 'precision', 'recall'], 'rf_results' ... ) ... ... # Generate plots ... self.plot_confusion_heatmap( ... model, self.X_test.values, ... self.y_test.values, 'confusion_matrix' ... ) ... self.plot_roc_curve(model, self.X_test.values, ... self.y_test.values, 'roc_curve') ... ... # Save model ... self.save_model(model, 'trained_model')
- compare_models(*models: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, calculate_diff: bool = False) Dict[str, Dict[str, float]][source]#
Compare multiple models using specified metrics.
This method evaluates multiple models on the same data and metrics, allowing for direct comparison of their performance. It can optionally calculate differences between model performances.
- Parameters:
- *modelsBaseEstimator
Variable number of trained models to compare
- Xpd.DataFrame
Feature data for evaluation
- ypd.Series
Target data for evaluation
- metricsList[str]
List of metric names to calculate for comparison
- filenamestr
Base filename for saving comparison results (without extension)
- calculate_diffbool, default=False
Whether to compute and include differences between model performances
- Returns:
- Dict[str, Dict[str, float]]
Nested dictionary with model names as keys and metric results as values. Structure: {model_name: {metric_name: metric_value}}
Notes
The method uses the ‘brisk_compare_models’ evaluator from the evaluation manager. Results are saved and returned for further analysis or reporting.
Examples
>>> results = workflow.compare_models(model1, model2, model3, ... X=X_test, y=y_test, ... metrics=['accuracy', 'f1_score'], ... filename='model_comparison', ... calculate_diff=True) >>> print(results['model1']['accuracy'])
- confusion_matrix(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#
Generate and save a confusion matrix.
This method creates a confusion matrix for classification models, showing the count of correct and incorrect predictions for each class. It’s useful for understanding model performance on classification tasks.
- Parameters:
- modelAny
Trained classification model with predict method
- Xnp.ndarray
Feature data for making predictions
- ynp.ndarray
True class labels
- filenamestr
Output filename for the confusion matrix (without extension)
Notes
The method uses the ‘brisk_confusion_matrix’ evaluator from the evaluation manager. The confusion matrix shows: - True Positives (TP): Correctly predicted positive cases - True Negatives (TN): Correctly predicted negative cases - False Positives (FP): Incorrectly predicted positive cases - False Negatives (FN): Incorrectly predicted negative cases
Examples
>>> workflow.confusion_matrix( ... model, X_test.values, y_test.values, 'confusion_matrix' ... )
- evaluate_model(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str) None[source]#
Evaluate model on specified metrics and save results.
This method evaluates a trained model on the provided data using the specified metrics and saves the results to files. It delegates to the evaluation manager to perform the actual evaluation.
- Parameters:
- modelBaseEstimator
Trained model to evaluate (must have predict method)
- Xpd.DataFrame
Feature data for evaluation
- ypd.Series
Target data for evaluation
- metricsList[str]
List of metric names to calculate (e.g., [‘accuracy’, ‘precision’])
- filenamestr
Base filename for saving results (without extension)
Notes
The method uses the ‘brisk_evaluate_model’ evaluator from the evaluation manager. Results are saved in the workflow’s output directory with the specified filename.
Examples
>>> workflow.evaluate_model(model, X_test, y_test, ... ['accuracy', 'precision', 'recall'], ... 'model_evaluation')
- evaluate_model_cv(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, cv: int = 5) None[source]#
Evaluate model using cross-validation.
This method evaluates a model using k-fold cross-validation to provide more robust performance estimates. It trains and evaluates the model on multiple train/test splits and saves the results.
- Parameters:
- modelBaseEstimator
Model to evaluate (will be cloned for each CV fold)
- Xpd.DataFrame
Feature data for evaluation
- ypd.Series
Target data for evaluation
- metricsList[str]
List of metric names to calculate
- filenamestr
Base filename for saving results (without extension)
- cvint, default=5
Number of cross-validation folds to use
Notes
The method uses the ‘brisk_evaluate_model_cv’ evaluator from the evaluation manager. Cross-validation provides more reliable performance estimates by testing on multiple data splits.
Examples
>>> workflow.evaluate_model_cv(model, X_train, y_train, ... ['accuracy', 'f1_score'], ... 'cv_evaluation', cv=10)
- hyperparameter_tuning(model: BaseEstimator, method: str, X_train: DataFrame, y_train: Series, scorer: str, kf: int, num_rep: int, n_jobs: int, plot_results: bool = False) BaseEstimator[source]#
Perform hyperparameter tuning using grid or random search.
This method optimizes model hyperparameters using either grid search or random search with cross-validation. It returns the best model found during the search process.
- Parameters:
- modelBaseEstimator
Base model to tune (will be cloned for each parameter combination)
- methodstr
Search method to use (‘grid’ or ‘random’)
- X_trainpd.DataFrame
Training feature data
- y_trainpd.Series
Training target data
- scorerstr
Scoring metric to optimize (e.g., ‘accuracy’, ‘neg_mean_squared_error’)
- kfint
Number of cross-validation folds
- num_repint
Number of CV repetitions for stability
- n_jobsint
Number of parallel jobs (-1 uses all cores)
- plot_resultsbool, default=False
Whether to generate plots showing hyperparameter performance
- Returns:
- BaseEstimator
Best model found during hyperparameter search
Notes
The method uses the ‘brisk_hyperparameter_tuning’ evaluator from the evaluation manager. The search process tests different parameter combinations and selects the one with the best cross-validation score.
Examples
>>> tuned_model = workflow.hyperparameter_tuning( ... RandomForestClassifier(), 'grid', X_train, y_train, ... 'accuracy', kf=5, num_rep=3, n_jobs=-1, plot_results=True)
- load_model(filepath: str) BaseEstimator[source]#
Load model from pickle file.
This method loads a previously saved model from a pickle file, allowing it to be used for inference or further analysis.
- Parameters:
- filepathstr
Path to the saved model file (with extension)
- Returns:
- BaseEstimator
Loaded model ready for use
- Raises:
- FileNotFoundError
If the model file does not exist at the specified path
Notes
The method delegates to the evaluation manager’s load_model method. The loaded model can be used for making predictions or further evaluation.
Examples
>>> loaded_model = workflow.load_model('my_model.pkl') >>> predictions = loaded_model.predict(X_new)
- plot_confusion_heatmap(model: Any, X: ndarray, y: ndarray, filename: str) None[source]#
Plot a heatmap of the confusion matrix for a model.
This method generates a visual heatmap representation of the confusion matrix, making it easier to interpret classification performance across different classes.
- Parameters:
- modelAny
Trained classification model with predict method
- Xnp.ndarray
Feature data for making predictions
- ynp.ndarray
True class labels
- filenamestr
Output filename for the heatmap plot (without extension)
Notes
The method uses the ‘brisk_plot_confusion_heatmap’ evaluator from the evaluation manager. The heatmap uses color intensity to show the count of predictions, making it easy to identify patterns in classification errors.
Examples
>>> workflow.plot_confusion_heatmap( ... model, X_test.values, y_test.values, 'confusion_heatmap' ... )
- plot_feature_importance(model: BaseEstimator, X: DataFrame, y: Series, threshold: int | float, feature_names: List[str], filename: str, metric: str, num_rep: int) None[source]#
Plot feature importance for the model and save the plot.
This method generates a plot showing the importance of each feature in the model’s predictions. It can filter features by importance threshold or number of top features.
- Parameters:
- modelBaseEstimator
Trained model with feature importance or permutation importance
- Xpd.DataFrame
Feature data for importance calculation
- ypd.Series
Target data for importance calculation
- thresholdUnion[int, float]
If int: number of top features to show If float: minimum importance threshold for features
- feature_namesList[str]
List of feature names corresponding to X columns
- filenamestr
Output filename for the plot (without extension)
- metricstr
Metric to use for importance calculation
- num_repint
Number of repetitions for calculating importance (for stability)
Notes
The method uses the ‘brisk_plot_feature_importance’ evaluator from the evaluation manager. Feature importance helps identify which features contribute most to model predictions.
Examples
>>> workflow.plot_feature_importance(model, X_train, y_train, ... threshold=10, ... feature_names=feature_names, ... filename='feature_importance', ... metric='accuracy', num_rep=5)
- plot_learning_curve(model: BaseEstimator, X_train: DataFrame, y_train: Series, filename: str = 'learning_curve', cv: int = 5, num_repeats: int = 1, n_jobs: int = -1, metric: str = 'neg_mean_absolute_error') None[source]#
Plot learning curves showing model performance vs training size.
This method generates learning curves that show how model performance changes as the training set size increases. This helps identify whether the model would benefit from more data or if it’s suffering from high bias or variance.
- Parameters:
- modelBaseEstimator
Model to evaluate (will be cloned for each training size)
- X_trainpd.DataFrame
Training feature data
- y_trainpd.Series
Training target data
- filenamestr, default=”learning_curve”
Base filename for saving the plot (without extension)
- cvint, default=5
Number of cross-validation folds for each training size
- num_repeatsint, default=1
Number of times to repeat cross-validation for stability
- n_jobsint, default=-1
Number of parallel jobs for cross-validation (-1 uses all cores)
- metricstr, default=”neg_mean_absolute_error”
Scoring metric to use for evaluation
Notes
The method uses the ‘brisk_plot_learning_curve’ evaluator from the evaluation manager. Learning curves help diagnose model behavior: - High bias: both training and validation scores are low - High variance: large gap between training and validation scores - Good fit: both scores converge to similar high values
Examples
>>> workflow.plot_learning_curve(model, X_train, y_train, ... filename='rf_learning_curve', ... cv=10, metric='neg_mean_squared_error')
- plot_model_comparison(*models: BaseEstimator, X: DataFrame, y: Series, metric: str, filename: str) None[source]#
Plot comparison of multiple models based on specified metric.
This method generates a visualization comparing the performance of multiple models on a single metric, making it easy to identify the best performing model.
- Parameters:
- *modelsBaseEstimator
Variable number of trained models to compare
- Xpd.DataFrame
Feature data for evaluation
- ypd.Series
Target data for evaluation
- metricstr
Single metric name to use for comparison
- filenamestr
Output filename for the plot (without extension)
Notes
The method uses the ‘brisk_plot_model_comparison’ evaluator from the evaluation manager. The plot typically shows model names on one axis and metric values on the other, making performance comparison easy.
Examples
>>> workflow.plot_model_comparison(model1, model2, model3, ... X=X_test, y=y_test, ... metric='accuracy', ... filename='model_comparison' ... )
- plot_precision_recall_curve(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#
Plot a precision-recall curve with average precision.
This method generates a precision-recall curve for binary classification models, showing the trade-off between precision and recall at different classification thresholds. This is particularly useful for imbalanced datasets.
- Parameters:
- modelAny
Trained binary classification model with predict_proba method
- Xnp.ndarray
Feature data for making predictions
- ynp.ndarray
True binary class labels (0 and 1)
- filenamestr
Output filename for the precision-recall curve plot (without extension)
- pos_labelOptional[int], default=1
Label of the positive class for precision-recall calculation
Notes
The method uses the ‘brisk_plot_precision_recall_curve’ evaluator from the evaluation manager. The precision-recall curve shows: - X-axis: Recall (True Positive Rate) - Y-axis: Precision (Positive Predictive Value) - AP: Average Precision (higher is better)
Precision-recall curves are especially useful for imbalanced datasets where the positive class is rare.
Examples
>>> workflow.plot_precision_recall_curve( ... model, X_test.values, y_test.values, 'pr_curve' ... )
- plot_pred_vs_obs(model: BaseEstimator, X: DataFrame, y_true: Series, filename: str) None[source]#
Plot predicted vs. observed values and save the plot.
This method generates a scatter plot comparing predicted values against observed values, which is useful for regression model evaluation and identifying prediction patterns.
- Parameters:
- modelBaseEstimator
Trained model with predict method
- Xpd.DataFrame
Feature data for making predictions
- y_truepd.Series
True target values for comparison
- filenamestr
Output filename for the plot (without extension)
Notes
The method uses the ‘brisk_plot_pred_vs_obs’ evaluator from the evaluation manager. The plot helps assess model performance by showing how well predictions align with actual values.
Examples
>>> workflow.plot_pred_vs_obs(model, X_test, y_test, 'pred_vs_obs')
- plot_residuals(model: BaseEstimator, X: DataFrame, y: Series, filename: str, add_fit_line: bool = False) None[source]#
Plot residuals of the model and save the plot.
This method generates a residual plot showing the difference between predicted and actual values. Residual plots help assess model assumptions and identify patterns in prediction errors.
- Parameters:
- modelBaseEstimator
Trained model with predict method
- Xpd.DataFrame
Feature data for making predictions
- ypd.Series
True target values
- filenamestr
Output filename for the plot (without extension)
- add_fit_linebool, default=False
Whether to add a line of best fit to the residual plot
Notes
The method uses the ‘brisk_plot_residuals’ evaluator from the evaluation manager. Residual plots help identify: - Non-linear patterns in residuals - Heteroscedasticity (varying variance) - Outliers and influential points
Examples
>>> workflow.plot_residuals( ... model, X_test, y_test, 'residuals', add_fit_line=True ... )
- plot_roc_curve(model: Any, X: ndarray, y: ndarray, filename: str, pos_label: int | None = 1) None[source]#
Plot a receiver operating characteristic (ROC) curve with AUC.
This method generates a ROC curve for binary classification models, showing the trade-off between true positive rate and false positive rate at different classification thresholds.
- Parameters:
- modelAny
Trained binary classification model with predict_proba method
- Xnp.ndarray
Feature data for making predictions
- ynp.ndarray
True binary class labels (0 and 1)
- filenamestr
Output filename for the ROC curve plot (without extension)
- pos_labelOptional[int], default=1
Label of the positive class for ROC calculation
Notes
The method uses the ‘brisk_plot_roc_curve’ evaluator from the evaluation manager. The ROC curve shows: - X-axis: False Positive Rate (1 - Specificity) - Y-axis: True Positive Rate (Sensitivity) - AUC: Area Under the Curve (higher is better)
Examples
>>> workflow.plot_roc_curve( ... model, X_test.values, y_test.values, 'roc_curve' ... )
- plot_shapley_values(model: BaseEstimator, X: DataFrame, y: Series, filename: str = 'shapley_values', plot_type: str = 'bar') None[source]#
Generate SHAP value plots for feature importance.
This method generates SHAP (SHapley Additive exPlanations) value plots to explain individual predictions and feature importance. SHAP values provide a unified framework for explaining model predictions.
- Parameters:
- modelBaseEstimator
Trained model to explain (must be compatible with SHAP)
- Xpd.DataFrame
Feature data for generating explanations
- ypd.Series
Target data (used for context in some plot types)
- filenamestr, default=”shapley_values”
Base output filename for SHAP plots (without extension)
- plot_typestr, default=”bar”
Type of SHAP plot to generate. Options: - ‘bar’: Bar plot of mean SHAP values - ‘waterfall’: Waterfall plot for individual predictions - ‘violin’: Violin plot showing SHAP value distributions - ‘beeswarm’: Beeswarm plot for feature importance Multiple types can be specified as ‘bar,waterfall’ to generate multiple plots
Notes
The method uses the ‘brisk_plot_shapley_values’ evaluator from the evaluation manager. SHAP values provide: - Feature importance rankings - Individual prediction explanations - Feature interaction effects - Model interpretability insights
Examples
>>> workflow.plot_shapley_values(model, X_test, y_test, ... 'shap_explanation', 'bar,waterfall')
- run() None[source]#
Execute the workflow.
This method serves as the entry point for running the workflow. It delegates to the abstract workflow() method that must be implemented by subclasses, providing a consistent interface for workflow execution.
- Raises:
- NotImplementedError
If called directly on the base Workflow class without implementing the abstract workflow() method
Notes
This method is marked with # pragma: no cover because it’s an abstract method that should be overridden by subclasses. The actual workflow logic is implemented in the workflow() method of concrete subclasses.
- save_model(model: BaseEstimator, filename: str) None[source]#
Save model to pickle file.
This method saves a trained model to a pickle file in the workflow’s output directory, allowing it to be loaded later for inference or further analysis.
- Parameters:
- modelBaseEstimator
Trained model to save
- filenamestr
Base filename for the saved model (without extension)
Notes
The method delegates to the evaluation manager’s save_model method. The model is saved in the workflow’s output directory with the specified filename.
Examples
>>> workflow.save_model(trained_model, 'my_model')
- abstractmethod workflow(X_train, X_test, y_train, y_test, output_dir, feature_names) None[source]#
Abstract method defining the workflow logic.
This method must be implemented by all concrete subclasses of Workflow. It should contain the specific logic for the machine learning workflow, including model training, evaluation, visualization, and result saving.
- Raises:
- NotImplementedError
Always raises this error as it’s an abstract method
Notes
This is an abstract method that must be implemented by subclasses. The implementation should define the complete workflow logic for the specific use case, utilizing the available data (X_train, X_test, y_train, y_test) and the evaluation manager for model assessment.
Typical workflow implementations include: - Model instantiation and training - Model evaluation using provided methods - Visualization generation - Result saving and reporting