Source code for ncaa_eval.evaluation.metrics

"""Evaluation metrics for NCAA basketball model predictions.

Provides metric functions for evaluating probabilistic predictions:

* :func:`log_loss` — Log Loss via ``sklearn.metrics.log_loss``
* :func:`brier_score` — Brier Score via ``sklearn.metrics.brier_score_loss``
* :func:`roc_auc` — ROC-AUC via ``sklearn.metrics.roc_auc_score``
* :func:`expected_calibration_error` — ECE via vectorized numpy binning
* :func:`reliability_diagram_data` — Reliability diagram bin data via
  ``sklearn.calibration.calibration_curve``

All functions accept ``npt.NDArray[np.float64]`` inputs and return ``float``
scalars or structured data (:class:`ReliabilityData`).

Metric Registry
---------------

* :func:`register_metric` — decorator to register a metric function
* :func:`get_metric` — look up a metric by name
* :func:`list_metrics` — list all registered metric names
* :class:`MetricNotFoundError` — raised for unknown metric names
"""

from __future__ import annotations

import dataclasses
from collections.abc import Callable
from typing import TypeVar

import numpy as np
import numpy.typing as npt

# ---------------------------------------------------------------------------
# Metric type alias & registry
# ---------------------------------------------------------------------------

MetricFn = Callable[[npt.NDArray[np.float64], npt.NDArray[np.float64]], float]
"""Signature for metric functions: ``(y_true, y_prob) -> float``."""

_MF = TypeVar("_MF", bound=MetricFn)

_METRIC_REGISTRY: dict[str, MetricFn] = {}


[docs] class MetricNotFoundError(KeyError): """Raised when a requested metric name is not in the registry."""
[docs] def register_metric(name: str) -> Callable[[_MF], _MF]: """Function decorator that registers a metric function. Args: name: Registry key for the metric. Returns: Decorator that registers the function and returns it unchanged. Raises: ValueError: If *name* is already registered. """ def decorator(fn: _MF) -> _MF: if name in _METRIC_REGISTRY: msg = f"Metric name {name!r} is already registered" raise ValueError(msg) _METRIC_REGISTRY[name] = fn return fn return decorator
[docs] def get_metric(name: str) -> MetricFn: """Return the metric function registered under *name*. Raises: MetricNotFoundError: If *name* is not registered. """ try: return _METRIC_REGISTRY[name] except KeyError: msg = f"No metric registered with name {name!r}. Available: {list_metrics()}" raise MetricNotFoundError(msg) from None
[docs] def list_metrics() -> list[str]: """Return all registered metric names (sorted).""" return sorted(_METRIC_REGISTRY)
[docs] @dataclasses.dataclass(frozen=True) class ReliabilityData: """Structured return type for reliability diagram data. Attributes: fraction_of_positives: Observed fraction of positives per bin (from calibration_curve). mean_predicted_value: Mean predicted probability per bin (from calibration_curve). bin_counts: Number of samples in each non-empty bin. bin_edges: Full bin edge array of shape ``(n_bins + 1,)``, i.e. ``np.linspace(0.0, 1.0, n_bins + 1)``. Includes both the lower (0.0) and upper (1.0) boundaries so callers do not need to recompute them. n_bins: Requested number of bins. """ fraction_of_positives: npt.NDArray[np.float64] mean_predicted_value: npt.NDArray[np.float64] bin_counts: npt.NDArray[np.int64] bin_edges: npt.NDArray[np.float64] n_bins: int
def _validate_inputs( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], ) -> None: """Validate metric inputs: non-empty, matching lengths, binary y_true, probs in [0, 1]. Checks array non-emptiness, matching lengths, binary values in ``y_true``, and probability bounds in [0, 1] using NumPy vectorized comparisons, raising a descriptive ``ValueError`` for any violation. """ if len(y_true) == 0 or len(y_prob) == 0: msg = "y_true and y_prob must be non-empty arrays." raise ValueError(msg) if len(y_true) != len(y_prob): msg = f"y_true and y_prob must have the same length, got {len(y_true)} and {len(y_prob)}." raise ValueError(msg) if not np.all((y_true == 0) | (y_true == 1)): msg = "y_true must contain only binary values (0 or 1)." raise ValueError(msg) if np.any(y_prob < 0.0) or np.any(y_prob > 1.0): msg = "y_prob values must be in [0, 1]." raise ValueError(msg)
[docs] @register_metric("log_loss") def log_loss( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], ) -> float: """Compute Log Loss (cross-entropy loss) for binary predictions. Args: y_true: Binary labels (0 or 1). y_prob: Predicted probabilities for the positive class. Returns: Log Loss value. Raises: ValueError: If inputs are empty, mismatched, or probabilities are outside [0, 1]. """ from sklearn.metrics import log_loss as sklearn_log_loss # type: ignore[import-untyped] _validate_inputs(y_true, y_prob) result: float = float(sklearn_log_loss(y_true, y_prob, labels=[0, 1])) return result
[docs] @register_metric("brier_score") def brier_score( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], ) -> float: """Compute Brier Score for binary predictions. Args: y_true: Binary labels (0 or 1). y_prob: Predicted probabilities for the positive class. Returns: Brier Score value (lower is better). Raises: ValueError: If inputs are empty, mismatched, or probabilities are outside [0, 1]. """ from sklearn.metrics import brier_score_loss _validate_inputs(y_true, y_prob) result: float = float(brier_score_loss(y_true, y_prob)) return result
[docs] @register_metric("roc_auc") def roc_auc( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], ) -> float: """Compute ROC-AUC for binary predictions. Args: y_true: Binary labels (0 or 1). y_prob: Predicted probabilities for the positive class. Returns: ROC-AUC value. Raises: ValueError: If inputs are empty, mismatched, probabilities are outside [0, 1], or ``y_true`` contains only one class (AUC is undefined). """ from sklearn.metrics import roc_auc_score _validate_inputs(y_true, y_prob) unique_classes = np.unique(y_true) if len(unique_classes) < 2: msg = "roc_auc requires both positive and negative samples in y_true." raise ValueError(msg) result: float = float(roc_auc_score(y_true, y_prob)) return result
[docs] @register_metric("ece") def expected_calibration_error( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], *, n_bins: int = 10, ) -> float: """Compute Expected Calibration Error (ECE) using vectorized numpy. ECE measures how well predicted probabilities match observed frequencies. Predictions are binned into ``n_bins`` equal-width bins on [0, 1], and ECE is the weighted average of per-bin |accuracy - confidence| gaps. Args: y_true: Binary labels (0 or 1). y_prob: Predicted probabilities for the positive class. n_bins: Number of equal-width bins (default 10). Returns: ECE value in [0, 1] (lower is better). Raises: ValueError: If inputs are empty, mismatched, or probabilities are outside [0, 1]. """ if n_bins < 1: msg = f"n_bins must be >= 1, got {n_bins}." raise ValueError(msg) _validate_inputs(y_true, y_prob) # Bin edges: [0, 1/n_bins, 2/n_bins, ..., 1] bin_edges = np.linspace(0.0, 1.0, n_bins + 1) # Digitize against interior edges only (bin_edges[1:-1] excludes 0.0 and 1.0). # np.digitize returns 0 for values below the first interior edge, and n_bins-1 # for values at or above the last interior edge after clipping to [0, n_bins-1]. bin_indices = np.clip(np.digitize(y_prob, bin_edges[1:-1]), 0, n_bins - 1) # Vectorized per-bin statistics using np.bincount bin_counts = np.bincount(bin_indices, minlength=n_bins).astype(np.float64) bin_sums_true = np.bincount(bin_indices, weights=y_true, minlength=n_bins) bin_sums_prob = np.bincount(bin_indices, weights=y_prob, minlength=n_bins) # Mask for non-empty bins (avoid division by zero) non_empty = bin_counts > 0 acc = np.zeros(n_bins, dtype=np.float64) conf = np.zeros(n_bins, dtype=np.float64) acc[non_empty] = bin_sums_true[non_empty] / bin_counts[non_empty] conf[non_empty] = bin_sums_prob[non_empty] / bin_counts[non_empty] weights = bin_counts / float(len(y_true)) ece: float = float(np.sum(weights * np.abs(acc - conf))) return ece
[docs] def reliability_diagram_data( y_true: npt.NDArray[np.float64], y_prob: npt.NDArray[np.float64], *, n_bins: int = 10, ) -> ReliabilityData: """Generate reliability diagram data for calibration visualization. Uses ``sklearn.calibration.calibration_curve`` for bin statistics and augments with per-bin sample counts. Args: y_true: Binary labels (0 or 1). y_prob: Predicted probabilities for the positive class. n_bins: Number of bins (default 10). Returns: Structured data containing fraction of positives, mean predicted values, bin counts, bin edges, and requested number of bins. Raises: ValueError: If inputs are empty, mismatched, ``n_bins < 1``, or probabilities are outside [0, 1]. """ from sklearn.calibration import calibration_curve # type: ignore[import-untyped] if n_bins < 1: msg = f"n_bins must be >= 1, got {n_bins}." raise ValueError(msg) _validate_inputs(y_true, y_prob) fraction_of_positives: npt.NDArray[np.float64] mean_predicted_value: npt.NDArray[np.float64] fraction_of_positives, mean_predicted_value = calibration_curve( y_true, y_prob, n_bins=n_bins, strategy="uniform" ) # Compute bin counts using same binning as calibration_curve (uniform) bin_edges = np.linspace(0.0, 1.0, n_bins + 1) bin_indices = np.clip(np.digitize(y_prob, bin_edges[1:-1]), 0, n_bins - 1) all_bin_counts = np.bincount(bin_indices, minlength=n_bins) # calibration_curve only returns non-empty bins — filter to match non_empty_mask = all_bin_counts > 0 bin_counts: npt.NDArray[np.int64] = all_bin_counts[non_empty_mask].astype(np.int64) return ReliabilityData( fraction_of_positives=fraction_of_positives.copy(), mean_predicted_value=mean_predicted_value.copy(), bin_counts=bin_counts.copy(), bin_edges=bin_edges.copy(), n_bins=n_bins, )