"""Stacked ensemble model — orchestrates base models and a meta-learner.
``StackedEnsemble`` is a standalone ``@dataclass`` (not a ``Model`` subclass)
that holds a list of base ``Model`` instances and a stateless meta-learner.
The training pipeline in ``cli/train.py`` dispatches on
``isinstance(model, StackedEnsemble)`` to invoke ensemble-specific training.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from itertools import combinations
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import numpy.typing as npt
import pandas as pd # type: ignore[import-untyped]
from pydantic import Field as _Field
from ncaa_eval.model._feature_config_io import save_feature_config
from ncaa_eval.model.base import Model, ModelConfig, StatefulModel
from ncaa_eval.model.registry import get_model, register_model
from ncaa_eval.transform.feature_serving import (
DatasetScope,
FeatureConfig,
GenderScope,
OrdinalCompositeMethod,
)
if TYPE_CHECKING:
from ncaa_eval.transform.elo import EloConfig
logger = logging.getLogger(__name__)
# ── Feature-config union helpers (extracted for complexity budget) ──────────
def _resolve_elo(
configs: list[FeatureConfig],
) -> tuple[bool, EloConfig | None]:
"""Return (elo_enabled, elo_config) from the union of *configs*."""
elo_enabled = any(c.elo_enabled for c in configs)
elo_config: EloConfig | None = None
if elo_enabled:
for c in configs:
if c.elo_enabled and c.elo_config is not None:
elo_config = c.elo_config
break
return elo_enabled, elo_config
def _resolve_ordinals(
configs: list[FeatureConfig],
) -> tuple[OrdinalCompositeMethod | None, tuple[str, ...] | None]:
"""Return (ordinal_composite, ordinal_systems) from the union of *configs*."""
ordinal_composite: OrdinalCompositeMethod | None = None
for c in configs:
if c.ordinal_composite is not None:
ordinal_composite = c.ordinal_composite
break
systems: set[str] = set()
any_systems = False
for c in configs:
if c.ordinal_systems is not None:
any_systems = True
systems.update(c.ordinal_systems)
ordinal_systems: tuple[str, ...] | None = tuple(sorted(systems)) if any_systems else None
return ordinal_composite, ordinal_systems
def _assert_agreement(
configs: list[FeatureConfig],
) -> tuple[bool, GenderScope, DatasetScope]:
"""Assert matchup_deltas / gender_scope / dataset_scope agree across *configs*."""
matchup_deltas_set = {c.matchup_deltas for c in configs}
if len(matchup_deltas_set) > 1:
msg = "All base models must agree on matchup_deltas"
raise ValueError(msg)
gender_scope_set = {c.gender_scope for c in configs}
if len(gender_scope_set) > 1:
msg = "All base models must agree on gender_scope"
raise ValueError(msg)
dataset_scope_set = {c.dataset_scope for c in configs}
if len(dataset_scope_set) > 1:
msg = "All base models must agree on dataset_scope"
raise ValueError(msg)
return matchup_deltas_set.pop(), gender_scope_set.pop(), dataset_scope_set.pop()
# ── Bracket prediction helpers (extracted for complexity budget) ─────────
def _discover_team_ids(data_dir: Path, season: int) -> list[int]:
"""Discover tournament-eligible team IDs from game data.
Prefers tournament games (``is_tournament=True``) so the result is
limited to the ~64 teams that actually competed in the bracket. Falls
back to all season games if no tournament games are found (e.g. for
seasons before the tournament or when running on partial data).
"""
from ncaa_eval.ingest import ParquetRepository
repo = ParquetRepository(base_path=data_dir)
games = repo.get_games(season)
if not games:
msg = f"No season data for bracket prediction: season {season}"
raise FileNotFoundError(msg)
tourney_games = [g for g in games if g.is_tournament]
source_games = tourney_games if tourney_games else games
team_id_set: set[int] = set()
for g in source_games:
team_id_set.add(g.w_team_id)
team_id_set.add(g.l_team_id)
return sorted(team_id_set)
def _stateful_base_matrix(
model: StatefulModel,
team_ids: list[int],
season: int,
) -> npt.NDArray[np.float64]:
"""Build pairwise probability matrix for a stateful base model."""
from ncaa_eval.evaluation.bracket import MatchupContext
from ncaa_eval.evaluation.kaggle_export import KAGGLE_NEUTRAL_DAY_NUM
from ncaa_eval.evaluation.providers import EloProvider, build_probability_matrix
provider = EloProvider(model)
context = MatchupContext(
season=season,
day_num=KAGGLE_NEUTRAL_DAY_NUM,
is_neutral=True,
)
return build_probability_matrix(provider, team_ids, context)
def _extract_team_features(
season_df: pd.DataFrame,
team_ids: list[int],
) -> dict[int, dict[str, float]]:
"""Extract per-team feature profiles from season game data.
For each team, averages their features across all games where they
appear as either team_a or team_b. Returns a mapping of
``team_id -> {feature_name_without_suffix: value}``.
"""
# Identify per-team columns (those ending with _a or _b)
a_cols = [c for c in season_df.columns if c.endswith("_a") and c != "team_a_id"]
b_cols = [c for c in season_df.columns if c.endswith("_b") and c != "team_b_id"]
# Map base feature name to (a_col, b_col)
base_names: dict[str, tuple[str, str]] = {}
for ac in a_cols:
base = ac[:-2] # strip _a
bc = f"{base}_b"
if bc in b_cols:
base_names[base] = (ac, bc)
team_profiles: dict[int, dict[str, float]] = {}
for tid in team_ids:
# Games where this team is team_a
as_a = season_df[season_df["team_a_id"] == tid]
# Games where this team is team_b
as_b = season_df[season_df["team_b_id"] == tid]
profile: dict[str, float] = {}
for base, (ac, bc) in base_names.items():
values: list[float] = []
if not as_a.empty and ac in as_a.columns:
vals = as_a[ac].dropna()
values.extend(vals.tolist())
if not as_b.empty and bc in as_b.columns:
vals = as_b[bc].dropna()
values.extend(vals.tolist())
profile[base] = float(np.mean(values)) if values else 0.0
team_profiles[tid] = profile
return team_profiles
def _build_synthetic_feature_rows(
team_profiles: dict[int, dict[str, float]],
team_ids: list[int],
feat_names: list[str],
) -> pd.DataFrame:
"""Build synthetic feature rows for all C(n,2) team pairings.
Constructs ``_a``/``_b`` columns from team profiles and computes
``delta_*`` and ``seed_diff`` columns.
"""
rows: list[dict[str, float]] = []
pair_indices: list[tuple[int, int]] = []
for i, j in combinations(range(len(team_ids)), 2):
tid_a, tid_b = team_ids[i], team_ids[j]
prof_a = team_profiles.get(tid_a, {})
prof_b = team_profiles.get(tid_b, {})
row: dict[str, float] = {}
# Populate _a and _b columns (dict.fromkeys preserves insertion order,
# deduplicating without non-deterministic set iteration)
for base in dict.fromkeys(list(prof_a.keys()) + list(prof_b.keys())):
row[f"{base}_a"] = prof_a.get(base, 0.0)
row[f"{base}_b"] = prof_b.get(base, 0.0)
row[f"delta_{base}"] = row[f"{base}_a"] - row[f"{base}_b"]
# Special: seed_diff (derived from seed_num)
if "seed_num" in prof_a or "seed_num" in prof_b:
row["seed_diff"] = prof_a.get("seed_num", 0.0) - prof_b.get("seed_num", 0.0)
rows.append(row)
pair_indices.append((i, j))
if not rows:
return pd.DataFrame()
df = pd.DataFrame(rows)
# Fill missing feature columns with 0
for c in feat_names:
if c not in df.columns:
df[c] = 0.0
df.attrs["pair_indices"] = pair_indices
return df[feat_names]
def _stateless_base_matrix(
model: Model,
team_ids: list[int],
data_dir: Path,
season: int,
) -> npt.NDArray[np.float64]:
"""Build pairwise probability matrix for a stateless base model."""
from ncaa_eval.cli.train import _setup_feature_server
feat_names: list[str] = model.feature_names_ # type: ignore[attr-defined]
server = _setup_feature_server(data_dir, model.feature_config)
season_df = server.serve_season_features(season, mode="batch")
team_profiles = _extract_team_features(season_df, team_ids)
synthetic_df = _build_synthetic_feature_rows(team_profiles, team_ids, feat_names)
n = len(team_ids)
P = np.zeros((n, n), dtype=np.float64)
if synthetic_df.empty:
return P
pair_indices: list[tuple[int, int]] = synthetic_df.attrs["pair_indices"]
probs = model.predict_proba(synthetic_df)
for idx, (i, j) in enumerate(pair_indices):
p = float(probs.iloc[idx])
P[i, j] = p
P[j, i] = 1.0 - p
return P
def _predict_bracket_base_matrices(
base_models: list[Model],
data_dir: Path,
season: int,
) -> tuple[list[int], list[npt.NDArray[np.float64]]]:
"""Build per-base-model probability matrices.
Returns:
Tuple of (team_ids, list of n×n probability matrices).
"""
team_ids = _discover_team_ids(data_dir, season)
matrices: list[npt.NDArray[np.float64]] = []
for base_model in base_models:
if isinstance(base_model, StatefulModel):
mat = _stateful_base_matrix(base_model, team_ids, season)
else:
mat = _stateless_base_matrix(base_model, team_ids, data_dir, season)
matrices.append(mat)
return team_ids, matrices
def _build_bracket_contextual_features(
data_dir: Path,
season: int,
team_ids: list[int],
contextual_features: list[str],
feature_config: FeatureConfig,
) -> dict[str, npt.NDArray[np.float64]]:
"""Build contextual feature vectors for all C(n,2) matchups.
Returns a dict mapping feature name to a 1-D array of length C(n,2),
ordered by upper-triangle iteration.
"""
from ncaa_eval.cli.train import _setup_feature_server
if not contextual_features:
return {}
n = len(team_ids)
# Build a feature server to get contextual features, using the ensemble's
# feature_config so dataset_scope and gender_scope are honoured.
server = _setup_feature_server(data_dir, feature_config)
season_df = server.serve_season_features(season, mode="batch")
# Extract per-team contextual profiles
team_profiles = _extract_team_features(season_df, team_ids)
result: dict[str, npt.NDArray[np.float64]] = {}
for feat in contextual_features:
values: list[float] = []
for i, j in combinations(range(n), 2):
tid_a, tid_b = team_ids[i], team_ids[j]
prof_a = team_profiles.get(tid_a, {})
prof_b = team_profiles.get(tid_b, {})
if feat == "seed_diff":
val = prof_a.get("seed_num", 0.0) - prof_b.get("seed_num", 0.0)
elif feat == "is_tournament":
val = 1.0 # bracket prediction is always tournament
elif feat == "loc_encoding":
val = 0.0 # neutral site for bracket
else:
# Generic: use team A's value minus team B's value if available
val = prof_a.get(feat, 0.0) - prof_b.get(feat, 0.0)
values.append(val)
result[feat] = np.array(values, dtype=np.float64)
return result
def _assemble_bracket_meta_predictions(
*,
base_matrices: list[npt.NDArray[np.float64]],
context_features: dict[str, npt.NDArray[np.float64]],
meta_column_order: list[str],
meta_learner: Model,
n: int,
) -> npt.NDArray[np.float64]:
"""Assemble meta-input from base matrices + context and predict.
Returns the final n×n probability matrix.
"""
# Extract upper-triangle predictions from each base matrix
rows_idx, cols_idx = np.triu_indices(n, k=1)
expected_len = len(rows_idx) # C(n, 2)
meta_parts: dict[str, npt.NDArray[np.float64]] = {}
for i, mat in enumerate(base_matrices):
meta_parts[f"pred_base_{i}"] = mat[rows_idx, cols_idx].astype(np.float64)
# Add contextual features (validate length matches upper-triangle count)
for feat, vals in context_features.items():
if len(vals) != expected_len:
msg = f"Context feature {feat!r} has length {len(vals)}, expected {expected_len} (C({n}, 2))"
raise ValueError(msg)
meta_parts[feat] = vals
meta_df = pd.DataFrame(meta_parts)
# Validate column order
if not meta_column_order:
msg = "meta_column_order is empty — ensemble was not trained or loaded correctly"
raise ValueError(msg)
missing = [c for c in meta_column_order if c not in meta_df.columns]
if missing:
msg = f"Missing meta-learner input columns: {missing}"
raise ValueError(msg)
meta_X = meta_df[meta_column_order]
probs = meta_learner.predict_proba(meta_X)
# Fill n×n matrix
P = np.zeros((n, n), dtype=np.float64)
P[rows_idx, cols_idx] = probs.values.astype(np.float64)
P[cols_idx, rows_idx] = 1.0 - probs.values.astype(np.float64)
return P
[docs]
class StackedEnsembleConfig(ModelConfig):
"""Configuration record for a stacked ensemble.
Stores base model types and contextual feature names for
serialisation and run-tracking purposes.
"""
model_name: str = "ensemble"
base_model_types: list[str] = _Field(default_factory=list)
contextual_features: list[str] = _Field(default_factory=list)
# Register so that ``list_models()`` includes ``"ensemble"``
# and ``RunStore.load_model`` can resolve ``model_type="ensemble"``.
# We register a sentinel — the real "load" path is via
# ``StackedEnsemble.load()``, not ``EnsembleSentinel.load()``.
@register_model("ensemble")
class _EnsembleSentinel(Model):
"""Registry placeholder — never instantiated directly."""
feature_config = FeatureConfig()
def fit(self, X: Any, y: Any) -> None: # pragma: no cover
raise NotImplementedError
def predict_proba(self, X: Any) -> Any: # pragma: no cover
raise NotImplementedError
def save(self, path: Path) -> None: # pragma: no cover
raise NotImplementedError
@classmethod
def load(cls, path: Path) -> _EnsembleSentinel: # pragma: no cover
raise NotImplementedError
def get_config(self) -> ModelConfig: # pragma: no cover
raise NotImplementedError
[docs]
@dataclass
class StackedEnsemble:
"""Stacked generalisation ensemble.
Holds a list of base ``Model`` instances and a stateless meta-learner.
The ensemble's ``feature_config`` is the union of all base models' configs.
Attributes:
base_models: Two or more trained (or to-be-trained) base models.
meta_learner: A stateless ``Model`` that learns to combine base
model predictions with contextual features.
contextual_features: Column names appended to OOF predictions
before meta-learner training (e.g. ``seed_diff``).
"""
base_models: list[Model]
meta_learner: Model
contextual_features: list[str] = field(
default_factory=lambda: ["seed_diff", "is_tournament", "loc_encoding"],
)
meta_column_order: list[str] = field(default_factory=list)
def __post_init__(self) -> None:
"""Validate base model count and meta-learner type."""
if isinstance(self.meta_learner, StatefulModel):
msg = "meta_learner must be a stateless Model, not StatefulModel"
raise TypeError(msg)
if len(self.base_models) < 2:
msg = "StackedEnsemble requires at least 2 base models"
raise ValueError(msg)
# ------------------------------------------------------------------
# feature_config — union of all base models
# ------------------------------------------------------------------
@property
def feature_config(self) -> FeatureConfig:
"""Return the union of all base model feature configs."""
configs = [m.feature_config for m in self.base_models]
elo_enabled, elo_config = _resolve_elo(configs)
ordinal_composite, ordinal_systems = _resolve_ordinals(configs)
matchup_deltas, gender_scope, dataset_scope = _assert_agreement(configs)
return FeatureConfig(
sequential_windows=tuple(sorted({w for c in configs for w in c.sequential_windows})),
ewma_alphas=tuple(sorted({a for c in configs for a in c.ewma_alphas})),
graph_features_enabled=any(c.graph_features_enabled for c in configs),
batch_rating_types=tuple(sorted({t for c in configs for t in c.batch_rating_types})),
ordinal_systems=ordinal_systems,
ordinal_composite=ordinal_composite,
matchup_deltas=matchup_deltas,
gender_scope=gender_scope,
dataset_scope=dataset_scope,
elo_enabled=elo_enabled,
elo_config=elo_config,
)
# ------------------------------------------------------------------
# Pydantic config helper (for run tracking)
# ------------------------------------------------------------------
[docs]
def get_config(self) -> StackedEnsembleConfig:
"""Return a serialisable configuration record."""
base_types: list[str] = []
for m in self.base_models:
cfg = m.get_config()
base_types.append(cfg.model_name)
return StackedEnsembleConfig(
base_model_types=base_types,
contextual_features=list(self.contextual_features),
)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
[docs]
def predict_proba(self, X: pd.DataFrame) -> pd.Series:
"""Route features through base models and meta-learner.
For each base model, generates predictions by dispatching
stateless models through ``X[base_model.feature_names_]`` and
stateful models through the full ``X``. Assembles base
predictions and contextual features into a meta-input DataFrame
in ``self.meta_column_order``, then calls the meta-learner.
Args:
X: Feature DataFrame with at least the columns required by
each base model and all contextual features.
Returns:
Series of ensemble win probabilities, indexed like *X*.
Raises:
ValueError: If any column in ``meta_column_order`` is missing
from the assembled meta-input.
"""
base_preds: dict[str, pd.Series[float]] = {}
for i, base_model in enumerate(self.base_models):
col_name = f"pred_base_{i}"
if isinstance(base_model, StatefulModel):
base_preds[col_name] = base_model.predict_proba(X)
else:
feat_names: list[str] = base_model.feature_names_ # type: ignore[attr-defined]
base_preds[col_name] = base_model.predict_proba(X[feat_names])
# Assemble meta-input in training column order
meta_parts: dict[str, pd.Series[float]] = {**base_preds}
for feat in self.contextual_features:
if feat in X.columns:
meta_parts[feat] = X[feat]
meta_df = pd.DataFrame(meta_parts, index=X.index)
# Validate column order
if not self.meta_column_order:
msg = "meta_column_order is empty — ensemble was not trained or loaded correctly"
raise ValueError(msg)
missing = [c for c in self.meta_column_order if c not in meta_df.columns]
if missing:
msg = f"Missing meta-learner input columns: {missing}"
raise ValueError(msg)
# .copy() makes the slice independent so the fillna assignment below
# does not trigger pandas Copy-on-Write warnings in pandas 2.x+ and
# avoids any risk of mutating the source meta_df.
meta_X = meta_df[self.meta_column_order].copy()
# Fill NaN contextual features (e.g. seed_diff for non-tournament
# games) with 0 so sklearn estimators that reject NaN work correctly.
ctx_cols = [c for c in self.contextual_features if c in meta_X.columns]
if ctx_cols:
meta_X[ctx_cols] = meta_X[ctx_cols].fillna(0)
return self.meta_learner.predict_proba(meta_X)
[docs]
def predict_bracket(
self,
data_dir: Path,
season: int,
) -> pd.DataFrame:
"""Generate an n×n pairwise probability matrix for bracket prediction.
Discovers tournament-eligible teams, generates per-base-model
pairwise predictions, assembles meta-input for all C(n,2) matchups,
and returns a probability matrix suitable for the Monte Carlo
bracket simulator.
Args:
data_dir: Path to the local Parquet data store.
season: Target season year.
Returns:
DataFrame of shape ``(n, n)`` with team_id index and columns.
``P[a, b]`` is the ensemble probability that team *a* beats
team *b*. Diagonal is zero; ``P[a,b] + P[b,a] ≈ 1``.
Raises:
FileNotFoundError: If no season data exists for *season*.
ValueError: If any column in ``meta_column_order`` is missing
from the assembled meta-input, or if a context feature
array has unexpected length.
"""
team_ids, base_matrices = _predict_bracket_base_matrices(self.base_models, data_dir, season)
n = len(team_ids)
# Build contextual feature vectors for all C(n,2) pairs
context_features = _build_bracket_contextual_features(
data_dir, season, team_ids, self.contextual_features, self.feature_config
)
# Assemble meta-input for upper triangle and predict
prob_matrix = _assemble_bracket_meta_predictions(
base_matrices=base_matrices,
context_features=context_features,
meta_column_order=self.meta_column_order,
meta_learner=self.meta_learner,
n=n,
)
return pd.DataFrame(prob_matrix, index=team_ids, columns=team_ids)
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
[docs]
def save(self, path: Path) -> None:
"""Save the ensemble to *path*.
Layout::
path/
manifest.json
feature_config.json
base_models/
base_0/ … (Model.save)
base_1/ …
meta_learner/ … (Model.save)
"""
path.mkdir(parents=True, exist_ok=True)
# Base models
base_dir = path / "base_models"
base_dir.mkdir(exist_ok=True)
base_model_types: list[str] = []
for i, model in enumerate(self.base_models):
model_path = base_dir / f"base_{i}"
model.save(model_path)
base_model_types.append(model.get_config().model_name)
# Meta-learner
meta_path = path / "meta_learner"
self.meta_learner.save(meta_path)
# Manifest
manifest: dict[str, Any] = {
"base_model_types": base_model_types,
"base_model_count": len(self.base_models),
"contextual_features": list(self.contextual_features),
"meta_learner_type": self.meta_learner.get_config().model_name,
"meta_column_order": list(self.meta_column_order),
}
(path / "manifest.json").write_text(json.dumps(manifest, indent=2))
# Feature config
save_feature_config(self.feature_config, path)
[docs]
@classmethod
def load(cls, path: Path) -> StackedEnsemble:
"""Reconstruct a ``StackedEnsemble`` from a saved directory."""
manifest_data = json.loads((path / "manifest.json").read_text())
base_model_types: list[str] = manifest_data["base_model_types"]
base_model_count: int = manifest_data["base_model_count"]
contextual_features: list[str] = manifest_data["contextual_features"]
meta_learner_type: str = manifest_data["meta_learner_type"]
meta_column_order: list[str] = manifest_data.get("meta_column_order", [])
# Load base models
base_dir = path / "base_models"
base_models: list[Model] = []
for i in range(base_model_count):
model_type = base_model_types[i]
model_cls = get_model(model_type)
base_models.append(model_cls.load(base_dir / f"base_{i}"))
# Load meta-learner
meta_cls = get_model(meta_learner_type)
meta_learner = meta_cls.load(path / "meta_learner")
return cls(
base_models=base_models,
meta_learner=meta_learner,
contextual_features=contextual_features,
meta_column_order=meta_column_order,
)