EvaluationManager#
- class EvaluationManager(algorithm_config: AlgorithmCollection, metric_config: MetricManager, output_dir: str, split_metadata: Dict[str, Any], logger: Logger | None = None)#
A class for evaluating machine learning models and plotting results.
This class provides methods for model evaluation, including calculating metrics, generating plots, comparing models, and hyperparameter tuning. It is designed to be used within a Workflow instance.
- Parameters:
- algorithm_configAlgorithmCollection
Configuration for algorithms.
- metric_configMetricManager
Configuration for evaluation metrics.
- output_dirstr
Directory to save results.
- split_metadataDict[str, Any]
Metadata to include in metric calculations.
- loggerOptional[logging.Logger]
Logger instance to use.
- Attributes:
- algorithm_configAlgorithmCollection
Configuration for algorithms.
- metric_configAny
Configuration for evaluation metrics.
- output_dirstr
Directory to save results.
- split_metadataDict[str, Any]
Metadata to include in metric calculations.
- loggerOptional[logging.Logger]
Logger instance to use.
- primary_colorstr
Color for primary elements.
- secondary_colorstr
Color for secondary elements.
- background_colorstr
Color for background elements.
- accent_colorstr
Color for accent elements.
- important_colorstr
Color for important elements.
- compare_models(*models: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, calculate_diff: bool = False) Dict[str, Dict[str, float]]#
Compare multiple models using specified metrics.
- Parameters:
- *modelsBaseEstimator
Models to compare
- XDataFrame
Input features
- ySeries
Target values
- metricslist of str
Names of metrics to calculate
- filenamestr
Name for output file (without extension)
- calculate_diffbool, optional
Whether to calculate differences between models, by default False
- Returns:
- dict
Nested dictionary containing metric scores for each model
- confusion_matrix(model: Any, X: ndarray, y: ndarray, filename: str) None#
Generate and save a confusion matrix.
- Parameters:
- modelAny
Trained classification model with predict method
- Xndarray
The input features.
- yndarray
The true target values.
- filenamestr
The name of the output file (without extension).
- evaluate_model(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str) None#
Evaluate a model on the provided metrics and save the results.
- Parameters:
- model (BaseEstimator):
The trained model to evaluate.
- X (pd.DataFrame):
The input features.
- y (pd.Series):
The target data.
- metrics (List[str]):
A list of metrics to calculate.
- filename (str):
The name of the output file without extension.
- evaluate_model_cv(model: BaseEstimator, X: DataFrame, y: Series, metrics: List[str], filename: str, cv: int = 5) None#
Evaluate a model using cross-validation and save the scores.
- Parameters:
- model (BaseEstimator):
The model to evaluate.
- X (pd.DataFrame):
The input features.
- y (pd.Series):
The target data.
- metrics (List[str]):
A list of metrics to calculate.
- filename (str):
The name of the output file without extension.
- cv (int):
The number of cross-validation folds. Defaults to 5.
- hyperparameter_tuning(model: BaseEstimator, method: str, X_train: DataFrame, y_train: Series, scorer: str, kf: int, num_rep: int, n_jobs: int, plot_results: bool = False) BaseEstimator#
Perform hyperparameter tuning using grid or random search.
- Parameters:
- model (BaseEstimator):
The model to be tuned.
- method (str):
The search method to use (“grid” or “random”).
- X_train (pd.DataFrame):
The training data.
- y_train (pd.Series):
The target values for training.
- scorer (str):
The scoring metric to use.
- kf (int):
Number of splits for cross-validation.
- num_rep (int):
Number of repetitions for cross-validation.
- n_jobs (int):
Number of parallel jobs to run.
- plot_results (bool):
Whether to plot the performance of hyperparameters. Defaults to False.
- Returns:
- BaseEstimator:
The tuned model.
- load_model(filepath: str) BaseEstimator#
Load model from pickle file.
- Parameters:
- filepathstr
Path to saved model file
- Returns:
- BaseEstimator
Loaded model
- Raises:
- FileNotFoundError
If model file does not exist
- plot_confusion_heatmap(model: Any, X: ndarray, y: ndarray, filename: str) None#
Plot a heatmap of the confusion matrix for a model.
- Parameters:
- model (Any):
The trained classification model with a predict method.
- X (np.ndarray):
The input features.
- y (np.ndarray):
The target labels.
- filename (str):
The path to save the confusion matrix heatmap image.
- plot_feature_importance(model: BaseEstimator, X: DataFrame, y: Series, threshold: int | float, feature_names: List[str], filename: str, metric: str, num_rep: int) None#
Plot the feature importance for the model and save the plot.
- Parameters:
- model (BaseEstimator):
The model to evaluate.
- X (pd.DataFrame):
The input features.
- y (pd.Series):
The target data.
- threshold (Union[int, float]):
The number of features or the threshold to filter features by importance.
- feature_names (List[str]):
A list of feature names corresponding to the columns in X.
- filename (str):
The name of the output file (without extension).
- metric (str):
The metric to use for evaluation.
- num_rep (int):
The number of repetitions for calculating importance.
- plot_learning_curve(model: BaseEstimator, X_train: DataFrame, y_train: Series, cv: int = 5, num_repeats: int = 1, n_jobs: int = -1, metric: str = 'neg_mean_absolute_error', filename: str = 'learning_curve') None#
Plot learning curves showing model performance vs training size.
- Parameters:
- modelBaseEstimator
Model to evaluate
- X_trainDataFrame
Training features
- y_trainSeries
Training target values
- cvint, optional
Number of cross-validation folds, by default 5
- num_repeatsint, optional
Number of times to repeat CV, by default 1
- n_jobsint, optional
Number of parallel jobs, by default -1
- metricstr, optional
Scoring metric to use, by default “neg_mean_absolute_error”
- filenamestr, optional
Name for output file, by default “learning_curve”
- plot_model_comparison(*models: BaseEstimator, X: DataFrame, y: Series, metric: str, filename: str) None#
Plot a comparison of multiple models based on the specified metric.
- Parameters:
- models:
A variable number of model instances to evaluate.
- X (pd.DataFrame):
The input features.
- y (pd.Series):
The target data.
- metric (str):
The metric to evaluate and plot.
- filename (str):
The name of the output file (without extension).
- plot_precision_recall_curve(model: Any, X: ndarray, y: ndarray, filename: str) None#
Plot a precision-recall curve with average precision.
- Parameters:
- model (Any):
The trained binary classification model.
- X (np.ndarray):
The input features.
- y (np.ndarray):
The true binary labels.
- filename (str):
The path to save the plot.
- plot_pred_vs_obs(model: BaseEstimator, X: DataFrame, y_true: Series, filename: str) None#
Plot predicted vs. observed values and save the plot.
- Parameters:
- model (BaseEstimator):
The trained model.
- X (pd.DataFrame):
The input features.
- y_true (pd.Series):
The true target values.
- filename (str):
The name of the output file (without extension).
- plot_residuals(model: BaseEstimator, X: DataFrame, y: Series, filename: str, add_fit_line: bool = False) None#
Plot the residuals of the model and save the plot.
- Parameters:
- model (BaseEstimator):
The trained model.
- X (pd.DataFrame):
The input features.
- y (pd.Series):
The true target values.
- filename (str):
The name of the output file (without extension).
- add_fit_line (bool):
Whether to add a line of best fit to the plot.
- plot_roc_curve(model: Any, X: ndarray, y: ndarray, filename: str) None#
Plot a reciever operator curve with area under the curve.
- Parameters:
- model (Any):
The trained binary classification model.
- X (np.ndarray):
The input features.
- y (np.ndarray):
The true binary labels.
- filename (str):
The path to save the ROC curve image.
- save_model(model: BaseEstimator, filename: str) None#
Save model to pickle file.
- Parameters:
- model (BaseEstimator):
The model to save.
- filename (str):
The name for the output file (without extension).