from __future__ import annotations import hashlib import logging from collections.abc import Sequence from time import perf_counter import pandas as pd from opentelemetry import metrics, trace from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from app.db import queries LOGGER = logging.getLogger(__name__) class ReadOnlyWarehouseClient: def __init__(self, engines: dict[str, Engine]) -> None: self.engines = engines self.tracer = trace.get_tracer(__name__) self.meter = metrics.get_meter(__name__) self.query_counter = self.meter.create_counter( name="warehouse_queries_total", description="Total warehouse query executions", ) self.query_latency = self.meter.create_histogram( name="warehouse_query_latency_ms", unit="ms", description="Warehouse query latency", ) def _validate_read_only_query(self, sql: str) -> None: normalized = sql.strip().lower() if not (normalized.startswith("select") or normalized.startswith("with")): raise ValueError("Only read-only SELECT/CTE SQL statements are allowed.") def _run_query_list( self, source: str, sql_candidates: Sequence[str] ) -> pd.DataFrame: engine = self.engines[source] last_error: Exception | None = None for candidate in sql_candidates: self._validate_read_only_query(candidate) query_hash = hashlib.sha256(candidate.encode("utf-8")).hexdigest()[:12] with self.tracer.start_as_current_span("warehouse.query") as span: span.set_attribute("db.system", "mssql") span.set_attribute("db.source", source) span.set_attribute("db.query.hash", query_hash) started = perf_counter() try: with engine.connect() as conn: with self.tracer.start_as_current_span( "warehouse.query.execute" ): df = pd.read_sql_query(sql=text(candidate), con=conn) elapsed_ms = (perf_counter() - started) * 1000 self.query_latency.record(elapsed_ms, attributes={"source": source}) self.query_counter.add( 1, attributes={"source": source, "status": "ok"} ) return df except SQLAlchemyError as exc: last_error = exc elapsed_ms = (perf_counter() - started) * 1000 self.query_latency.record(elapsed_ms, attributes={"source": source}) self.query_counter.add( 1, attributes={"source": source, "status": "error"} ) LOGGER.warning( "Query failed for %s with hash %s: %s", source, query_hash, exc ) if last_error is not None: raise RuntimeError( f"All query candidates failed for source '{source}'." ) from last_error return pd.DataFrame() def fetch_daily_sales(self) -> pd.DataFrame: aw = self._run_query_list("aw", queries.AW_DAILY_SALES_QUERIES) aw["source"] = "AdventureWorks2022DWH" wwi = self._run_query_list("wwi", queries.WWI_DAILY_SALES_QUERIES) wwi["source"] = "WorldWideImporters" return pd.concat([aw, wwi], ignore_index=True) def fetch_product_performance(self) -> pd.DataFrame: aw = self._run_query_list("aw", queries.AW_PRODUCT_PERFORMANCE_QUERIES) aw["source"] = "AdventureWorks2022DWH" wwi = self._run_query_list("wwi", queries.WWI_PRODUCT_PERFORMANCE_QUERIES) wwi["source"] = "WorldWideImporters" return pd.concat([aw, wwi], ignore_index=True) def fetch_customer_performance(self) -> pd.DataFrame: aw = self._run_query_list("aw", queries.AW_CUSTOMER_QUERIES) aw["source"] = "AdventureWorks2022DWH" wwi = self._run_query_list("wwi", queries.WWI_CUSTOMER_QUERIES) wwi["source"] = "WorldWideImporters" return pd.concat([aw, wwi], ignore_index=True)