Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
86 lines
3.1 KiB
Python
86 lines
3.1 KiB
Python
"""Multi-model calibrator for handling multiple models/tasks."""
|
|
|
|
import glob
|
|
import os
|
|
|
|
import structlog
|
|
|
|
from .calibrator import ConfidenceCalibrator
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class MultiModelCalibrator:
|
|
"""Calibrate confidence scores for multiple models/tasks"""
|
|
|
|
def __init__(self) -> None:
|
|
self.calibrators: dict[str, ConfidenceCalibrator] = {}
|
|
|
|
def add_calibrator(self, model_name: str, method: str = "temperature") -> None:
|
|
"""Add calibrator for a specific model"""
|
|
self.calibrators[model_name] = ConfidenceCalibrator(method)
|
|
logger.info("Added calibrator", model=model_name, method=method)
|
|
|
|
def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None:
|
|
"""Fit calibrator for specific model"""
|
|
if model_name not in self.calibrators:
|
|
self.add_calibrator(model_name)
|
|
|
|
self.calibrators[model_name].fit(scores, labels)
|
|
|
|
def calibrate(self, model_name: str, scores: list[float]) -> list[float]:
|
|
"""Calibrate scores for specific model"""
|
|
if model_name not in self.calibrators:
|
|
logger.warning("No calibrator for model", model=model_name)
|
|
return scores
|
|
|
|
return self.calibrators[model_name].calibrate(scores)
|
|
|
|
def save_all(self, directory: str) -> None:
|
|
"""Save all calibrators"""
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
for model_name, calibrator in self.calibrators.items():
|
|
filepath = os.path.join(directory, f"{model_name}_calibrator.pkl")
|
|
calibrator.save_model(filepath)
|
|
|
|
def load_all(self, directory: str) -> None:
|
|
"""Load all calibrators from directory"""
|
|
pattern = os.path.join(directory, "*_calibrator.pkl")
|
|
for filepath in glob.glob(pattern):
|
|
filename = os.path.basename(filepath)
|
|
model_name = filename.replace("_calibrator.pkl", "")
|
|
|
|
calibrator = ConfidenceCalibrator()
|
|
calibrator.load_model(filepath)
|
|
self.calibrators[model_name] = calibrator
|
|
|
|
def save_models(self, directory: str) -> None:
|
|
"""Save all calibrators (alias for save_all)"""
|
|
self.save_all(directory)
|
|
|
|
def load_models(self, directory: str) -> None:
|
|
"""Load all calibrators from directory (alias for load_all)"""
|
|
self.load_all(directory)
|
|
|
|
def get_model_names(self) -> list[str]:
|
|
"""Get list of model names"""
|
|
return list(self.calibrators.keys())
|
|
|
|
def has_model(self, model_name: str) -> bool:
|
|
"""Check if model exists"""
|
|
return model_name in self.calibrators
|
|
|
|
def is_fitted(self, model_name: str) -> bool:
|
|
"""Check if model is fitted"""
|
|
if model_name not in self.calibrators:
|
|
raise ValueError(f"Model '{model_name}' not found")
|
|
return self.calibrators[model_name].is_fitted
|
|
|
|
def remove_calibrator(self, model_name: str) -> None:
|
|
"""Remove calibrator for specific model"""
|
|
if model_name not in self.calibrators:
|
|
raise ValueError(f"Model '{model_name}' not found")
|
|
del self.calibrators[model_name]
|
|
logger.info("Removed calibrator", model=model_name)
|