MetricWrapper#
- class MetricWrapper(name: str, func: Callable, display_name: str, greater_is_better: bool, abbr: str | None = None, **default_params: Any)[source]#
A wrapper for metric functions with default parameters and metadata.
Wraps metric functions and provides methods to update parameters and retrieve the metric function with applied parameters. Also handles display names and abbreviations for reporting. Supports both scikit-learn metrics and custom user-defined functions.
- Parameters:
- namestr
Name of the metric
- funcCallable
Metric function to wrap
- display_namestr
Human-readable name for display in reports and plots
- greater_is_betterbool
Whether higher values indicate better performance
- abbrstr, optional
Abbreviation for the metric, by default None
- **default_paramsAny
Default parameters for the metric function
- Attributes:
- namestr
Name of the metric
- funcCallable
The wrapped metric function (may be modified to accept split_metadata)
- display_namestr
Human-readable display name
- abbrstr
Abbreviation (defaults to name if not provided)
- greater_is_betterbool
Whether higher values indicate better performance
- paramsdict
Current parameters for the metric
- _func_with_paramsCallable
Metric function with parameters applied
- scorerCallable
Scikit-learn scorer created from the metric
Notes
The MetricWrapper automatically ensures that wrapped functions can accept a split_metadata parameter, even if the original function doesn’t support it. This allows for consistent parameter passing across all metrics.
Examples
- Create a wrapper for mean squared error:
>>> from sklearn.metrics import mean_squared_error >>> wrapper = MetricWrapper( ... name="mse", ... func=mean_squared_error, ... display_name="Mean Squared Error", ... greater_is_better=False ... )
- Create a custom metric wrapper:
>>> def custom_metric(y_true, y_pred): ... return sum(abs(y_true - y_pred)) / len(y_true) >>> wrapper = MetricWrapper( ... name="custom_mae", ... func=custom_metric, ... display_name="Custom MAE", ... greater_is_better=False ... )
- export_config() Dict[str, Any][source]#
Export this MetricWrapper’s configuration for rerun functionality.
Exports the complete configuration needed to recreate this MetricWrapper instance. Handles both built-in scikit-learn functions and custom user-defined functions by detecting the function source and exporting appropriate reconstruction information.
- Returns:
- Dict[str, Any]
Configuration dictionary that can be used to recreate this MetricWrapper instance
Notes
The export process intelligently detects the function type: - “imported”: For functions from external libraries (e.g., sklearn) - “local”: For custom functions defined in the project - “unknown”: For functions that cannot be properly identified
For imported functions, it exports module and function names. For local functions, it exports the source code. For unknown functions, it exports basic identification information.
The split_metadata parameter is excluded from the exported params as it’s runtime-specific and not needed for reconstruction.
- get_func_with_params() Callable[source]#
Get the metric function with current parameters applied.
Returns a deep copy of the metric function with all current parameters applied. This ensures that the returned function is independent and can be safely used in parallel operations.
- Returns:
- Callable
Deep copy of the metric function with parameters applied
Notes
The returned function is a deep copy to prevent issues with shared state in parallel or concurrent operations.
- set_params(**params: Any) None[source]#
Update parameters for the metric function and scorer.
Updates the internal parameter dictionary and reapplies parameters to both the function and scorer.
- Parameters:
- **paramsAny
New parameters to update or add
- Returns:
- None
Notes
This method updates the internal params dictionary and then calls _apply_params to ensure both the function and scorer are updated with the new parameters.