232 lines
7.6 KiB
Python
232 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from time import time
|
|
from uuid import uuid4
|
|
|
|
import jwt
|
|
from fastapi import Depends, Header, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from jwt import InvalidTokenError, PyJWKClient
|
|
|
|
from app.core.config import settings
|
|
|
|
BEARER_SCHEME = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@dataclass
|
|
class FrontendPrincipal:
|
|
subject: str
|
|
scopes: list[str]
|
|
claims: dict
|
|
token: str
|
|
|
|
|
|
@dataclass
|
|
class InternalPrincipal:
|
|
subject: str
|
|
scopes: list[str]
|
|
claims: dict
|
|
token: str
|
|
|
|
|
|
class FrontendJWTVerifier:
|
|
@property
|
|
def jwks_url(self) -> str:
|
|
if not settings.frontend_jwt_jwks_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_JWKS_URL is not configured.",
|
|
)
|
|
return settings.frontend_jwt_jwks_url
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _jwks_client(self) -> PyJWKClient:
|
|
return PyJWKClient(self.jwks_url)
|
|
|
|
@staticmethod
|
|
def _extract_scopes(claims: dict) -> list[str]:
|
|
scope = claims.get("scope")
|
|
if isinstance(scope, str):
|
|
return [item for item in scope.split(" ") if item]
|
|
scp = claims.get("scp")
|
|
if isinstance(scp, list):
|
|
return [str(item) for item in scp]
|
|
return []
|
|
|
|
def verify(self, token: str) -> FrontendPrincipal:
|
|
if not settings.frontend_jwt_issuer_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_ISSUER_URL is not configured.",
|
|
)
|
|
if not settings.frontend_jwt_audience:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="FRONTEND_JWT_AUDIENCE is not configured.",
|
|
)
|
|
|
|
try:
|
|
signing_key = self._jwks_client().get_signing_key_from_jwt(token).key
|
|
claims = jwt.decode(
|
|
token,
|
|
key=signing_key,
|
|
algorithms=[settings.frontend_jwt_algorithm],
|
|
audience=settings.frontend_jwt_audience,
|
|
issuer=settings.frontend_jwt_issuer_url,
|
|
leeway=settings.frontend_clock_skew_seconds,
|
|
)
|
|
except InvalidTokenError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid frontend access token.",
|
|
) from exc
|
|
|
|
subject = str(claims.get("sub") or "")
|
|
if not subject:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Frontend token missing subject.",
|
|
)
|
|
|
|
scopes = self._extract_scopes(claims)
|
|
required = settings.frontend_required_scopes_list
|
|
missing = [scope for scope in required if scope not in scopes]
|
|
if missing:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Missing required scope(s): {', '.join(missing)}",
|
|
)
|
|
return FrontendPrincipal(
|
|
subject=subject, scopes=scopes, claims=claims, token=token
|
|
)
|
|
|
|
|
|
class InternalTokenManager:
|
|
token_type = "internal-service"
|
|
|
|
@staticmethod
|
|
def _assert_secret() -> str:
|
|
secret = settings.internal_service_shared_secret
|
|
if not secret or secret == "change-me":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="INTERNAL_SERVICE_SHARED_SECRET must be configured.",
|
|
)
|
|
if len(secret.encode("utf-8")) < 32:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=(
|
|
"INTERNAL_SERVICE_SHARED_SECRET must be at least 32 bytes for "
|
|
"HS256 token signing."
|
|
),
|
|
)
|
|
return secret
|
|
|
|
def mint(
|
|
self,
|
|
*,
|
|
subject: str,
|
|
scopes: list[str],
|
|
source_service: str,
|
|
) -> str:
|
|
now = int(time())
|
|
payload = {
|
|
"sub": subject,
|
|
"scope": " ".join(scopes),
|
|
"iss": source_service,
|
|
"aud": settings.internal_service_token_audience,
|
|
"typ": self.token_type,
|
|
"iat": now,
|
|
"nbf": now,
|
|
"exp": now + settings.internal_service_token_ttl_seconds,
|
|
"jti": str(uuid4()),
|
|
}
|
|
return jwt.encode(payload, self._assert_secret(), algorithm="HS256")
|
|
|
|
def verify(self, token: str) -> InternalPrincipal:
|
|
try:
|
|
claims = jwt.decode(
|
|
token,
|
|
self._assert_secret(),
|
|
algorithms=["HS256"],
|
|
audience=settings.internal_service_token_audience,
|
|
options={
|
|
"require": ["sub", "iss", "aud", "exp", "iat", "nbf", "jti", "typ"]
|
|
},
|
|
leeway=settings.internal_token_clock_skew_seconds,
|
|
)
|
|
except InvalidTokenError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid internal service token.",
|
|
) from exc
|
|
|
|
subject = str(claims.get("sub") or "")
|
|
if not subject:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Internal token missing subject.",
|
|
)
|
|
|
|
issuer = str(claims.get("iss") or "")
|
|
if issuer not in settings.internal_service_allowed_issuers_list:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Internal token issuer is not allowed.",
|
|
)
|
|
|
|
token_type = str(claims.get("typ") or "")
|
|
if token_type != self.token_type:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Internal token type is invalid.",
|
|
)
|
|
|
|
scope = claims.get("scope")
|
|
scopes = [item for item in str(scope).split(" ") if item] if scope else []
|
|
return InternalPrincipal(
|
|
subject=subject, scopes=scopes, claims=claims, token=token
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_frontend_verifier() -> FrontendJWTVerifier:
|
|
return FrontendJWTVerifier()
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_internal_token_manager() -> InternalTokenManager:
|
|
return InternalTokenManager()
|
|
|
|
|
|
def require_frontend_principal(
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(BEARER_SCHEME),
|
|
) -> FrontendPrincipal:
|
|
if not settings.require_frontend_auth:
|
|
return FrontendPrincipal(subject="anonymous", scopes=[], claims={}, token="")
|
|
|
|
if credentials is None or credentials.scheme.lower() != "bearer":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing bearer token.",
|
|
)
|
|
return get_frontend_verifier().verify(credentials.credentials)
|
|
|
|
|
|
def require_internal_principal(
|
|
internal_token: str | None = Header(default=None, alias="x-internal-service-token"),
|
|
) -> InternalPrincipal:
|
|
if not settings.internal_service_auth_enabled:
|
|
return InternalPrincipal(
|
|
subject="internal-unauth", scopes=[], claims={}, token=""
|
|
)
|
|
|
|
if not internal_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing x-internal-service-token header.",
|
|
)
|
|
return get_internal_token_manager().verify(internal_token)
|