Files
ai-tax-agent/libs/calibration/multi_model.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

86 lines
3.1 KiB
Python

"""Multi-model calibrator for handling multiple models/tasks."""
import glob
import os
import structlog
from .calibrator import ConfidenceCalibrator
logger = structlog.get_logger()
class MultiModelCalibrator:
"""Calibrate confidence scores for multiple models/tasks"""
def __init__(self) -> None:
self.calibrators: dict[str, ConfidenceCalibrator] = {}
def add_calibrator(self, model_name: str, method: str = "temperature") -> None:
"""Add calibrator for a specific model"""
self.calibrators[model_name] = ConfidenceCalibrator(method)
logger.info("Added calibrator", model=model_name, method=method)
def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None:
"""Fit calibrator for specific model"""
if model_name not in self.calibrators:
self.add_calibrator(model_name)
self.calibrators[model_name].fit(scores, labels)
def calibrate(self, model_name: str, scores: list[float]) -> list[float]:
"""Calibrate scores for specific model"""
if model_name not in self.calibrators:
logger.warning("No calibrator for model", model=model_name)
return scores
return self.calibrators[model_name].calibrate(scores)
def save_all(self, directory: str) -> None:
"""Save all calibrators"""
os.makedirs(directory, exist_ok=True)
for model_name, calibrator in self.calibrators.items():
filepath = os.path.join(directory, f"{model_name}_calibrator.pkl")
calibrator.save_model(filepath)
def load_all(self, directory: str) -> None:
"""Load all calibrators from directory"""
pattern = os.path.join(directory, "*_calibrator.pkl")
for filepath in glob.glob(pattern):
filename = os.path.basename(filepath)
model_name = filename.replace("_calibrator.pkl", "")
calibrator = ConfidenceCalibrator()
calibrator.load_model(filepath)
self.calibrators[model_name] = calibrator
def save_models(self, directory: str) -> None:
"""Save all calibrators (alias for save_all)"""
self.save_all(directory)
def load_models(self, directory: str) -> None:
"""Load all calibrators from directory (alias for load_all)"""
self.load_all(directory)
def get_model_names(self) -> list[str]:
"""Get list of model names"""
return list(self.calibrators.keys())
def has_model(self, model_name: str) -> bool:
"""Check if model exists"""
return model_name in self.calibrators
def is_fitted(self, model_name: str) -> bool:
"""Check if model is fitted"""
if model_name not in self.calibrators:
raise ValueError(f"Model '{model_name}' not found")
return self.calibrators[model_name].is_fitted
def remove_calibrator(self, model_name: str) -> None:
"""Remove calibrator for specific model"""
if model_name not in self.calibrators:
raise ValueError(f"Model '{model_name}' not found")
del self.calibrators[model_name]
logger.info("Removed calibrator", model=model_name)