| | """Model manager for keypoint–argument matching model""" |
| |
|
| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class KpaModelManager: |
| | """Manages loading and inference for keypoint matching model""" |
| |
|
| | def __init__(self): |
| | self.model = None |
| | self.tokenizer = None |
| | self.device = None |
| | self.model_loaded = False |
| | self.max_length = 256 |
| | self.model_id = None |
| |
|
| | def load_model(self, model_id: str, api_key: str = None): |
| | """Load complete model and tokenizer directly from Hugging Face""" |
| | if self.model_loaded: |
| | logger.info("KPA model already loaded") |
| | return |
| |
|
| | try: |
| | |
| | logger.info(f"=== DEBUG KPA MODEL LOADING ===") |
| | logger.info(f"model_id reçu: {model_id}") |
| | logger.info(f"model_id type: {type(model_id)}") |
| | logger.info(f"api_key présent: {api_key is not None}") |
| | |
| | if model_id is None: |
| | raise ValueError("model_id cannot be None - check your .env file") |
| | |
| | logger.info(f"Loading KPA model from Hugging Face: {model_id}") |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logger.info(f"Using device: {self.device}") |
| | |
| | |
| | self.model_id = model_id |
| | |
| | |
| | token = api_key if api_key else None |
| | |
| | |
| | logger.info("Step 1: Loading tokenizer...") |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_id, |
| | token=token, |
| | trust_remote_code=True |
| | ) |
| | logger.info("✓ Tokenizer loaded successfully") |
| | |
| | logger.info("Step 2: Loading model...") |
| | self.model = AutoModelForSequenceClassification.from_pretrained( |
| | model_id, |
| | token=token, |
| | trust_remote_code=True |
| | ) |
| | logger.info("✓ Model architecture loaded") |
| | |
| | self.model.to(self.device) |
| | self.model.eval() |
| | logger.info("✓ Model moved to device and set to eval mode") |
| |
|
| | self.model_loaded = True |
| | logger.info("✓ KPA model loaded successfully from Hugging Face!") |
| | logger.info(f"=== KPA MODEL LOADING COMPLETE ===") |
| |
|
| | except Exception as e: |
| | logger.error(f"❌ Error loading KPA model: {str(e)}") |
| | logger.error(f"❌ Model ID was: {model_id}") |
| | logger.error(f"❌ API Key present: {api_key is not None}") |
| | raise RuntimeError(f"Failed to load KPA model: {str(e)}") |
| |
|
| | def predict(self, argument: str, key_point: str) -> dict: |
| | """Run a prediction for (argument, key_point)""" |
| | if not self.model_loaded: |
| | raise RuntimeError("KPA model not loaded") |
| |
|
| | try: |
| | |
| | encoding = self.tokenizer( |
| | argument, |
| | key_point, |
| | truncation=True, |
| | padding="max_length", |
| | max_length=self.max_length, |
| | return_tensors="pt" |
| | ).to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**encoding) |
| | logits = outputs.logits |
| | probabilities = torch.softmax(logits, dim=-1) |
| |
|
| | predicted_class = torch.argmax(probabilities, dim=-1).item() |
| | confidence = probabilities[0][predicted_class].item() |
| |
|
| | return { |
| | "prediction": predicted_class, |
| | "confidence": confidence, |
| | "label": "apparie" if predicted_class == 1 else "non_apparie", |
| | "probabilities": { |
| | "non_apparie": probabilities[0][0].item(), |
| | "apparie": probabilities[0][1].item(), |
| | }, |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Error during prediction: {str(e)}") |
| | raise RuntimeError(f"KPA prediction failed: {str(e)}") |
| | |
| | def get_model_info(self): |
| | """Get model information""" |
| | if not self.model_loaded: |
| | return {"loaded": False} |
| | |
| | return { |
| | "model_name": self.model_id, |
| | "device": str(self.device), |
| | "max_length": self.max_length, |
| | "num_labels": 2, |
| | "loaded": self.model_loaded |
| | } |
| |
|
| |
|
| | |
| | kpa_model_manager = KpaModelManager() |