109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from jwt import InvalidTokenError, PyJWKClient
|
|
|
|
from app.core.config import settings
|
|
|
|
BEARER_SCHEME = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@dataclass
|
|
class FrontendPrincipal:
|
|
subject: str
|
|
scopes: list[str]
|
|
claims: dict
|
|
token: str
|
|
|
|
|
|
class FrontendJWTVerifier:
|
|
@property
|
|
def jwks_url(self) -> str:
|
|
if not settings.frontend_jwt_jwks_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_JWKS_URL is not configured.",
|
|
)
|
|
return settings.frontend_jwt_jwks_url
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _jwks_client(self) -> PyJWKClient:
|
|
return PyJWKClient(self.jwks_url)
|
|
|
|
@staticmethod
|
|
def _extract_scopes(claims: dict) -> list[str]:
|
|
scope = claims.get("scope")
|
|
if isinstance(scope, str):
|
|
return [item for item in scope.split(" ") if item]
|
|
scp = claims.get("scp")
|
|
if isinstance(scp, list):
|
|
return [str(item) for item in scp]
|
|
return []
|
|
|
|
def verify(self, token: str) -> FrontendPrincipal:
|
|
if not settings.frontend_jwt_issuer_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_ISSUER_URL is not configured.",
|
|
)
|
|
if not settings.frontend_jwt_audience:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_AUDIENCE is not configured.",
|
|
)
|
|
try:
|
|
signing_key = self._jwks_client().get_signing_key_from_jwt(token).key
|
|
claims = jwt.decode(
|
|
token,
|
|
key=signing_key,
|
|
algorithms=[settings.frontend_jwt_algorithm],
|
|
audience=settings.frontend_jwt_audience,
|
|
issuer=settings.frontend_jwt_issuer_url,
|
|
leeway=settings.frontend_clock_skew_seconds,
|
|
)
|
|
except InvalidTokenError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid frontend access token.",
|
|
) from exc
|
|
|
|
subject = str(claims.get("sub") or "")
|
|
if not subject:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Frontend token missing subject.",
|
|
)
|
|
|
|
scopes = self._extract_scopes(claims)
|
|
required = settings.frontend_required_scopes_list
|
|
missing = [s for s in required if s not in scopes]
|
|
if missing:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Missing required scope(s): {', '.join(missing)}",
|
|
)
|
|
return FrontendPrincipal(subject=subject, scopes=scopes, claims=claims, token=token)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_frontend_verifier() -> FrontendJWTVerifier:
|
|
return FrontendJWTVerifier()
|
|
|
|
|
|
def require_frontend_principal(
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(BEARER_SCHEME),
|
|
) -> FrontendPrincipal:
|
|
if not settings.require_frontend_auth:
|
|
return FrontendPrincipal(subject="anonymous", scopes=[], claims={}, token="")
|
|
if credentials is None or credentials.scheme.lower() != "bearer":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing bearer token.",
|
|
)
|
|
return get_frontend_verifier().verify(credentials.credentials)
|