Source code for ncaa_eval.evaluation.splitter
"""Walk-forward cross-validation splitter with Leave-One-Tournament-Out folds.
Provides :func:`walk_forward_splits`, which partitions historical game data into
train/test folds where each fold uses one tournament year as the test set and all
prior years as training data. The 2020 COVID year is handled gracefully: its
regular-season data is included in training but no test fold is yielded (the
tournament was cancelled).
"""
from __future__ import annotations
import dataclasses
from collections.abc import Iterator, Sequence
from typing import Literal
import pandas as pd # type: ignore[import-untyped]
from ncaa_eval.transform.feature_serving import StatefulFeatureServer
from ncaa_eval.transform.serving import NO_TOURNAMENT_SEASONS
[docs]
@dataclasses.dataclass(frozen=True)
class CVFold:
"""A single cross-validation fold.
Attributes:
train: All games from seasons strictly before the test year.
test: Tournament games only from the test year.
year: The test season year.
"""
train: pd.DataFrame
test: pd.DataFrame
year: int
[docs]
def walk_forward_splits(
seasons: Sequence[int],
feature_server: StatefulFeatureServer,
*,
mode: Literal["batch", "stateful"] = "batch",
) -> Iterator[CVFold]:
"""Generate walk-forward CV folds with Leave-One-Tournament-Out splits.
Args:
seasons: Ordered sequence of season years to include
(e.g., ``range(2008, 2026)``). Must contain at least 2 seasons.
feature_server: Configured StatefulFeatureServer for building feature
matrices.
mode: Feature serving mode: ``"batch"`` (stateless models) or
``"stateful"`` (sequential-update models like Elo).
Yields:
CVFold: For each eligible test year (skipping no-tournament years like
2020): ``train`` contains all games from seasons strictly before the
test year; ``test`` contains only tournament games from the test year;
``year`` is the test season year.
Raises:
ValueError: If ``seasons`` has fewer than 2 elements, or if ``mode``
is not ``"batch"`` or ``"stateful"``.
"""
# Runtime guard: Literal["batch","stateful"] enforces at static-analysis
# time; this check also protects callers who bypass mypy (e.g. YAML config).
if mode not in ("batch", "stateful"):
msg = f"mode must be 'batch' or 'stateful', got {mode!r}"
raise ValueError(msg)
sorted_seasons = sorted(seasons)
if len(sorted_seasons) < 2:
msg = "seasons must contain at least 2 seasons (need at least one training and one test season)"
raise ValueError(msg)
# Cache feature DataFrames — serve each season exactly once
season_cache: dict[int, pd.DataFrame] = {}
for year in sorted_seasons:
season_cache[year] = feature_server.serve_season_features(year, mode=mode)
# Walk-forward: iterate from second season onward as test candidates
for i, test_year in enumerate(sorted_seasons[1:], start=1):
# Skip no-tournament seasons (e.g., 2020 COVID cancellation)
if test_year in NO_TOURNAMENT_SEASONS:
continue
# Accumulate training data from all prior seasons.
# Empty seasons (e.g. feature server returned no games) are skipped —
# pd.concat of empty DataFrames produces a column-less frame that would
# confuse downstream consumers. Including an empty frame in the list
# is harmless but produces dtype-mismatch warnings, so we filter them.
train_frames: list[pd.DataFrame] = []
for train_year in sorted_seasons[:i]:
df = season_cache[train_year]
if not df.empty:
train_frames.append(df)
# If all prior seasons were empty, yield a column-less DataFrame so the
# fold is still produced; Story 6.3 callers should check train.empty.
train_df = pd.concat(train_frames, ignore_index=True) if train_frames else pd.DataFrame()
# Test data: tournament games only from the test year
test_season_df = season_cache[test_year]
if test_season_df.empty:
test_df = test_season_df
else:
test_df = test_season_df[test_season_df["is_tournament"] == True].reset_index( # noqa: E712
drop=True,
)
yield CVFold(train=train_df, test=test_df, year=test_year)