"""Trusted proxy middleware for authentication validation.""" from collections.abc import Callable from typing import Any import structlog from fastapi import HTTPException, Request, status from starlette.middleware.base import BaseHTTPMiddleware from .auth import AuthenticationHeaders from .utils import is_internal_request logger = structlog.get_logger() class TrustedProxyMiddleware( BaseHTTPMiddleware ): # pylint: disable=too-few-public-methods """Middleware to validate requests from trusted proxy (Traefik)""" def __init__(self, app: Any, internal_cidrs: list[str], disable_auth: bool = False): super().__init__(app) self.internal_cidrs = internal_cidrs self.disable_auth = disable_auth self.public_endpoints = { "/healthz", "/readyz", "/livez", "/metrics", "/docs", "/openapi.json", } async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Any: """Process request through middleware""" # Get client IP (considering proxy headers) client_ip = self._get_client_ip(request) # Check if authentication is disabled (development mode) if self.disable_auth: # Set development state request.state.user = "dev-user" request.state.email = "dev@example.com" request.state.roles = ["developers"] request.state.auth_token = "dev-token" logger.info( "Development mode: authentication disabled", path=request.url.path ) return await call_next(request) # Check if this is a public endpoint if request.url.path in self.public_endpoints: # For metrics endpoint, still require internal network if request.url.path == "/metrics": if not is_internal_request(client_ip, self.internal_cidrs): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Metrics endpoint only accessible from internal network", ) # Set minimal state for public endpoints request.state.user = None request.state.email = None request.state.roles = [] return await call_next(request) # For protected endpoints, validate authentication headers auth_headers = AuthenticationHeaders(request) # Require authentication headers if not auth_headers.authenticated_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing X-Authenticated-User header", ) if not auth_headers.authenticated_email: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing X-Authenticated-Email header", ) if not auth_headers.authorization_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header", ) # Set request state request.state.user = auth_headers.authenticated_user request.state.email = auth_headers.authenticated_email request.state.roles = auth_headers.authenticated_groups request.state.auth_token = auth_headers.authorization_token # Add authentication helper to request request.state.auth = auth_headers logger.info( "Authenticated request", user=auth_headers.authenticated_user, email=auth_headers.authenticated_email, roles=auth_headers.authenticated_groups, path=request.url.path, ) return await call_next(request) def _get_client_ip(self, request: Request) -> str: """Get client IP considering proxy headers""" # Check X-Forwarded-For header first (set by Traefik) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: # Take the first IP in the chain return forwarded_for.split(",")[0].strip() # Check X-Real-IP header real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fall back to direct client IP return request.client.host if request.client else "unknown" def create_trusted_proxy_middleware( internal_cidrs: list[str], ) -> Callable[[Any], TrustedProxyMiddleware]: """Factory function to create TrustedProxyMiddleware""" def middleware_factory( # pylint: disable=unused-argument app: Any, ) -> TrustedProxyMiddleware: return TrustedProxyMiddleware(app, internal_cidrs) return middleware_factory