Source code for ncaa_eval.model.elo

"""Elo rating model — reference stateful model for NCAA tournament prediction.

Thin wrapper around :class:`~ncaa_eval.transform.elo.EloFeatureEngine`.
All Elo math is delegated to the engine; this module adds :class:`StatefulModel`
ABC conformance, Pydantic configuration, JSON persistence, and plugin
registration.
"""

from __future__ import annotations

import dataclasses
import json
from pathlib import Path
from typing import Any, Literal, Self

from ncaa_eval.ingest.schema import Game
from ncaa_eval.model._feature_config_io import load_feature_config, save_feature_config
from ncaa_eval.model.base import ModelConfig, StatefulModel
from ncaa_eval.model.registry import register_model
from ncaa_eval.transform.elo import EloConfig, EloFeatureEngine
from ncaa_eval.transform.feature_serving import FeatureConfig


[docs] class EloModelConfig(ModelConfig): """Pydantic configuration for the Elo model. Fields and defaults mirror :class:`~ncaa_eval.transform.elo.EloConfig`. """ model_name: Literal["elo"] = "elo" initial_rating: float = 1500.0 k_early: float = 56.0 k_regular: float = 38.0 k_tournament: float = 47.5 early_game_threshold: int = 20 margin_exponent: float = 0.85 max_margin: int = 25 home_advantage_elo: float = 3.5 mean_reversion_fraction: float = 0.25
[docs] @register_model("elo") class EloModel(StatefulModel): """Elo rating model wrapping :class:`EloFeatureEngine`.""" def __init__(self, config: EloModelConfig | None = None) -> None: """Initialize Elo model with optional configuration. Args: config: Pydantic config; defaults to :class:`EloModelConfig` with standard hyperparameters when ``None``. """ self._config = config or EloModelConfig() self._engine = EloFeatureEngine(self._to_elo_config(self._config)) self.feature_config = FeatureConfig( sequential_windows=(), graph_features_enabled=False, batch_rating_types=(), ordinal_composite=None, elo_enabled=True, elo_config=self._to_elo_config(self._config), ) # ------------------------------------------------------------------ # StatefulModel abstract hooks # ------------------------------------------------------------------
[docs] def update(self, game: Game) -> None: """Delegate game processing to the engine.""" self._engine.update_game( w_team_id=game.w_team_id, l_team_id=game.l_team_id, w_score=game.w_score, l_score=game.l_score, loc=game.loc, is_tournament=game.is_tournament, num_ot=game.num_ot, )
[docs] def start_season(self, season: int) -> None: """Delegate season transition to the engine.""" self._engine.start_new_season(season)
def _predict_one(self, team_a_id: int, team_b_id: int) -> float: """Return P(team_a wins) using the Elo expected-score formula. Delegates to the underlying EloFeatureEngine.predict_matchup(), which retrieves both teams' current ratings and applies the logistic expected-score formula. """ return self._engine.predict_matchup(team_a_id, team_b_id)
[docs] def get_state(self) -> dict[str, Any]: """Return ratings and game counts as a serialisable snapshot.""" return { "ratings": self._engine.get_all_ratings(), "game_counts": self._engine.get_game_counts(), }
[docs] def set_state(self, state: dict[str, Any]) -> None: """Restore ratings and game counts from a snapshot. Args: state: Must contain ``"ratings"`` (``dict[int, float]``) and ``"game_counts"`` (``dict[int, int]``) keys, as returned by :meth:`get_state`. Keys may be ``int`` or ``str``; string keys are coerced to ``int`` so that JSON-decoded dicts (where all keys are strings) work correctly without silent rating loss. Raises: KeyError: If ``"ratings"`` or ``"game_counts"`` keys are absent. TypeError: If either value is not a ``dict``. """ if "ratings" not in state or "game_counts" not in state: missing = {"ratings", "game_counts"} - state.keys() msg = f"set_state() state dict missing required keys: {missing}" raise KeyError(msg) ratings = state["ratings"] game_counts = state["game_counts"] if not isinstance(ratings, dict) or not isinstance(game_counts, dict): msg = "set_state() 'ratings' and 'game_counts' must be dicts" raise TypeError(msg) # Coerce string keys to int so JSON-decoded dicts (all keys are str) # work correctly — without coercion, get_rating(team_id_int) would # silently return initial_rating for every team. self._engine.set_ratings({int(k): float(v) for k, v in ratings.items()}) self._engine.set_game_counts({int(k): int(v) for k, v in game_counts.items()})
# ------------------------------------------------------------------ # Model ABC: persistence # ------------------------------------------------------------------
[docs] def save(self, path: Path) -> None: """JSON-dump config, state, and feature config to *path* directory. Creates the output directory, JSON-dumps the Pydantic config, then JSON-dumps the state dict (ratings and game counts) after coercing numeric keys to strings for JSON compatibility. Also writes the ``feature_config.json`` sidecar. """ path.mkdir(parents=True, exist_ok=True) (path / "config.json").write_text(self._config.model_dump_json()) state = self.get_state() # JSON keys must be strings serialisable = { "ratings": {str(k): v for k, v in state["ratings"].items()}, "game_counts": {str(k): v for k, v in state["game_counts"].items()}, } (path / "state.json").write_text(json.dumps(serialisable)) save_feature_config(self.feature_config, path)
[docs] @classmethod def load(cls, path: Path) -> Self: """Reconstruct an EloModel from a saved directory. Raises: FileNotFoundError: If either ``config.json`` or ``state.json`` is missing. A missing file indicates an incomplete :meth:`save` (e.g., interrupted write). """ config_path = path / "config.json" state_path = path / "state.json" missing = [p for p in (config_path, state_path) if not p.exists()] if missing: missing_names = ", ".join(p.name for p in missing) msg = f"Incomplete save at {path!r}: missing {missing_names}. The save may have been interrupted." raise FileNotFoundError(msg) config = EloModelConfig.model_validate_json(config_path.read_text()) instance = cls(config) raw = json.loads(state_path.read_text()) state = { "ratings": {int(k): v for k, v in raw["ratings"].items()}, "game_counts": {int(k): v for k, v in raw["game_counts"].items()}, } instance.set_state(state) loaded_fc = load_feature_config(path) if loaded_fc is not None: instance.feature_config = loaded_fc return instance
# ------------------------------------------------------------------ # Model ABC: config # ------------------------------------------------------------------
[docs] def get_feature_importances(self) -> list[tuple[str, float]] | None: """Return top team Elo ratings as interpretability information.""" ratings = self._engine.get_all_ratings() if not ratings: return None sorted_ratings = sorted(ratings.items(), key=lambda x: x[1], reverse=True) top_n = sorted_ratings[:50] return [(f"team_{team_id}", rating) for team_id, rating in top_n]
[docs] def get_config(self) -> EloModelConfig: """Return the Pydantic-validated configuration.""" return self._config
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ @staticmethod def _to_elo_config(config: EloModelConfig) -> EloConfig: """Convert Pydantic config to the frozen dataclass the engine expects. Uses :func:`dataclasses.fields` to derive the argument set from ``EloConfig`` at runtime, so any new field added to ``EloConfig`` is automatically included — without requiring a manual update here. ``EloModelConfig`` must keep its fields in sync with ``EloConfig``. """ elo_field_names = {f.name for f in dataclasses.fields(EloConfig)} kwargs = {k: v for k, v in config.model_dump().items() if k in elo_field_names} return EloConfig(**kwargs)