"""DuckDB utilities for efficient data querying and processing."""
from pathlib import Path
from typing import Any
import duckdb
import pandas as pd
[docs]
class DuckDBConnection:
"""Context manager for DuckDB connections."""
[docs]
def __init__(self, database: str | None = None):
"""Initialize connection.
Args:
database: Path to database file. If None, uses in-memory database.
"""
self.database = database
self.con = None
def __enter__(self):
"""Enter context and create connection."""
self.con = duckdb.connect(self.database)
return self.con
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context and close connection."""
if self.con:
self.con.close()
[docs]
def query_parquet(
parquet_path: Path,
columns: list[str] | None = None,
filters: dict[str, Any] | None = None,
limit: int | None = None,
) -> pd.DataFrame:
"""Query Parquet files with DuckDB.
Args:
parquet_path: Path to parquet directory or file
columns: List of columns to select (None = all)
filters: Dictionary of column filters (e.g., {"year": 2025, "month": 9})
limit: Maximum number of rows to return
Returns:
DataFrame with query results
"""
with DuckDBConnection() as con:
# Build SELECT clause
select_cols = ", ".join(columns) if columns else "*"
# Build WHERE clause
where_clauses = []
if filters:
for col, val in filters.items():
if isinstance(val, str):
where_clauses.append(f"{col} = '{val}'")
else:
where_clauses.append(f"{col} = {val}")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
# Build LIMIT clause
limit_sql = f"LIMIT {limit}" if limit else ""
# Construct pattern
if parquet_path.is_dir():
pattern = str(parquet_path / "**" / "*.parquet")
else:
pattern = str(parquet_path)
# Execute query
query = f"""
SELECT {select_cols}
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
{limit_sql}
"""
return con.execute(query).fetchdf()
[docs]
def aggregate_trips(
parquet_path: Path,
group_by: list[str],
aggregations: dict[str, str],
filters: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Aggregate trip data using DuckDB.
Args:
parquet_path: Path to parquet files
group_by: Columns to group by
aggregations: Dict of {output_col: aggregation_expr}
Example: {"trip_count": "COUNT(*)", "avg_duration": "AVG(duration)"}
filters: Optional filters to apply before aggregation
Returns:
DataFrame with aggregated results
"""
with DuckDBConnection() as con:
# Build WHERE clause
where_clauses = []
if filters:
for col, val in filters.items():
if isinstance(val, str):
where_clauses.append(f"{col} = '{val}'")
else:
where_clauses.append(f"{col} = {val}")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
# Build aggregation expressions
agg_exprs = [f"{expr} as {col}" for col, expr in aggregations.items()]
agg_sql = ", ".join(agg_exprs)
# Build GROUP BY clause
group_sql = ", ".join(group_by)
# Construct pattern
pattern = str(parquet_path / "**" / "*.parquet")
# Execute query
query = f"""
SELECT {group_sql}, {agg_sql}
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
GROUP BY {group_sql}
ORDER BY {group_sql}
"""
return con.execute(query).fetchdf()
[docs]
def get_trip_stats(
parquet_path: Path,
start_date: str | None = None,
end_date: str | None = None,
) -> dict[str, Any]:
"""Get summary statistics for trip data.
Args:
parquet_path: Path to parquet files
start_date: Optional start date filter (YYYY-MM-DD)
end_date: Optional end date filter (YYYY-MM-DD)
Returns:
Dictionary with statistics
"""
with DuckDBConnection() as con:
# Build date filters
where_clauses = []
if start_date:
where_clauses.append(f"started_at >= '{start_date}'")
if end_date:
where_clauses.append(f"started_at <= '{end_date}'")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
pattern = str(parquet_path / "**" / "*.parquet")
# Get overall stats
query = f"""
SELECT
COUNT(*) as total_trips,
COUNT(DISTINCT start_station_name) as n_start_stations,
COUNT(DISTINCT end_station_name) as n_end_stations,
MIN(started_at) as first_trip,
MAX(started_at) as last_trip,
AVG(EXTRACT(EPOCH FROM (ended_at - started_at)) / 60) as avg_duration_minutes
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
"""
result = con.execute(query).fetchdf()
return {
"total_trips": int(result["total_trips"].iloc[0]),
"n_start_stations": int(result["n_start_stations"].iloc[0]),
"n_end_stations": int(result["n_end_stations"].iloc[0]),
"first_trip": result["first_trip"].iloc[0],
"last_trip": result["last_trip"].iloc[0],
"avg_duration_minutes": float(result["avg_duration_minutes"].iloc[0]),
}
[docs]
def count_trips_by_station(
parquet_path: Path,
start_date: str | None = None,
end_date: str | None = None,
) -> pd.DataFrame:
"""Count trips by station (both starts and ends).
Args:
parquet_path: Path to parquet files
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
DataFrame with station trip counts
"""
with DuckDBConnection() as con:
where_clauses = []
if start_date:
where_clauses.append(f"started_at >= '{start_date}'")
if end_date:
where_clauses.append(f"started_at <= '{end_date}'")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
pattern = str(parquet_path / "**" / "*.parquet")
# Count starts and ends separately, then combine
query = f"""
WITH starts AS (
SELECT
start_station_name as station_name,
COUNT(*) as start_count
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
GROUP BY start_station_name
),
ends AS (
SELECT
end_station_name as station_name,
COUNT(*) as end_count
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
GROUP BY end_station_name
)
SELECT
COALESCE(starts.station_name, ends.station_name) as station_name,
COALESCE(start_count, 0) as departures,
COALESCE(end_count, 0) as arrivals,
COALESCE(start_count, 0) + COALESCE(end_count, 0) as total_trips
FROM starts
FULL OUTER JOIN ends ON starts.station_name = ends.station_name
ORDER BY total_trips DESC
"""
return con.execute(query).fetchdf()
[docs]
def export_to_parquet(
df: pd.DataFrame,
output_path: Path,
partition_cols: list[str] | None = None,
compression: str = "zstd",
) -> None:
"""Export DataFrame to Parquet using DuckDB.
Args:
df: DataFrame to export
output_path: Output path for parquet file
partition_cols: Columns to partition by (creates subdirectories)
compression: Compression algorithm (zstd, snappy, gzip, etc.)
"""
with DuckDBConnection() as con:
# Register DataFrame as a view
con.register("temp_df", df)
# Create output directory
output_path.parent.mkdir(parents=True, exist_ok=True)
# Build partition clause
partition_sql = ""
if partition_cols:
partition_sql = f"PARTITION_BY ({', '.join(partition_cols)})"
# Export query
query = f"""
COPY temp_df TO '{output_path}' (
FORMAT PARQUET,
COMPRESSION '{compression}',
{partition_sql}
)
"""
con.execute(query)
[docs]
def create_summary_table(
parquet_path: Path,
output_path: Path,
start_date: str | None = None,
end_date: str | None = None,
) -> None:
"""Create a summary table for faster querying.
Aggregates trips by hour and station for modeling.
Args:
parquet_path: Path to raw parquet files
output_path: Path for summary parquet output
start_date: Optional start date filter
end_date: Optional end date filter
"""
with DuckDBConnection() as con:
where_clauses = []
if start_date:
where_clauses.append(f"started_at >= '{start_date}'")
if end_date:
where_clauses.append(f"started_at <= '{end_date}'")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
pattern = str(parquet_path / "**" / "*.parquet")
output_path.parent.mkdir(parents=True, exist_ok=True)
# Create hourly summary
query = f"""
COPY (
SELECT
DATE_TRUNC('hour', started_at) as hour,
start_station_name,
end_station_name,
EXTRACT(HOUR FROM started_at) as hour_of_day,
EXTRACT(DOW FROM started_at) IN (0, 6) as is_weekend,
COUNT(*) as trip_count,
AVG(EXTRACT(EPOCH FROM (ended_at - started_at)) / 60) as avg_duration_minutes
FROM read_parquet('{pattern}', hive_partitioning=1)
{where_sql}
GROUP BY
DATE_TRUNC('hour', started_at),
start_station_name,
end_station_name,
EXTRACT(HOUR FROM started_at),
EXTRACT(DOW FROM started_at) IN (0, 6)
ORDER BY hour, start_station_name, end_station_name
) TO '{output_path}' (
FORMAT PARQUET,
COMPRESSION 'zstd'
)
"""
print(f"Creating summary table at {output_path}...")
con.execute(query)
print("✓ Summary table created")