Source code for citibike.evaluation.cross_validation

"""Rolling window cross-validation for inventory prediction."""

from collections.abc import Generator
from dataclasses import dataclass

import pandas as pd
from tqdm import tqdm

from .metrics import compute_inventory_metrics, summarize_fold_results


[docs] @dataclass class CVFold: """A single cross-validation fold.""" fold_id: int train_start: pd.Timestamp train_end: pd.Timestamp test_start: pd.Timestamp test_end: pd.Timestamp
[docs] class RollingWindowCV: """Rolling window cross-validation splitter.""" def __init__( self, train_weeks: int = 3, test_weeks: int = 1, increment_days: int = None, ): self.train_weeks = train_weeks self.test_weeks = test_weeks # If increment_days not specified, default to test_weeks worth of days self.increment_days = increment_days if increment_days is not None else (test_weeks * 7)
[docs] def split( self, trips: pd.DataFrame, ) -> Generator[CVFold, None, None]: """Generate train/test splits.""" min_date = trips["started_at"].min().normalize() max_date = trips["started_at"].max().normalize() train_delta = pd.Timedelta(weeks=self.train_weeks) test_delta = pd.Timedelta(weeks=self.test_weeks) increment_delta = pd.Timedelta(days=self.increment_days) fold_id = 0 train_start = min_date while True: train_end = train_start + train_delta test_start = train_end test_end = test_start + test_delta if test_end > max_date: break yield CVFold( fold_id=fold_id, train_start=train_start, train_end=train_end, test_start=test_start, test_end=test_end, ) fold_id += 1 train_start = train_start + increment_delta
[docs] def get_n_splits(self, trips: pd.DataFrame) -> int: return sum(1 for _ in self.split(trips))
[docs] def track_inventory( trips: pd.DataFrame, initial_inventory: pd.Series, start_time: pd.Timestamp, end_time: pd.Timestamp, freq: str = "1h", ) -> pd.DataFrame: """Track actual inventory by applying trips to initial state. This gives us GROUND TRUTH inventory - what actually happened. Args: trips: Trip data for the period initial_inventory: Starting bike count per station start_time: Start of tracking period end_time: End of tracking period freq: Time frequency Returns: DataFrame with actual inventory (index=stations, columns=times) """ stations = initial_inventory.index.tolist() times = pd.date_range(start=start_time, end=end_time, freq=freq, inclusive="left") # Initialize inventory = pd.DataFrame(index=stations, columns=times, dtype=float) inventory[times[0]] = initial_inventory # Filter trips to time range mask = (trips["started_at"] >= start_time) & (trips["started_at"] < end_time) period_trips = trips[mask].copy() if len(period_trips) == 0: # No trips, inventory stays constant for t in times: inventory[t] = initial_inventory return inventory # Bucket trips by hour period_trips["hour_bucket"] = period_trips["started_at"].dt.floor(freq) # Track hour by hour current_inventory = initial_inventory.copy() for i, t in enumerate(times[:-1]): # Get trips in this hour hour_trips = period_trips[period_trips["hour_bucket"] == t] if len(hour_trips) > 0: # Count arrivals and departures arrivals = hour_trips.groupby("end_station_name").size() departures = hour_trips.groupby("start_station_name").size() # Apply to inventory for station in stations: arr = arrivals.get(station, 0) dep = departures.get(station, 0) current_inventory[station] = max(0, current_inventory[station] - dep + arr) # Store state at next time inventory[times[i + 1]] = current_inventory.copy() return inventory
[docs] def compute_initial_inventory_for_fold( trips: pd.DataFrame, stations: list, fold_start: pd.Timestamp, ) -> pd.Series: """Compute initial inventory at start of fold using backward tracking. Uses trips before fold_start to infer the bike distribution. Args: trips: All trip data stations: List of station names fold_start: Start time of the fold Returns: Series with estimated bike count per station """ # Get trips before fold start (use last week for burn-in) burn_in_start = fold_start - pd.Timedelta(weeks=1) mask = (trips["started_at"] >= burn_in_start) & (trips["started_at"] < fold_start) burn_in_trips = trips[mask].copy() if len(burn_in_trips) == 0: # No burn-in data, use uniform distribution # Assume 50% of typical capacity (15 bikes) return pd.Series(15.0, index=stations) # Track forward from zero to get ending state inventory = dict.fromkeys(stations, 0) burn_in_trips["hour_bucket"] = burn_in_trips["started_at"].dt.floor("1h") hours = sorted(burn_in_trips["hour_bucket"].unique()) for hour in hours: hour_trips = burn_in_trips[burn_in_trips["hour_bucket"] == hour] arrivals = hour_trips.groupby("end_station_name").size() departures = hour_trips.groupby("start_station_name").size() for station, count in arrivals.items(): if station in inventory: inventory[station] += count for station, count in departures.items(): if station in inventory: inventory[station] = max(0, inventory[station] - count) return pd.Series(inventory)
[docs] def run_cross_validation( model, trips: pd.DataFrame, station_stats: pd.DataFrame, config: dict, verbose: bool = True, ) -> tuple[list[dict[str, float]], dict[str, tuple[float, float]]]: """Run rolling window cross-validation for inventory prediction. Args: model: Model instance (must have fit/predict_inventory methods) trips: Trip data station_stats: Station information with capacity config: Configuration dictionary verbose: Whether to print progress Returns: Tuple of (fold_results, summary) """ cv_config = config.get("cross_validation", {}) cv = RollingWindowCV( train_weeks=cv_config.get("train_weeks", 3), test_weeks=cv_config.get("test_weeks", 1), increment_days=cv_config.get("increment_days", None), ) stations = station_stats.index.tolist() capacities = station_stats["capacity"].to_dict() thresholds = config.get("thresholds", {"empty": 0.1, "full": 0.9}) fold_results = [] folds = list(cv.split(trips)) if verbose: print(f"\nRunning {len(folds)}-fold cross-validation...") print(f"Predicting inventory for {len(stations)} stations") for fold in tqdm(folds, desc="CV Folds", disable=not verbose): # Split data train_mask = (trips["started_at"] >= fold.train_start) & ( trips["started_at"] < fold.train_end ) test_mask = (trips["started_at"] >= fold.test_start) & (trips["started_at"] < fold.test_end) train_trips = trips[train_mask] test_trips = trips[test_mask] if len(train_trips) == 0 or len(test_trips) == 0: continue # Fit model on training data model.fit(train_trips, station_stats) # Compute initial inventory at start of test period # Use end of training period to estimate starting state initial_inventory = compute_initial_inventory_for_fold( trips, stations, fold.test_start, ) # Clamp to capacity for station in stations: cap = capacities.get(station, 30) initial_inventory[station] = min(initial_inventory[station], cap) # Track actual inventory (GROUND TRUTH) true_inventory = track_inventory( test_trips, initial_inventory, fold.test_start, fold.test_end, freq="1h", ) # Predict inventory pred_inventory = model.predict_inventory( initial_inventory, fold.test_start, fold.test_end, freq="1h", ) # Compute metrics metrics = compute_inventory_metrics( true_inventory, pred_inventory, capacities, thresholds, ) metrics["fold_id"] = fold.fold_id metrics["train_start"] = str(fold.train_start.date()) metrics["test_start"] = str(fold.test_start.date()) fold_results.append(metrics) if verbose: print( f" Fold {fold.fold_id}: " f"MAE={metrics.get('inventory_mae', 0):.2f} bikes, " f"Empty Recall={metrics.get('empty_recall', 0):.1%}, " f"Full Recall={metrics.get('full_recall', 0):.1%}, " f"State Acc={metrics.get('state_accuracy', 0):.1%}" ) # Summarize summary = summarize_fold_results(fold_results) if verbose: print("\n" + "=" * 60) print("Cross-Validation Summary (mean ± std across folds):") print("=" * 60) key_metrics = [ "inventory_mae", "inventory_rmse", "correlation", "empty_recall", "empty_precision", "empty_f1", "full_recall", "full_precision", "full_f1", "state_accuracy", ] for metric in key_metrics: if metric in summary: mean, std = summary[metric] if ( "recall" in metric or "precision" in metric or "accuracy" in metric or "f1" in metric ): print(f" {metric}: {mean:.1%} ± {std:.1%}") else: print(f" {metric}: {mean:.2f} ± {std:.2f}") return fold_results, summary