Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
"""Trusted proxy middleware for authentication validation."""
|
|
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import structlog
|
|
from fastapi import HTTPException, Request, status
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from .auth import AuthenticationHeaders
|
|
from .utils import is_internal_request
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class TrustedProxyMiddleware(
|
|
BaseHTTPMiddleware
|
|
): # pylint: disable=too-few-public-methods
|
|
"""Middleware to validate requests from trusted proxy (Traefik)"""
|
|
|
|
def __init__(self, app: Any, internal_cidrs: list[str], disable_auth: bool = False):
|
|
super().__init__(app)
|
|
self.internal_cidrs = internal_cidrs
|
|
self.disable_auth = disable_auth
|
|
self.public_endpoints = {
|
|
"/healthz",
|
|
"/readyz",
|
|
"/livez",
|
|
"/metrics",
|
|
"/docs",
|
|
"/openapi.json",
|
|
}
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Any:
|
|
"""Process request through middleware"""
|
|
# Get client IP (considering proxy headers)
|
|
client_ip = self._get_client_ip(request)
|
|
|
|
# Check if authentication is disabled (development mode)
|
|
if self.disable_auth:
|
|
# Set development state
|
|
request.state.user = "dev-user"
|
|
request.state.email = "dev@example.com"
|
|
request.state.roles = ["developers"]
|
|
request.state.auth_token = "dev-token"
|
|
logger.info(
|
|
"Development mode: authentication disabled", path=request.url.path
|
|
)
|
|
return await call_next(request)
|
|
|
|
# Check if this is a public endpoint
|
|
if request.url.path in self.public_endpoints:
|
|
# For metrics endpoint, still require internal network
|
|
if request.url.path == "/metrics":
|
|
if not is_internal_request(client_ip, self.internal_cidrs):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Metrics endpoint only accessible from internal network",
|
|
)
|
|
# Set minimal state for public endpoints
|
|
request.state.user = None
|
|
request.state.email = None
|
|
request.state.roles = []
|
|
return await call_next(request)
|
|
|
|
# For protected endpoints, validate authentication headers
|
|
auth_headers = AuthenticationHeaders(request)
|
|
|
|
# Require authentication headers
|
|
if not auth_headers.authenticated_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing X-Authenticated-User header",
|
|
)
|
|
|
|
if not auth_headers.authenticated_email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing X-Authenticated-Email header",
|
|
)
|
|
|
|
if not auth_headers.authorization_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Authorization header",
|
|
)
|
|
|
|
# Set request state
|
|
request.state.user = auth_headers.authenticated_user
|
|
request.state.email = auth_headers.authenticated_email
|
|
request.state.roles = auth_headers.authenticated_groups
|
|
request.state.auth_token = auth_headers.authorization_token
|
|
|
|
# Add authentication helper to request
|
|
request.state.auth = auth_headers
|
|
|
|
logger.info(
|
|
"Authenticated request",
|
|
user=auth_headers.authenticated_user,
|
|
email=auth_headers.authenticated_email,
|
|
roles=auth_headers.authenticated_groups,
|
|
path=request.url.path,
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Get client IP considering proxy headers"""
|
|
# Check X-Forwarded-For header first (set by Traefik)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# Take the first IP in the chain
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
# Check X-Real-IP header
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct client IP
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
def create_trusted_proxy_middleware(
|
|
internal_cidrs: list[str],
|
|
) -> Callable[[Any], TrustedProxyMiddleware]:
|
|
"""Factory function to create TrustedProxyMiddleware"""
|
|
|
|
def middleware_factory( # pylint: disable=unused-argument
|
|
app: Any,
|
|
) -> TrustedProxyMiddleware:
|
|
return TrustedProxyMiddleware(app, internal_cidrs)
|
|
|
|
return middleware_factory
|