from __future__ import annotations from dataclasses import dataclass from functools import lru_cache import jwt from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jwt import InvalidTokenError, PyJWKClient from app.core.config import settings BEARER_SCHEME = HTTPBearer(auto_error=False) @dataclass class FrontendPrincipal: subject: str scopes: list[str] claims: dict token: str class FrontendJWTVerifier: @property def jwks_url(self) -> str: if not settings.frontend_jwt_jwks_url: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_JWKS_URL is not configured.", ) return settings.frontend_jwt_jwks_url @lru_cache(maxsize=1) def _jwks_client(self) -> PyJWKClient: return PyJWKClient(self.jwks_url) @staticmethod def _extract_scopes(claims: dict) -> list[str]: scope = claims.get("scope") if isinstance(scope, str): return [item for item in scope.split(" ") if item] scp = claims.get("scp") if isinstance(scp, list): return [str(item) for item in scp] return [] def verify(self, token: str) -> FrontendPrincipal: if not settings.frontend_jwt_issuer_url: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_ISSUER_URL is not configured.", ) if not settings.frontend_jwt_audience: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_AUDIENCE is not configured.", ) try: signing_key = self._jwks_client().get_signing_key_from_jwt(token).key claims = jwt.decode( token, key=signing_key, algorithms=[settings.frontend_jwt_algorithm], audience=settings.frontend_jwt_audience, issuer=settings.frontend_jwt_issuer_url, leeway=settings.frontend_clock_skew_seconds, ) except InvalidTokenError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid frontend access token.", ) from exc subject = str(claims.get("sub") or "") if not subject: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Frontend token missing subject.", ) scopes = self._extract_scopes(claims) required = settings.frontend_required_scopes_list missing = [s for s in required if s not in scopes] if missing: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Missing required scope(s): {', '.join(missing)}", ) return FrontendPrincipal(subject=subject, scopes=scopes, claims=claims, token=token) @lru_cache(maxsize=1) def get_frontend_verifier() -> FrontendJWTVerifier: return FrontendJWTVerifier() def require_frontend_principal( credentials: HTTPAuthorizationCredentials | None = Depends(BEARER_SCHEME), ) -> FrontendPrincipal: if not settings.require_frontend_auth: return FrontendPrincipal(subject="anonymous", scopes=[], claims={}, token="") if credentials is None or credentials.scheme.lower() != "bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token.", ) return get_frontend_verifier().verify(credentials.credentials)