Source code for citibike.utils.data_loader

"""Data loading utilities for CitiBike trip data."""

from pathlib import Path

import duckdb
import pandas as pd
from tqdm import tqdm


[docs] def load_trip_data( data_dir: str = "data", start_date: str | None = None, end_date: str | None = None, use_parquet: bool = True, ) -> pd.DataFrame: """Load trip data from Parquet or CSV files using DuckDB. Args: data_dir: Directory containing trip data folders start_date: Optional start date filter (YYYY-MM-DD) end_date: Optional end date filter (YYYY-MM-DD) use_parquet: If True, use Parquet files; otherwise fall back to CSV Returns: DataFrame with all trip data """ data_path = Path(data_dir) # Try Parquet first if requested if use_parquet: parquet_path = data_path / "parquet" / "trips" if parquet_path.exists(): print("Loading trip data from Parquet files using DuckDB...") return _load_from_parquet_duckdb(parquet_path, start_date, end_date) else: print(f"Parquet directory not found at {parquet_path}, falling back to CSV...") # Fallback to CSV loading return _load_from_csv(data_path, start_date, end_date)
def _load_from_parquet_duckdb( parquet_path: Path, start_date: str | None = None, end_date: str | None = None, ) -> pd.DataFrame: """Load trip data from Parquet files using DuckDB for efficient querying. Args: parquet_path: Path to parquet directory (with year=YYYY partitions) start_date: Optional start date filter (YYYY-MM-DD) end_date: Optional end date filter (YYYY-MM-DD) Returns: DataFrame with trip data """ con = duckdb.connect() try: # Build query with optional date filters where_clauses = [] if start_date: where_clauses.append(f"started_at >= '{start_date}'") if end_date: where_clauses.append(f"started_at <= '{end_date}'") where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" # Use glob pattern to read all parquet files parquet_pattern = str(parquet_path / "**" / "*.parquet") query = f""" SELECT * FROM read_parquet('{parquet_pattern}', hive_partitioning=1) {where_sql} """ print(" Executing DuckDB query on Parquet files...") print(f" Pattern: {parquet_pattern}") if where_sql: print(f" Filters: {where_sql}") df = con.execute(query).fetchdf() # Ensure datetime columns are parsed if "started_at" in df.columns: df["started_at"] = pd.to_datetime(df["started_at"]) if "ended_at" in df.columns: df["ended_at"] = pd.to_datetime(df["ended_at"]) print( f" ✓ Loaded {len(df):,} trips from {df['started_at'].min().date()} to {df['started_at'].max().date()}" ) return df finally: con.close() def _load_from_csv( data_path: Path, start_date: str | None = None, end_date: str | None = None, ) -> pd.DataFrame: """Load trip data from CSV files (legacy fallback). Args: data_path: Path to data directory start_date: Optional start date filter (YYYY-MM-DD) end_date: Optional end date filter (YYYY-MM-DD) Returns: DataFrame with trip data """ # Find all CSV files in subdirectories csv_files = list(data_path.glob("**/202*.csv")) if not csv_files: raise FileNotFoundError(f"No trip data files found in {data_path}") print(f"Found {len(csv_files)} trip data CSV files") # Load all files dfs = [] for f in tqdm(csv_files, desc="Loading trip data"): df = pd.read_csv(f, low_memory=False) dfs.append(df) df = pd.concat(dfs, ignore_index=True) # Parse timestamps df["started_at"] = pd.to_datetime(df["started_at"]) df["ended_at"] = pd.to_datetime(df["ended_at"]) # Filter by date if specified if start_date: df = df[df["started_at"] >= start_date] if end_date: df = df[df["started_at"] <= end_date] print( f"Loaded {len(df):,} trips from {df['started_at'].min().date()} to {df['started_at'].max().date()}" ) return df
[docs] def load_station_info( station_path: str = "data/stations/station_info.csv", use_parquet: bool = True ) -> pd.DataFrame: """Load station information including capacity. Args: station_path: Path to station info file (CSV or Parquet) use_parquet: If True, try Parquet first Returns: DataFrame with station information """ station_path_obj = Path(station_path) # Try Parquet version first if requested if use_parquet: # Try replacing extension parquet_path = station_path_obj.with_suffix(".parquet") # Also try common parquet directory structure if not parquet_path.exists(): data_dir = Path(station_path_obj.parts[0]) if station_path_obj.parts else Path("data") parquet_path = data_dir / "parquet" / "stations" / "station_info.parquet" if parquet_path.exists(): print(f"Loading station info from Parquet: {parquet_path}") df = pd.read_parquet(parquet_path) print(f"Loaded {len(df)} stations with total capacity {df['capacity'].sum():,}") return df # Fallback to CSV if station_path_obj.exists(): df = pd.read_csv(station_path) print(f"Loaded {len(df)} stations with total capacity {df['capacity'].sum():,}") return df else: raise FileNotFoundError( f"Station info file not found: {station_path} or parquet alternative" )
[docs] def prepare_data( trips: pd.DataFrame, stations: pd.DataFrame, config: dict, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Prepare data for modeling. - Filters to valid stations - Adds time features - Merges with station capacity Args: trips: Raw trip data stations: Station information config: Configuration dictionary Returns: Tuple of (processed_trips, station_stats) """ # Filter out missing station names trips = trips.dropna(subset=["start_station_name", "end_station_name"]).copy() # Add time features trips["hour"] = trips["started_at"].dt.hour trips["day_of_week"] = trips["started_at"].dt.dayofweek trips["date"] = trips["started_at"].dt.date trips["is_weekend"] = trips["day_of_week"].isin([5, 6]) # Filter stations by minimum trips min_trips = config.get("stations", {}).get("min_trips", 100) station_trip_counts = ( trips.groupby("start_station_name").size() + trips.groupby("end_station_name").size() ) valid_stations = station_trip_counts[station_trip_counts >= min_trips].index.tolist() trips = trips[ trips["start_station_name"].isin(valid_stations) & trips["end_station_name"].isin(valid_stations) ] print(f"After filtering: {len(trips):,} trips, {len(valid_stations)} stations") # Create station stats with capacity station_stats = stations[["name", "capacity"]].copy() station_stats = station_stats.rename(columns={"name": "station_name"}) station_stats = station_stats.drop_duplicates(subset=["station_name"]) station_stats = station_stats.set_index("station_name") # Only keep stations that appear in our trip data station_stats = station_stats[station_stats.index.isin(valid_stations)] # For stations in trips but not in station info, estimate capacity missing_stations = set(valid_stations) - set(station_stats.index) if missing_stations: print(f"Warning: {len(missing_stations)} stations missing capacity info, using median") median_capacity = station_stats["capacity"].median() for station in missing_stations: station_stats.loc[station, "capacity"] = median_capacity return trips, station_stats