"""Plotly visualization adapters for evaluation results.
Provides standalone functions that accept evaluation result objects
and return interactive ``plotly.graph_objects.Figure`` instances for
Jupyter notebook rendering.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go # type: ignore[import-untyped]
from ncaa_eval.evaluation.backtest import BacktestResult
from ncaa_eval.evaluation.metrics import reliability_diagram_data
from ncaa_eval.evaluation.simulation import (
N_ROUNDS,
BracketDistribution,
SimulationResult,
)
# UX spec color palette
COLOR_GREEN: str = "#28a745"
COLOR_RED: str = "#dc3545"
COLOR_NEUTRAL: str = "#6c757d"
# Extended palette for multi-trace plots (green, red, neutral, then extras)
_PALETTE: tuple[str, ...] = (
COLOR_GREEN,
COLOR_RED,
COLOR_NEUTRAL,
"#17a2b8", # teal
"#ffc107", # amber
"#6f42c1", # purple
"#fd7e14", # orange
"#20c997", # mint
)
# Use plotly_dark template for dark-mode compatibility
TEMPLATE: str = "plotly_dark"
# Round labels for advancement heatmap
_ROUND_LABELS: tuple[str, ...] = ("R64", "R32", "S16", "E8", "F4", "Championship")
[docs]
def plot_reliability_diagram(
y_true: npt.NDArray[np.float64],
y_prob: npt.NDArray[np.float64],
*,
n_bins: int = 10,
title: str | None = None,
) -> go.Figure:
"""Reliability diagram: predicted vs. actual probability with bin counts.
Args:
y_true: Binary labels (0 or 1).
y_prob: Predicted probabilities for the positive class.
n_bins: Number of calibration bins (default 10).
title: Optional figure title.
Returns:
Interactive Plotly Figure with calibration curve, diagonal
reference, and bar overlay of per-bin sample counts.
"""
data = reliability_diagram_data(y_true, y_prob, n_bins=n_bins)
fig = go.Figure()
# Bar trace: bin counts on secondary y-axis
fig.add_trace(
go.Bar(
x=data.mean_predicted_value,
y=data.bin_counts,
name="Bin Count",
marker_color=COLOR_NEUTRAL,
opacity=0.3,
yaxis="y2",
)
)
# Diagonal reference line: perfect calibration
fig.add_trace(
go.Scatter(
x=[0, 1],
y=[0, 1],
mode="lines",
line={"dash": "dash", "color": COLOR_NEUTRAL},
name="Perfect",
showlegend=True,
)
)
# Scatter trace: calibration curve
fig.add_trace(
go.Scatter(
x=data.mean_predicted_value,
y=data.fraction_of_positives,
mode="lines+markers",
marker={"color": COLOR_GREEN, "size": 8},
line={"color": COLOR_GREEN},
name="Calibration",
text=[str(c) for c in data.bin_counts],
hovertemplate=("Predicted: %{x:.3f}<br>Observed: %{y:.3f}<br>Count: %{text}<extra></extra>"),
)
)
fig.update_layout(
template=TEMPLATE,
title=title or "Reliability Diagram",
xaxis_title="Mean Predicted Probability",
yaxis_title="Fraction of Positives",
yaxis2={
"title": "Bin Count",
"overlaying": "y",
"side": "right",
"showgrid": False,
},
legend={"x": 0.01, "y": 0.99},
)
return fig
[docs]
def plot_backtest_summary(
result: BacktestResult,
*,
metrics: Sequence[str] | None = None,
) -> go.Figure:
"""Per-year metric values from a backtest result.
Args:
result: Backtest result containing the summary DataFrame.
metrics: Metric column names to include. Defaults to all
metric columns (excludes ``elapsed_seconds``).
Returns:
Interactive Plotly Figure with one line per metric, x=year.
"""
summary = result.summary
if metrics is None:
metric_cols = [c for c in summary.columns if c != "elapsed_seconds"]
else:
metric_cols = list(metrics)
if not metric_cols:
msg = "No metric columns to plot. BacktestResult.summary has no columns besides 'elapsed_seconds'."
raise ValueError(msg)
years = summary.index.tolist()
fig = go.Figure()
for i, col in enumerate(metric_cols):
color = _PALETTE[i % len(_PALETTE)]
fig.add_trace(
go.Scatter(
x=years,
y=summary[col].tolist(),
mode="lines+markers",
name=col,
line={"color": color},
marker={"color": color},
)
)
fig.update_layout(
template=TEMPLATE,
title="Backtest Summary",
xaxis_title="Year",
yaxis_title="Metric Value",
)
return fig
[docs]
def plot_metric_comparison(
results: Mapping[str, BacktestResult],
metric: str,
) -> go.Figure:
"""Multi-model overlay: one line per model for a given metric across years.
Args:
results: Mapping of model name to BacktestResult.
metric: Metric column name to compare.
Returns:
Interactive Plotly Figure with one line per model.
"""
fig = go.Figure()
for i, (model_name, bt) in enumerate(results.items()):
if metric not in bt.summary.columns:
available = [c for c in bt.summary.columns if c != "elapsed_seconds"]
msg = f"metric {metric!r} not found in results[{model_name!r}].summary. Available: {available}"
raise ValueError(msg)
color = _PALETTE[i % len(_PALETTE)]
years = bt.summary.index.tolist()
values = bt.summary[metric].tolist()
fig.add_trace(
go.Scatter(
x=years,
y=values,
mode="lines+markers",
name=model_name,
line={"color": color},
marker={"color": color},
hovertemplate=(f"{model_name}<br>Year: %{{x}}<br>{metric}: %{{y:.4f}}<extra></extra>"),
)
)
fig.update_layout(
template=TEMPLATE,
title=f"Model Comparison — {metric}",
xaxis_title="Year",
yaxis_title=metric,
)
return fig
[docs]
def plot_advancement_heatmap(
result: SimulationResult,
team_labels: Mapping[int, str] | None = None,
) -> go.Figure:
"""Heatmap of per-team advancement probabilities by round.
Args:
result: Simulation result with ``advancement_probs`` array.
team_labels: Optional mapping of **team index** (0..n-1, bracket
position order) to display name. When ``None``, team indices
are shown as-is. Note: keys are bracket indices, not canonical
team IDs — use ``BracketStructure.team_index_map`` to translate
from team IDs to indices before passing this argument.
Returns:
Interactive Plotly Figure showing a heatmap with teams on
y-axis and rounds on x-axis.
"""
adv = result.advancement_probs # shape (n_teams, n_rounds)
n_teams = adv.shape[0]
n_rounds = min(adv.shape[1], N_ROUNDS)
round_labels = list(_ROUND_LABELS[:n_rounds])
if team_labels is not None:
y_labels = [team_labels.get(i, str(i)) for i in range(n_teams)]
else:
y_labels = [str(i) for i in range(n_teams)]
fig = go.Figure(
data=go.Heatmap(
z=adv[:, :n_rounds],
x=round_labels,
y=y_labels,
colorscale=[[0, COLOR_RED], [1, COLOR_GREEN]],
zmin=0.0,
zmax=1.0,
hovertemplate=("Team: %{y}<br>Round: %{x}<br>P(advance): %{z:.3f}<extra></extra>"),
)
)
fig.update_layout(
template=TEMPLATE,
title="Advancement Probabilities",
xaxis_title="Round",
yaxis_title="Team",
yaxis={"autorange": "reversed"},
)
return fig
[docs]
def plot_score_distribution(
dist: BracketDistribution,
*,
title: str | None = None,
) -> go.Figure:
"""Histogram of bracket score distribution with percentile markers.
Args:
dist: Bracket distribution with pre-computed histogram data
and percentile values.
title: Optional figure title.
Returns:
Interactive Plotly Figure with histogram bars and vertical
percentile lines at 5th, 25th, 50th, 75th, and 95th.
"""
# Convert bin edges to bin centers for the bar chart
bin_centers = (dist.histogram_bins[:-1] + dist.histogram_bins[1:]) / 2.0
bin_width = (
float(dist.histogram_bins[1] - dist.histogram_bins[0]) if len(dist.histogram_bins) >= 2 else 1.0
)
fig = go.Figure()
# Histogram bars
fig.add_trace(
go.Bar(
x=bin_centers.tolist(),
y=dist.histogram_counts.tolist(),
width=bin_width,
marker_color=COLOR_GREEN,
opacity=0.7,
name="Score Distribution",
)
)
# Percentile vertical lines
percentile_colors = {
5: COLOR_RED,
25: COLOR_NEUTRAL,
50: COLOR_GREEN,
75: COLOR_NEUTRAL,
95: COLOR_RED,
}
max_count = int(np.max(dist.histogram_counts)) if len(dist.histogram_counts) > 0 else 1
for pct, value in sorted(dist.percentiles.items()):
color = percentile_colors.get(pct, COLOR_NEUTRAL)
fig.add_trace(
go.Scatter(
x=[value, value],
y=[0, max_count],
mode="lines",
line={"color": color, "dash": "dash", "width": 2},
name=f"P{pct}: {value:.1f}",
)
)
fig.update_layout(
template=TEMPLATE,
title=title or "Bracket Score Distribution",
xaxis_title="Score",
yaxis_title="Count",
bargap=0.05,
)
return fig