"""Prediction orchestration for CLI predict command.
Loads a trained model, generates win-probability predictions for a target
season, and formats output as CSV. Supports both stateful (Elo) and
stateless (XGBoost, LogisticRegression) model types.
"""
from __future__ import annotations
import csv
import io
import sys
from pathlib import Path
import pandas as pd # type: ignore[import-untyped]
from rich.console import Console
from ncaa_eval.cli.train import _setup_feature_server
from ncaa_eval.evaluation.bracket import MatchupContext
from ncaa_eval.evaluation.kaggle_export import KAGGLE_NEUTRAL_DAY_NUM
from ncaa_eval.evaluation.providers import EloProvider, build_probability_matrix
from ncaa_eval.ingest import ParquetRepository
from ncaa_eval.model.base import Model, StatefulModel
from ncaa_eval.model.ensemble import StackedEnsemble
from ncaa_eval.model.tracking import RunStore
def _build_stateful_predictions(
*,
model: StatefulModel,
season: int,
repo: ParquetRepository,
) -> list[tuple[int, int, int, float]]:
"""Build pairwise predictions for a stateful (Elo) model.
Returns:
List of ``(season, team_a_id, team_b_id, pred_win_prob)`` tuples
for all C(n,2) team pairs where ``team_a_id < team_b_id``.
"""
games = repo.get_games(season)
if not games:
msg = f"No games found for season {season}"
raise FileNotFoundError(msg)
team_id_set: set[int] = set()
for g in games:
team_id_set.add(g.w_team_id)
team_id_set.add(g.l_team_id)
team_ids = sorted(team_id_set)
provider = EloProvider(model)
context = MatchupContext(season=season, day_num=KAGGLE_NEUTRAL_DAY_NUM, is_neutral=True)
prob_matrix = build_probability_matrix(provider, team_ids, context)
rows: list[tuple[int, int, int, float]] = []
n = len(team_ids)
for i in range(n):
for j in range(i + 1, n):
rows.append((season, team_ids[i], team_ids[j], float(prob_matrix[i, j])))
return rows
def _build_stateless_predictions(
*,
model: Model,
run_id: str,
season: int,
data_dir: Path,
store: RunStore,
) -> list[tuple[int, int, int, float]]:
"""Build game-level predictions for a stateless model.
Returns:
List of ``(season, team_a_id, team_b_id, pred_win_prob)`` tuples
for all games in the season dataset.
"""
feat_names = store.load_feature_names(run_id)
if feat_names is None:
msg = f"No feature names saved for run {run_id!r}"
raise FileNotFoundError(msg)
feature_config = model.feature_config
server = _setup_feature_server(data_dir, feature_config)
df = server.serve_season_features(season, mode="batch")
if df.empty:
msg = f"No game data found for season {season}"
raise FileNotFoundError(msg)
probs: pd.Series[float] = model.predict_proba(df[feat_names])
rows: list[tuple[int, int, int, float]] = []
for idx, prob in probs.items():
row = df.loc[idx]
rows.append(
(
season,
int(row["team_a_id"]),
int(row["team_b_id"]),
float(min(max(prob, 0.0), 1.0)),
)
)
return rows
def _build_ensemble_predictions(
*,
ensemble: StackedEnsemble,
season: int,
data_dir: Path,
) -> list[tuple[int, int, int, float]]:
"""Build pairwise predictions for a stacked ensemble.
Calls ``ensemble.predict_bracket`` to generate an n×n probability
matrix, then extracts upper-triangle rows (team_a_id < team_b_id)
matching the single-model CSV format.
Returns:
List of ``(season, team_a_id, team_b_id, pred_win_prob)`` tuples
for all C(n,2) team pairs where ``team_a_id < team_b_id``.
"""
prob_df = ensemble.predict_bracket(data_dir, season)
team_ids = list(prob_df.index)
rows: list[tuple[int, int, int, float]] = []
n = len(team_ids)
for i in range(n):
for j in range(i + 1, n):
rows.append(
(
season,
team_ids[i],
team_ids[j],
float(min(max(float(prob_df.iloc[i, j]), 0.0), 1.0)),
)
)
return rows
[docs]
def build_predictions(*, run_id: str, season: int, data_dir: Path) -> str:
"""Load a model and return a predictions CSV string.
Orchestration layer: loads the model and season data from disk, routes
to the appropriate prediction path (stateful or stateless), and formats
the result as a CSV string. Callers decide where to write the output.
Args:
run_id: Model run identifier.
season: Target season year (e.g. 2026).
data_dir: Path to the local data directory.
Returns:
CSV string with ``season,team_a_id,team_b_id,pred_win_prob`` header.
Raises:
FileNotFoundError: If the run, model, or season data cannot be loaded.
"""
store = RunStore(base_path=data_dir)
model = store.load_model(run_id)
if model is None:
msg = f"No model found for run {run_id!r}"
raise FileNotFoundError(msg)
if isinstance(model, StackedEnsemble):
rows = _build_ensemble_predictions(
ensemble=model,
season=season,
data_dir=data_dir,
)
return format_predictions_csv(rows)
repo = ParquetRepository(base_path=data_dir)
if isinstance(model, StatefulModel):
rows = _build_stateful_predictions(model=model, season=season, repo=repo)
else:
rows = _build_stateless_predictions(
model=model,
run_id=run_id,
season=season,
data_dir=data_dir,
store=store,
)
return format_predictions_csv(rows)
[docs]
def run_predict(
*,
run_id: str,
season: int,
data_dir: Path,
output: Path | None,
console: Console | None = None,
) -> str:
"""Load a model and produce a predictions CSV.
Thin CLI wrapper around ``build_predictions`` that handles
progress output and writing to a file or stdout.
Args:
run_id: Model run identifier.
season: Target season year (e.g. 2026).
data_dir: Path to the local data directory.
output: File path to write the CSV. ``None`` means stdout.
console: Rich Console instance for status output.
Returns:
The CSV string.
Raises:
FileNotFoundError: If the run or model cannot be loaded.
TypeError: If the model lacks a ``feature_config`` attribute
(e.g. a malformed plugin).
AttributeError: If the model subclass did not set ``feature_config``.
"""
# When writing CSV to stdout, route status messages to stderr so the
# output stream stays pipe-safe (no non-CSV lines mixed in).
con = console or Console(stderr=output is None)
con.print(f"Generating predictions for season {season}...")
csv_str = build_predictions(run_id=run_id, season=season, data_dir=data_dir)
if output is not None:
output.write_text(csv_str)
con.print(f"[green]Predictions written to {output}[/green]")
else:
sys.stdout.write(csv_str)
return csv_str