Source code for ncaa_eval.cli.train

"""Training pipeline orchestration.

Assembles feature serving, model training, prediction generation, and
run tracking into a single ``run_training()`` function consumed by the
Typer CLI entry point.
"""

from __future__ import annotations

import copy
import json
import logging
import subprocess
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pandas as pd  # type: ignore[import-untyped]
from rich.console import Console
from rich.progress import Progress
from rich.table import Table

from ncaa_eval.evaluation import BacktestResult, feature_cols as _feature_cols, run_backtest
from ncaa_eval.evaluation.backtest import _randomize_team_assignment
from ncaa_eval.ingest import ParquetRepository
from ncaa_eval.model.base import Model, StatefulModel
from ncaa_eval.model.ensemble import StackedEnsemble
from ncaa_eval.model.tracking import ModelRun, Prediction, RunStore
from ncaa_eval.transform.feature_serving import FeatureConfig, StatefulFeatureServer
from ncaa_eval.transform.serving import ChronologicalDataServer

_logger = logging.getLogger(__name__)


@dataclass
class _TrainingContext:
    """Internal context passed between pipeline stages."""

    model: Model
    model_name: str
    start_year: int
    end_year: int
    is_stateful: bool
    console: Console
    store: RunStore
    server: StatefulFeatureServer


def _get_git_hash() -> str:
    """Return the short git hash of HEAD, or ``"unknown"`` on failure."""
    try:
        result = subprocess.run(
            ["git", "rev-parse", "--short", "HEAD"],
            capture_output=True,
            text=True,
            check=True,
        )
        return result.stdout.strip()
    except (subprocess.CalledProcessError, FileNotFoundError):
        return "unknown"


def _build_fold_predictions(result: BacktestResult) -> pd.DataFrame | None:
    """Build a fold predictions DataFrame from backtest results.

    Args:
        result: Backtest result containing fold results with game metadata.

    Returns:
        DataFrame with columns [year, game_id, team_a_id, team_b_id,
        pred_win_prob, team_a_won], or None if no fold predictions exist.
    """
    fold_frames: list[pd.DataFrame] = []
    for fr in result.fold_results:
        if fr.predictions.empty:
            continue
        fold_frames.append(
            pd.DataFrame(
                {
                    "year": fr.year,
                    "game_id": fr.test_game_ids.values,
                    "team_a_id": fr.test_team_a_ids.values,
                    "team_b_id": fr.test_team_b_ids.values,
                    "pred_win_prob": fr.predictions.values,
                    "team_a_won": fr.actuals.values,
                }
            )
        )
    if not fold_frames:
        return None
    return pd.concat(fold_frames, ignore_index=True)


def _setup_feature_server(data_dir: Path, feature_config: FeatureConfig) -> StatefulFeatureServer:
    """Initialize the repository, data server, and feature server."""
    repo = ParquetRepository(base_path=data_dir)
    data_server = ChronologicalDataServer(repo)
    return StatefulFeatureServer(config=feature_config, data_server=data_server)


def _build_season_features(ctx: _TrainingContext) -> list[pd.DataFrame]:
    """Build feature matrices per season with a progress display."""
    season_frames: list[pd.DataFrame] = []
    with Progress() as progress:
        task = progress.add_task(
            "Building features...",
            total=ctx.end_year - ctx.start_year + 1,
        )
        for year in range(ctx.start_year, ctx.end_year + 1):
            mode: Literal["batch", "stateful"] = "stateful" if ctx.is_stateful else "batch"
            df = ctx.server.serve_season_features(year, mode=mode)
            if not df.empty:
                season_frames.append(df)
            progress.advance(task)
    return season_frames


def _prepare_and_train(ctx: _TrainingContext, combined: pd.DataFrame) -> list[str]:
    """Extract labels, check balance, compute feature columns, and train.

    Extracts ``team_a_won`` as integer labels and warns if the label mean
    is outside ``[0.05, 0.95]`` (heavy imbalance).  Computes ``feat_cols``
    via ``_feature_cols(combined)``.  For stateful models, passes the full
    ``combined`` DataFrame (model uses internal state for features); for
    stateless models, slices to ``combined[feat_cols]`` before calling
    ``model.fit``.

    Returns:
        List of feature column names used for training.
    """
    # Stateless classifiers require balanced labels; the feature server
    # assigns team_a = winner for every game, making team_a_won always True.
    if not ctx.is_stateful:
        combined = _randomize_team_assignment(combined)

    y = combined["team_a_won"].astype(int)

    label_mean = y.mean()
    if label_mean > 0.95 or label_mean < 0.05:
        ctx.console.print(
            f"[yellow]Warning: labels are heavily imbalanced "
            f"(mean={label_mean:.3f}). Consider randomising team assignment "
            f"or adjusting scale_pos_weight.[/yellow]"
        )

    feat_cols = _feature_cols(combined)
    if not ctx.is_stateful:
        # Drop columns that are entirely NaN (e.g. seed features without a seed
        # table) so sklearn classifiers that reject NaN inputs can still fit.
        feat_cols = [c for c in feat_cols if not combined[c].isna().all()]

    ctx.console.print(f"Training [bold]{ctx.model_name}[/bold] on seasons {ctx.start_year}{ctx.end_year}...")
    if ctx.is_stateful:
        ctx.model.fit(combined, y)
    else:
        ctx.model.fit(combined[feat_cols], y)

    return feat_cols


def _generate_tournament_predictions(
    ctx: _TrainingContext,
    combined: pd.DataFrame,
    feat_cols: list[str],
    run_id: str,
) -> list[Prediction]:
    """Generate predictions on tournament games."""
    tourney = combined[combined["is_tournament"] == True].copy()  # noqa: E712
    predictions: list[Prediction] = []

    if not tourney.empty:
        if ctx.is_stateful:
            probs = ctx.model.predict_proba(tourney)
        else:
            probs = ctx.model.predict_proba(tourney[feat_cols])

        for idx, prob in probs.items():
            row = tourney.loc[idx]
            predictions.append(
                Prediction(
                    run_id=run_id,
                    game_id=str(row["game_id"]),
                    season=int(row["season"]),
                    team_a_id=int(row["team_a_id"]),
                    team_b_id=int(row["team_b_id"]),
                    pred_win_prob=float(min(max(prob, 0.0), 1.0)),
                )
            )

    return predictions


def _run_backtest_and_persist(ctx: _TrainingContext, run_id: str) -> None:
    """Run walk-forward backtest and persist metrics and fold predictions.

    Guards on ``len(seasons) >= 2`` — a single season cannot produce
    walk-forward folds.  Deep-copies the model before passing it to
    ``run_backtest`` to prevent the backtest's sequential ``fit`` calls
    from mutating the already-trained model held in ``ctx``.  Saves the
    summary metrics and, if fold-level predictions exist, the per-game
    prediction DataFrame via ``RunStore``.
    """
    seasons = list(range(ctx.start_year, ctx.end_year + 1))
    if len(seasons) >= 2:
        ctx.console.print("Running walk-forward backtest...")
        # Deep-copy to avoid mutating the trained model: run_backtest
        # calls model.fit() on each fold, which would overwrite ctx.model.
        backtest_model = copy.deepcopy(ctx.model)
        mode: Literal["batch", "stateful"] = "stateful" if ctx.is_stateful else "batch"
        result = run_backtest(
            backtest_model,
            ctx.server,
            seasons=seasons,
            mode=mode,
            n_jobs=1,
            console=ctx.console,
        )
        ctx.store.save_metrics(run_id, result.summary)

        fold_preds = _build_fold_predictions(result)
        if fold_preds is not None:
            ctx.store.save_fold_predictions(run_id, fold_preds)

        ctx.console.print("[green]Backtest metrics persisted.[/green]")
    else:
        ctx.console.print("[yellow]Skipping backtest: need ≥ 2 seasons.[/yellow]")


def _persist_artifacts_and_summarize(
    ctx: _TrainingContext,
    run: ModelRun,
    feat_cols: list[str],
    combined: pd.DataFrame,
    predictions: list[Prediction],
) -> None:
    """Save the trained model and print a summary table."""
    ctx.store.save_model(run.run_id, ctx.model, feature_names=feat_cols)
    ctx.console.print("[green]Model artifacts persisted.[/green]")

    table = Table(title="Training Results")
    table.add_column("Field", style="cyan")
    table.add_column("Value", style="green")
    table.add_row("Run ID", run.run_id)
    table.add_row("Model", ctx.model_name)
    table.add_row("Seasons", f"{ctx.start_year}{ctx.end_year}")
    table.add_row("Games trained", str(len(combined)))
    table.add_row("Tournament predictions", str(len(predictions)))
    table.add_row("Git hash", run.git_hash)
    ctx.console.print(table)


[docs] def run_training( # noqa: PLR0913 model: Model | StackedEnsemble, *, start_year: int, end_year: int, data_dir: Path, output_dir: Path, model_name: str, console: Console | None = None, ) -> ModelRun: """Execute the full train → predict → persist pipeline. Dispatches to ``_run_ensemble_training`` when *model* is a ``StackedEnsemble``; otherwise runs the single-model pipeline. Args: model: An instantiated model (stateful, stateless, or ensemble). start_year: First season year (inclusive) for training. end_year: Last season year (inclusive) for training. data_dir: Path to the local Parquet data store. output_dir: Path where run artifacts are persisted. model_name: Registered plugin name (used in the ModelRun record). console: Rich Console instance for terminal output. Defaults to a fresh ``Console()`` so callers (e.g. tests) can suppress output by passing ``Console(quiet=True)``. Returns: The persisted run metadata record. """ _console = console or Console() if isinstance(model, StackedEnsemble): return _run_ensemble_training( model, start_year=start_year, end_year=end_year, data_dir=data_dir, output_dir=output_dir, model_name=model_name, console=_console, ) server = _setup_feature_server(data_dir, model.feature_config) store = RunStore(base_path=output_dir) ctx = _TrainingContext( model=model, model_name=model_name, start_year=start_year, end_year=end_year, is_stateful=isinstance(model, StatefulModel), console=_console, store=store, server=server, ) # Build feature matrices per season season_frames = _build_season_features(ctx) if not season_frames: _console.print("[yellow]No game data found for the specified year range.[/yellow]") run_id = str(uuid.uuid4()) run = ModelRun( run_id=run_id, model_type=model_name, hyperparameters=model.get_config().model_dump(), git_hash=_get_git_hash(), start_year=start_year, end_year=end_year, prediction_count=0, ) store.save_run(run, []) return run combined = pd.concat(season_frames, ignore_index=True) # Train model feat_cols = _prepare_and_train(ctx, combined) # Generate predictions run_id = str(uuid.uuid4()) predictions = _generate_tournament_predictions(ctx, combined, feat_cols, run_id) # Persist run run = ModelRun( run_id=run_id, model_type=model_name, hyperparameters=model.get_config().model_dump(), git_hash=_get_git_hash(), start_year=start_year, end_year=end_year, prediction_count=len(predictions), ) store.save_run(run, predictions) # Backtest _run_backtest_and_persist(ctx, run.run_id) # Save model and summarize _persist_artifacts_and_summarize(ctx, run, feat_cols, combined, predictions) return run
# ── Ensemble training pipeline ───────────────────────────────────────────── def _collect_oof_predictions( # noqa: PLR0913 base_model: Model, base_idx: int, data_dir: Path, start_year: int, end_year: int, console: Console, ) -> pd.DataFrame: """Run a walk-forward backtest for one base model and return OOF predictions. Returns a DataFrame with columns ``[year, game_id, team_a_id, team_b_id, pred_base_<idx>, team_a_won]``. Contextual features (``seed_diff``, ``is_tournament``, ``loc_encoding``) are NOT included here; they are joined in Step 3 of ``_run_ensemble_training`` from the union feature server. """ server = _setup_feature_server(data_dir, base_model.feature_config) seasons = list(range(start_year, end_year + 1)) mode: Literal["batch", "stateful"] = "stateful" if isinstance(base_model, StatefulModel) else "batch" console.print(f" OOF backtest for base model {base_idx} ({type(base_model).__name__})...") result = run_backtest( copy.deepcopy(base_model), server, seasons=seasons, mode=mode, n_jobs=1, console=console, ) fold_preds = _build_fold_predictions(result) if fold_preds is None: return pd.DataFrame() fold_preds = fold_preds.rename(columns={"pred_win_prob": f"pred_base_{base_idx}"}) return fold_preds def _align_oof_predictions( oof_frames: list[pd.DataFrame], ) -> pd.DataFrame: """Inner-join OOF predictions from all base models on ``game_id``. Logs a warning if >5% of games are dropped by the join. """ if not oof_frames: return pd.DataFrame() aligned = oof_frames[0] for frame in oof_frames[1:]: pred_cols = [c for c in frame.columns if c.startswith("pred_base_")] aligned = aligned.merge( frame[["game_id", *pred_cols]], on="game_id", how="inner", ) max_len = max(len(f) for f in oof_frames) if max_len > 0: drop_pct = 1.0 - len(aligned) / max_len if drop_pct > 0.05: _logger.warning( "OOF alignment dropped %.1f%% of games (%d%d)", drop_pct * 100, max_len, len(aligned), ) return aligned def _build_meta_training_set( aligned: pd.DataFrame, contextual_features: list[str], ) -> tuple[pd.DataFrame, pd.Series]: """Build meta-learner X and y from aligned OOF predictions. Returns (meta_X, meta_y) where meta_X has columns ``[pred_base_0, ..., <contextual_features>]`` and meta_y is ``team_a_won``. Contextual features with missing values (e.g. ``seed_diff`` for regular-season games that lack tournament seeds) are filled with 0 so that sklearn estimators that reject NaN can still fit. """ pred_cols = sorted(c for c in aligned.columns if c.startswith("pred_base_")) context_cols = [c for c in contextual_features if c in aligned.columns] meta_X = aligned[pred_cols + context_cols].copy() if context_cols: meta_X[context_cols] = meta_X[context_cols].fillna(0) meta_y = aligned["team_a_won"].astype(int) return meta_X, meta_y def _retrain_base_models( ensemble: StackedEnsemble, data_dir: Path, start_year: int, end_year: int, console: Console, ) -> None: """Retrain each base model on the full dataset.""" for i, base_model in enumerate(ensemble.base_models): console.print(f" Retraining base model {i} on full dataset...") server = _setup_feature_server(data_dir, base_model.feature_config) is_stateful = isinstance(base_model, StatefulModel) season_frames: list[pd.DataFrame] = [] for year in range(start_year, end_year + 1): mode: Literal["batch", "stateful"] = "stateful" if is_stateful else "batch" df = server.serve_season_features(year, mode=mode) if not df.empty: season_frames.append(df) if not season_frames: continue combined = pd.concat(season_frames, ignore_index=True) if not is_stateful: combined = _randomize_team_assignment(combined) y = combined["team_a_won"].astype(int) if is_stateful: base_model.fit(combined, y) else: feat_cols = _feature_cols(combined) feat_cols = [c for c in feat_cols if not combined[c].isna().all()] base_model.fit(combined[feat_cols], y) def _run_ensemble_training( # noqa: PLR0913 ensemble: StackedEnsemble, *, start_year: int, end_year: int, data_dir: Path, output_dir: Path, model_name: str, console: Console, ) -> ModelRun: """Execute the stacked ensemble training pipeline. Steps: 1. OOF generation — walk-forward backtest per base model 2. OOF alignment — inner join on game_id 3. Meta-training set construction 4. Meta-learner training 5. Full-dataset base model retraining 6. Artifact persistence """ store = RunStore(base_path=output_dir) # Step 1: OOF generation console.print("[bold]Step 1/6: Generating OOF predictions...[/bold]") oof_frames: list[pd.DataFrame] = [] oof_run_ids: list[str] = [] for i, base_model in enumerate(ensemble.base_models): oof_df = _collect_oof_predictions( base_model, base_idx=i, data_dir=data_dir, start_year=start_year, end_year=end_year, console=console, ) oof_frames.append(oof_df) oof_run_ids.append(str(uuid.uuid4())) # Step 2: OOF alignment console.print("[bold]Step 2/6: Aligning OOF predictions...[/bold]") aligned = _align_oof_predictions(oof_frames) if aligned.empty: console.print("[yellow]No aligned OOF predictions — aborting ensemble training.[/yellow]") run_id = str(uuid.uuid4()) run = ModelRun( run_id=run_id, model_type="ensemble", hyperparameters=ensemble.get_config().model_dump(), git_hash=_get_git_hash(), start_year=start_year, end_year=end_year, prediction_count=0, ) store.save_run(run, []) return run console.print(f" Aligned OOF games: {len(aligned)}") # Step 3: Meta-training set console.print("[bold]Step 3/6: Building meta-training set...[/bold]") # Contextual features (seed_diff, is_tournament, loc_encoding) are # present in the OOF fold predictions from _build_fold_predictions # via _evaluate_fold's test data, but _build_fold_predictions only # returns game_id/pred/actuals. We need to get them from a feature # server. Build one with the ensemble's union feature_config. union_server = _setup_feature_server(data_dir, ensemble.feature_config) context_frames: list[pd.DataFrame] = [] for year in range(start_year, end_year + 1): df = union_server.serve_season_features(year, mode="batch") if not df.empty: context_frames.append(df[["game_id", *ensemble.contextual_features]]) if context_frames: context_df = pd.concat(context_frames, ignore_index=True) # The feature server assigns team_a = winner always, so game_ids # are unique per game. Inner join: if a game_id is missing from the # context server it likely indicates a data inconsistency; NaN # contextual features would silently poison the meta-learner. before_merge = len(aligned) aligned = aligned.merge(context_df, on="game_id", how="inner") if len(aligned) < before_merge: _logger.warning( "Context feature join dropped %d OOF games (%d%d); " "possible game_id mismatch between base model OOF and union feature server.", before_merge - len(aligned), before_merge, len(aligned), ) meta_X, meta_y = _build_meta_training_set(aligned, ensemble.contextual_features) meta_column_order = list(meta_X.columns) ensemble.meta_column_order = meta_column_order console.print(f" Meta-training shape: {meta_X.shape}") # Step 4: Meta-learner training console.print("[bold]Step 4/6: Training meta-learner...[/bold]") ensemble.meta_learner.fit(meta_X, meta_y) # Step 5: Retrain base models on full dataset console.print("[bold]Step 5/6: Retraining base models on full dataset...[/bold]") _retrain_base_models(ensemble, data_dir, start_year, end_year, console) # Step 6: Persist console.print("[bold]Step 6/6: Persisting ensemble artifacts...[/bold]") run_id = str(uuid.uuid4()) run = ModelRun( run_id=run_id, model_type="ensemble", hyperparameters=ensemble.get_config().model_dump(), git_hash=_get_git_hash(), start_year=start_year, end_year=end_year, prediction_count=0, ) store.save_run(run, []) store.save_model(run_id, ensemble) # Augment manifest with OOF metadata model_dir = store.model_dir(run_id) manifest_path = model_dir / "manifest.json" manifest = json.loads(manifest_path.read_text()) manifest["meta_column_order"] = meta_column_order manifest["oof_backtest_run_ids"] = oof_run_ids manifest["oof_game_count"] = len(aligned) max_len = max(len(f) for f in oof_frames) if oof_frames else 0 manifest["oof_drop_pct"] = round(1.0 - len(aligned) / max_len, 4) if max_len > 0 else 0.0 manifest_path.write_text(json.dumps(manifest, indent=2)) # Save OOF aligned data for post-hoc analysis (e.g. tutorial notebook) oof_path = model_dir / "oof_aligned.parquet" oof_export = aligned.copy() # Add ensemble-level OOF predictions from the meta-learner oof_export["pred_ensemble"] = ensemble.meta_learner.predict_proba(meta_X).values oof_export.to_parquet(oof_path, index=False) # Summary table table = Table(title="Ensemble Training Results") table.add_column("Field", style="cyan") table.add_column("Value", style="green") table.add_row("Run ID", run.run_id) table.add_row("Model", model_name) table.add_row("Base models", str(len(ensemble.base_models))) table.add_row("Seasons", f"{start_year}{end_year}") table.add_row("OOF games", str(len(aligned))) table.add_row("Meta features", str(len(meta_column_order))) table.add_row("Git hash", run.git_hash) console.print(table) return run