"""Multi-model calibrator for handling multiple models/tasks.""" import glob import os import structlog from .calibrator import ConfidenceCalibrator logger = structlog.get_logger() class MultiModelCalibrator: """Calibrate confidence scores for multiple models/tasks""" def __init__(self) -> None: self.calibrators: dict[str, ConfidenceCalibrator] = {} def add_calibrator(self, model_name: str, method: str = "temperature") -> None: """Add calibrator for a specific model""" self.calibrators[model_name] = ConfidenceCalibrator(method) logger.info("Added calibrator", model=model_name, method=method) def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None: """Fit calibrator for specific model""" if model_name not in self.calibrators: self.add_calibrator(model_name) self.calibrators[model_name].fit(scores, labels) def calibrate(self, model_name: str, scores: list[float]) -> list[float]: """Calibrate scores for specific model""" if model_name not in self.calibrators: logger.warning("No calibrator for model", model=model_name) return scores return self.calibrators[model_name].calibrate(scores) def save_all(self, directory: str) -> None: """Save all calibrators""" os.makedirs(directory, exist_ok=True) for model_name, calibrator in self.calibrators.items(): filepath = os.path.join(directory, f"{model_name}_calibrator.pkl") calibrator.save_model(filepath) def load_all(self, directory: str) -> None: """Load all calibrators from directory""" pattern = os.path.join(directory, "*_calibrator.pkl") for filepath in glob.glob(pattern): filename = os.path.basename(filepath) model_name = filename.replace("_calibrator.pkl", "") calibrator = ConfidenceCalibrator() calibrator.load_model(filepath) self.calibrators[model_name] = calibrator def save_models(self, directory: str) -> None: """Save all calibrators (alias for save_all)""" self.save_all(directory) def load_models(self, directory: str) -> None: """Load all calibrators from directory (alias for load_all)""" self.load_all(directory) def get_model_names(self) -> list[str]: """Get list of model names""" return list(self.calibrators.keys()) def has_model(self, model_name: str) -> bool: """Check if model exists""" return model_name in self.calibrators def is_fitted(self, model_name: str) -> bool: """Check if model is fitted""" if model_name not in self.calibrators: raise ValueError(f"Model '{model_name}' not found") return self.calibrators[model_name].is_fitted def remove_calibrator(self, model_name: str) -> None: """Remove calibrator for specific model""" if model_name not in self.calibrators: raise ValueError(f"Model '{model_name}' not found") del self.calibrators[model_name] logger.info("Removed calibrator", model=model_name)