"""Elo rating model — reference stateful model for NCAA tournament prediction.
Thin wrapper around :class:`~ncaa_eval.transform.elo.EloFeatureEngine`.
All Elo math is delegated to the engine; this module adds :class:`StatefulModel`
ABC conformance, Pydantic configuration, JSON persistence, and plugin
registration.
"""
from __future__ import annotations
import dataclasses
import json
from pathlib import Path
from typing import Any, Literal, Self
from ncaa_eval.ingest.schema import Game
from ncaa_eval.model._feature_config_io import load_feature_config, save_feature_config
from ncaa_eval.model.base import ModelConfig, StatefulModel
from ncaa_eval.model.registry import register_model
from ncaa_eval.transform.elo import EloConfig, EloFeatureEngine
from ncaa_eval.transform.feature_serving import FeatureConfig
[docs]
class EloModelConfig(ModelConfig):
"""Pydantic configuration for the Elo model.
Fields and defaults mirror :class:`~ncaa_eval.transform.elo.EloConfig`.
"""
model_name: Literal["elo"] = "elo"
initial_rating: float = 1500.0
k_early: float = 56.0
k_regular: float = 38.0
k_tournament: float = 47.5
early_game_threshold: int = 20
margin_exponent: float = 0.85
max_margin: int = 25
home_advantage_elo: float = 3.5
mean_reversion_fraction: float = 0.25
[docs]
@register_model("elo")
class EloModel(StatefulModel):
"""Elo rating model wrapping :class:`EloFeatureEngine`."""
def __init__(self, config: EloModelConfig | None = None) -> None:
"""Initialize Elo model with optional configuration.
Args:
config: Pydantic config; defaults to :class:`EloModelConfig`
with standard hyperparameters when ``None``.
"""
self._config = config or EloModelConfig()
self._engine = EloFeatureEngine(self._to_elo_config(self._config))
self.feature_config = FeatureConfig(
sequential_windows=(),
graph_features_enabled=False,
batch_rating_types=(),
ordinal_composite=None,
elo_enabled=True,
elo_config=self._to_elo_config(self._config),
)
# ------------------------------------------------------------------
# StatefulModel abstract hooks
# ------------------------------------------------------------------
[docs]
def update(self, game: Game) -> None:
"""Delegate game processing to the engine."""
self._engine.update_game(
w_team_id=game.w_team_id,
l_team_id=game.l_team_id,
w_score=game.w_score,
l_score=game.l_score,
loc=game.loc,
is_tournament=game.is_tournament,
num_ot=game.num_ot,
)
[docs]
def start_season(self, season: int) -> None:
"""Delegate season transition to the engine."""
self._engine.start_new_season(season)
def _predict_one(self, team_a_id: int, team_b_id: int) -> float:
"""Return P(team_a wins) using the Elo expected-score formula.
Delegates to the underlying EloFeatureEngine.predict_matchup(), which
retrieves both teams' current ratings and applies the logistic
expected-score formula.
"""
return self._engine.predict_matchup(team_a_id, team_b_id)
[docs]
def get_state(self) -> dict[str, Any]:
"""Return ratings and game counts as a serialisable snapshot."""
return {
"ratings": self._engine.get_all_ratings(),
"game_counts": self._engine.get_game_counts(),
}
[docs]
def set_state(self, state: dict[str, Any]) -> None:
"""Restore ratings and game counts from a snapshot.
Args:
state: Must contain ``"ratings"`` (``dict[int, float]``) and
``"game_counts"`` (``dict[int, int]``) keys, as returned by
:meth:`get_state`. Keys may be ``int`` or ``str``; string
keys are coerced to ``int`` so that JSON-decoded dicts (where
all keys are strings) work correctly without silent rating
loss.
Raises:
KeyError: If ``"ratings"`` or ``"game_counts"`` keys are absent.
TypeError: If either value is not a ``dict``.
"""
if "ratings" not in state or "game_counts" not in state:
missing = {"ratings", "game_counts"} - state.keys()
msg = f"set_state() state dict missing required keys: {missing}"
raise KeyError(msg)
ratings = state["ratings"]
game_counts = state["game_counts"]
if not isinstance(ratings, dict) or not isinstance(game_counts, dict):
msg = "set_state() 'ratings' and 'game_counts' must be dicts"
raise TypeError(msg)
# Coerce string keys to int so JSON-decoded dicts (all keys are str)
# work correctly — without coercion, get_rating(team_id_int) would
# silently return initial_rating for every team.
self._engine.set_ratings({int(k): float(v) for k, v in ratings.items()})
self._engine.set_game_counts({int(k): int(v) for k, v in game_counts.items()})
# ------------------------------------------------------------------
# Model ABC: persistence
# ------------------------------------------------------------------
[docs]
def save(self, path: Path) -> None:
"""JSON-dump config, state, and feature config to *path* directory.
Creates the output directory, JSON-dumps the Pydantic config, then
JSON-dumps the state dict (ratings and game counts) after coercing
numeric keys to strings for JSON compatibility. Also writes the
``feature_config.json`` sidecar.
"""
path.mkdir(parents=True, exist_ok=True)
(path / "config.json").write_text(self._config.model_dump_json())
state = self.get_state()
# JSON keys must be strings
serialisable = {
"ratings": {str(k): v for k, v in state["ratings"].items()},
"game_counts": {str(k): v for k, v in state["game_counts"].items()},
}
(path / "state.json").write_text(json.dumps(serialisable))
save_feature_config(self.feature_config, path)
[docs]
@classmethod
def load(cls, path: Path) -> Self:
"""Reconstruct an EloModel from a saved directory.
Raises:
FileNotFoundError: If either ``config.json`` or ``state.json`` is
missing. A missing file indicates an incomplete :meth:`save`
(e.g., interrupted write).
"""
config_path = path / "config.json"
state_path = path / "state.json"
missing = [p for p in (config_path, state_path) if not p.exists()]
if missing:
missing_names = ", ".join(p.name for p in missing)
msg = f"Incomplete save at {path!r}: missing {missing_names}. The save may have been interrupted."
raise FileNotFoundError(msg)
config = EloModelConfig.model_validate_json(config_path.read_text())
instance = cls(config)
raw = json.loads(state_path.read_text())
state = {
"ratings": {int(k): v for k, v in raw["ratings"].items()},
"game_counts": {int(k): v for k, v in raw["game_counts"].items()},
}
instance.set_state(state)
loaded_fc = load_feature_config(path)
if loaded_fc is not None:
instance.feature_config = loaded_fc
return instance
# ------------------------------------------------------------------
# Model ABC: config
# ------------------------------------------------------------------
[docs]
def get_feature_importances(self) -> list[tuple[str, float]] | None:
"""Return top team Elo ratings as interpretability information."""
ratings = self._engine.get_all_ratings()
if not ratings:
return None
sorted_ratings = sorted(ratings.items(), key=lambda x: x[1], reverse=True)
top_n = sorted_ratings[:50]
return [(f"team_{team_id}", rating) for team_id, rating in top_n]
[docs]
def get_config(self) -> EloModelConfig:
"""Return the Pydantic-validated configuration."""
return self._config
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
@staticmethod
def _to_elo_config(config: EloModelConfig) -> EloConfig:
"""Convert Pydantic config to the frozen dataclass the engine expects.
Uses :func:`dataclasses.fields` to derive the argument set from
``EloConfig`` at runtime, so any new field added to ``EloConfig`` is
automatically included — without requiring a manual update here.
``EloModelConfig`` must keep its fields in sync with ``EloConfig``.
"""
elo_field_names = {f.name for f in dataclasses.fields(EloConfig)}
kwargs = {k: v for k, v in config.model_dump().items() if k in elo_field_names}
return EloConfig(**kwargs)