Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
85
libs/calibration/multi_model.py
Normal file
85
libs/calibration/multi_model.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Multi-model calibrator for handling multiple models/tasks."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
from .calibrator import ConfidenceCalibrator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class MultiModelCalibrator:
|
||||
"""Calibrate confidence scores for multiple models/tasks"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calibrators: dict[str, ConfidenceCalibrator] = {}
|
||||
|
||||
def add_calibrator(self, model_name: str, method: str = "temperature") -> None:
|
||||
"""Add calibrator for a specific model"""
|
||||
self.calibrators[model_name] = ConfidenceCalibrator(method)
|
||||
logger.info("Added calibrator", model=model_name, method=method)
|
||||
|
||||
def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None:
|
||||
"""Fit calibrator for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
self.add_calibrator(model_name)
|
||||
|
||||
self.calibrators[model_name].fit(scores, labels)
|
||||
|
||||
def calibrate(self, model_name: str, scores: list[float]) -> list[float]:
|
||||
"""Calibrate scores for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
logger.warning("No calibrator for model", model=model_name)
|
||||
return scores
|
||||
|
||||
return self.calibrators[model_name].calibrate(scores)
|
||||
|
||||
def save_all(self, directory: str) -> None:
|
||||
"""Save all calibrators"""
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
for model_name, calibrator in self.calibrators.items():
|
||||
filepath = os.path.join(directory, f"{model_name}_calibrator.pkl")
|
||||
calibrator.save_model(filepath)
|
||||
|
||||
def load_all(self, directory: str) -> None:
|
||||
"""Load all calibrators from directory"""
|
||||
pattern = os.path.join(directory, "*_calibrator.pkl")
|
||||
for filepath in glob.glob(pattern):
|
||||
filename = os.path.basename(filepath)
|
||||
model_name = filename.replace("_calibrator.pkl", "")
|
||||
|
||||
calibrator = ConfidenceCalibrator()
|
||||
calibrator.load_model(filepath)
|
||||
self.calibrators[model_name] = calibrator
|
||||
|
||||
def save_models(self, directory: str) -> None:
|
||||
"""Save all calibrators (alias for save_all)"""
|
||||
self.save_all(directory)
|
||||
|
||||
def load_models(self, directory: str) -> None:
|
||||
"""Load all calibrators from directory (alias for load_all)"""
|
||||
self.load_all(directory)
|
||||
|
||||
def get_model_names(self) -> list[str]:
|
||||
"""Get list of model names"""
|
||||
return list(self.calibrators.keys())
|
||||
|
||||
def has_model(self, model_name: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return model_name in self.calibrators
|
||||
|
||||
def is_fitted(self, model_name: str) -> bool:
|
||||
"""Check if model is fitted"""
|
||||
if model_name not in self.calibrators:
|
||||
raise ValueError(f"Model '{model_name}' not found")
|
||||
return self.calibrators[model_name].is_fitted
|
||||
|
||||
def remove_calibrator(self, model_name: str) -> None:
|
||||
"""Remove calibrator for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
raise ValueError(f"Model '{model_name}' not found")
|
||||
del self.calibrators[model_name]
|
||||
logger.info("Removed calibrator", model=model_name)
|
||||
Reference in New Issue
Block a user