Files
ai-tax-agent/libs/security/middleware.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

135 lines
4.7 KiB
Python

"""Trusted proxy middleware for authentication validation."""
from collections.abc import Callable
from typing import Any
import structlog
from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware
from .auth import AuthenticationHeaders
from .utils import is_internal_request
logger = structlog.get_logger()
class TrustedProxyMiddleware(
BaseHTTPMiddleware
): # pylint: disable=too-few-public-methods
"""Middleware to validate requests from trusted proxy (Traefik)"""
def __init__(self, app: Any, internal_cidrs: list[str], disable_auth: bool = False):
super().__init__(app)
self.internal_cidrs = internal_cidrs
self.disable_auth = disable_auth
self.public_endpoints = {
"/healthz",
"/readyz",
"/livez",
"/metrics",
"/docs",
"/openapi.json",
}
async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Any:
"""Process request through middleware"""
# Get client IP (considering proxy headers)
client_ip = self._get_client_ip(request)
# Check if authentication is disabled (development mode)
if self.disable_auth:
# Set development state
request.state.user = "dev-user"
request.state.email = "dev@example.com"
request.state.roles = ["developers"]
request.state.auth_token = "dev-token"
logger.info(
"Development mode: authentication disabled", path=request.url.path
)
return await call_next(request)
# Check if this is a public endpoint
if request.url.path in self.public_endpoints:
# For metrics endpoint, still require internal network
if request.url.path == "/metrics":
if not is_internal_request(client_ip, self.internal_cidrs):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Metrics endpoint only accessible from internal network",
)
# Set minimal state for public endpoints
request.state.user = None
request.state.email = None
request.state.roles = []
return await call_next(request)
# For protected endpoints, validate authentication headers
auth_headers = AuthenticationHeaders(request)
# Require authentication headers
if not auth_headers.authenticated_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Authenticated-User header",
)
if not auth_headers.authenticated_email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Authenticated-Email header",
)
if not auth_headers.authorization_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
)
# Set request state
request.state.user = auth_headers.authenticated_user
request.state.email = auth_headers.authenticated_email
request.state.roles = auth_headers.authenticated_groups
request.state.auth_token = auth_headers.authorization_token
# Add authentication helper to request
request.state.auth = auth_headers
logger.info(
"Authenticated request",
user=auth_headers.authenticated_user,
email=auth_headers.authenticated_email,
roles=auth_headers.authenticated_groups,
path=request.url.path,
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Get client IP considering proxy headers"""
# Check X-Forwarded-For header first (set by Traefik)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
# Check X-Real-IP header
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
def create_trusted_proxy_middleware(
internal_cidrs: list[str],
) -> Callable[[Any], TrustedProxyMiddleware]:
"""Factory function to create TrustedProxyMiddleware"""
def middleware_factory( # pylint: disable=unused-argument
app: Any,
) -> TrustedProxyMiddleware:
return TrustedProxyMiddleware(app, internal_cidrs)
return middleware_factory