Source code for citibike.models.base

"""Base model class defining the interface for inventory prediction."""

from abc import ABC, abstractmethod
from typing import Any

import pandas as pd


[docs] class BaseModel(ABC): """Abstract base class for all bike inventory prediction models. All models must implement: - fit(): Train the model on historical data - predict_inventory(): Predict future bike counts per station """
[docs] def __init__(self, config: dict): """Initialize model with configuration. Args: config: Configuration dictionary """ self.config = config self.is_fitted = False self.station_capacities = {}
[docs] @abstractmethod def fit( self, trips: pd.DataFrame, station_stats: pd.DataFrame, ) -> "BaseModel": """Train the model on historical trip data. Args: trips: DataFrame with columns [started_at, ended_at, start_station_name, end_station_name, ...] station_stats: DataFrame indexed by station_name with capacity Returns: self (for method chaining) """ pass
[docs] @abstractmethod def predict_inventory( self, initial_inventory: pd.Series, start_time: pd.Timestamp, end_time: pd.Timestamp, freq: str = "1h", ) -> pd.DataFrame: """Predict bike inventory at each station over time. Args: initial_inventory: Series indexed by station_name with starting bike counts start_time: Start of prediction period end_time: End of prediction period freq: Time frequency (e.g., "1h" for hourly) Returns: DataFrame with predictions: - index: station_name - columns: timestamps - values: predicted bike counts """ pass
[docs] def predict_states( self, initial_inventory: pd.Series, start_time: pd.Timestamp, end_time: pd.Timestamp, freq: str = "1h", ) -> pd.DataFrame: """Predict station states (empty/normal/full) over time. Args: initial_inventory: Starting bike counts per station start_time: Start of prediction period end_time: End of prediction period freq: Time frequency Returns: DataFrame with state predictions ("empty", "normal", "full") """ # Get inventory predictions inventory = self.predict_inventory(initial_inventory, start_time, end_time, freq) # Get thresholds thresholds = self.config.get("thresholds", {"empty": 0.1, "full": 0.9}) # Convert to states states = pd.DataFrame(index=inventory.index, columns=inventory.columns, dtype=str) for station in inventory.index: capacity = self.station_capacities.get(station, 30) empty_threshold = capacity * thresholds["empty"] full_threshold = capacity * thresholds["full"] for col in inventory.columns: bikes = inventory.loc[station, col] if bikes <= empty_threshold: states.loc[station, col] = "empty" elif bikes >= full_threshold: states.loc[station, col] = "full" else: states.loc[station, col] = "normal" return states
[docs] def get_name(self) -> str: """Return the model name.""" return self.__class__.__name__
[docs] def get_params(self) -> dict[str, Any]: """Return model parameters for logging.""" return {"name": self.get_name()}