from __future__ import annotations from dataclasses import dataclass from functools import lru_cache from time import time from uuid import uuid4 import jwt from fastapi import Depends, Header, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jwt import InvalidTokenError, PyJWKClient from app.core.config import settings BEARER_SCHEME = HTTPBearer(auto_error=False) @dataclass class FrontendPrincipal: subject: str scopes: list[str] claims: dict token: str @dataclass class InternalPrincipal: subject: str scopes: list[str] claims: dict token: str class FrontendJWTVerifier: @property def jwks_url(self) -> str: if not settings.frontend_jwt_jwks_url: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_JWKS_URL is not configured.", ) return settings.frontend_jwt_jwks_url @lru_cache(maxsize=1) def _jwks_client(self) -> PyJWKClient: return PyJWKClient(self.jwks_url) @staticmethod def _extract_scopes(claims: dict) -> list[str]: scope = claims.get("scope") if isinstance(scope, str): return [item for item in scope.split(" ") if item] scp = claims.get("scp") if isinstance(scp, list): return [str(item) for item in scp] return [] def verify(self, token: str) -> FrontendPrincipal: if not settings.frontend_jwt_issuer_url: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_ISSUER_URL is not configured.", ) if not settings.frontend_jwt_audience: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="FRONTEND_JWT_AUDIENCE is not configured.", ) try: signing_key = self._jwks_client().get_signing_key_from_jwt(token).key claims = jwt.decode( token, key=signing_key, algorithms=[settings.frontend_jwt_algorithm], audience=settings.frontend_jwt_audience, issuer=settings.frontend_jwt_issuer_url, leeway=settings.frontend_clock_skew_seconds, ) except InvalidTokenError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid frontend access token.", ) from exc subject = str(claims.get("sub") or "") if not subject: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Frontend token missing subject.", ) scopes = self._extract_scopes(claims) required = settings.frontend_required_scopes_list missing = [scope for scope in required if scope not in scopes] if missing: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Missing required scope(s): {', '.join(missing)}", ) return FrontendPrincipal( subject=subject, scopes=scopes, claims=claims, token=token ) class InternalTokenManager: token_type = "internal-service" @staticmethod def _assert_secret() -> str: secret = settings.internal_service_shared_secret if not secret or secret == "change-me": raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="INTERNAL_SERVICE_SHARED_SECRET must be configured.", ) if len(secret.encode("utf-8")) < 32: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=( "INTERNAL_SERVICE_SHARED_SECRET must be at least 32 bytes for " "HS256 token signing." ), ) return secret def mint( self, *, subject: str, scopes: list[str], source_service: str, ) -> str: now = int(time()) payload = { "sub": subject, "scope": " ".join(scopes), "iss": source_service, "aud": settings.internal_service_token_audience, "typ": self.token_type, "iat": now, "nbf": now, "exp": now + settings.internal_service_token_ttl_seconds, "jti": str(uuid4()), } return jwt.encode(payload, self._assert_secret(), algorithm="HS256") def verify(self, token: str) -> InternalPrincipal: try: claims = jwt.decode( token, self._assert_secret(), algorithms=["HS256"], audience=settings.internal_service_token_audience, options={ "require": ["sub", "iss", "aud", "exp", "iat", "nbf", "jti", "typ"] }, leeway=settings.internal_token_clock_skew_seconds, ) except InvalidTokenError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid internal service token.", ) from exc subject = str(claims.get("sub") or "") if not subject: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Internal token missing subject.", ) issuer = str(claims.get("iss") or "") if issuer not in settings.internal_service_allowed_issuers_list: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Internal token issuer is not allowed.", ) token_type = str(claims.get("typ") or "") if token_type != self.token_type: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Internal token type is invalid.", ) scope = claims.get("scope") scopes = [item for item in str(scope).split(" ") if item] if scope else [] return InternalPrincipal( subject=subject, scopes=scopes, claims=claims, token=token ) @lru_cache(maxsize=1) def get_frontend_verifier() -> FrontendJWTVerifier: return FrontendJWTVerifier() @lru_cache(maxsize=1) def get_internal_token_manager() -> InternalTokenManager: return InternalTokenManager() def require_frontend_principal( credentials: HTTPAuthorizationCredentials | None = Depends(BEARER_SCHEME), ) -> FrontendPrincipal: if not settings.require_frontend_auth: return FrontendPrincipal(subject="anonymous", scopes=[], claims={}, token="") if credentials is None or credentials.scheme.lower() != "bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing bearer token.", ) return get_frontend_verifier().verify(credentials.credentials) def require_internal_principal( internal_token: str | None = Header(default=None, alias="x-internal-service-token"), ) -> InternalPrincipal: if not settings.internal_service_auth_enabled: return InternalPrincipal( subject="internal-unauth", scopes=[], claims={}, token="" ) if not internal_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing x-internal-service-token header.", ) return get_internal_token_manager().verify(internal_token)