| import os
|
| import torch
|
| import easyocr
|
| import numpy as np
|
| import gc
|
| from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
|
| import torch.nn.functional as F
|
| from utils import build_transform
|
|
|
| class ModelHandler:
|
| def __init__(self):
|
| self.device = torch.device("cpu")
|
| self.transform = build_transform()
|
| self.load_models()
|
|
|
| def load_models(self):
|
|
|
| try:
|
|
|
| local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
|
| if os.path.exists(local_path):
|
| internvl_model_path = local_path
|
| print(f"Loading InternVL from local path: {internvl_model_path}")
|
| else:
|
| internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO"
|
| print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}")
|
|
|
| self.model_int = AutoModel.from_pretrained(
|
| internvl_model_path,
|
| torch_dtype=torch.bfloat16,
|
| low_cpu_mem_usage=True,
|
| trust_remote_code=True
|
| ).eval()
|
|
|
| for module in self.model_int.modules():
|
| if isinstance(module, torch.nn.Dropout):
|
| module.p = 0
|
|
|
| self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
|
| print("\nInternVL model and tokenizer loaded successfully.")
|
| except Exception as e:
|
| print(f"\nError loading InternVL model or tokenizer: {e}")
|
| self.model_int = None
|
| self.tokenizer_int = None
|
|
|
|
|
| try:
|
|
|
| self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
|
| print("\nEasyOCR reader initialized successfully.")
|
| except Exception as e:
|
| print(f"\nError initializing EasyOCR reader: {e}")
|
| self.reader = None
|
|
|
|
|
| try:
|
| local_path = os.path.join("Models", "clip-vit-base-patch32")
|
| if os.path.exists(local_path):
|
| clip_model_path = local_path
|
| print(f"Loading CLIP from local path: {clip_model_path}")
|
| else:
|
| clip_model_path = "openai/clip-vit-base-patch32"
|
| print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")
|
|
|
| self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
|
| self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
|
| print("\nCLIP model and processor loaded successfully.")
|
| except Exception as e:
|
| print(f"\nError loading CLIP model or processor: {e}")
|
| self.model_clip = None
|
| self.processor_clip = None
|
|
|
| def easyocr_ocr(self, image):
|
| if not self.reader:
|
| return ""
|
| image_np = np.array(image)
|
| results = self.reader.readtext(image_np, detail=1)
|
|
|
| del image_np
|
| gc.collect()
|
|
|
| if not results:
|
| return ""
|
|
|
| sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
|
| ordered_text = " ".join([res[1] for res in sorted_results]).strip()
|
| return ordered_text
|
|
|
| def intern(self, image, prompt, max_tokens):
|
| if not self.model_int or not self.tokenizer_int:
|
| return ""
|
|
|
| pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
|
| with torch.no_grad():
|
| response, _ = self.model_int.chat(
|
| self.tokenizer_int,
|
| pixel_values,
|
| prompt,
|
| generation_config={
|
| "max_new_tokens": max_tokens,
|
| "do_sample": False,
|
| "num_beams": 1,
|
| "temperature": 1.0,
|
| "top_p": 1.0,
|
| "repetition_penalty": 1.0,
|
| "length_penalty": 1.0,
|
| "pad_token_id": self.tokenizer_int.pad_token_id
|
| },
|
| history=None,
|
| return_history=True
|
| )
|
|
|
| del pixel_values
|
| gc.collect()
|
| return response
|
|
|
| def clip(self, image, labels):
|
| if not self.model_clip or not self.processor_clip:
|
| return None
|
|
|
| processed = self.processor_clip(
|
| text=labels,
|
| images=image,
|
| padding=True,
|
| return_tensors="pt"
|
| ).to(self.device)
|
|
|
| del image, labels
|
| gc.collect()
|
| return processed
|
|
|
| def get_clip_probs(self, image, labels):
|
| inputs = self.clip(image, labels)
|
| if inputs is None:
|
| return None
|
|
|
| with torch.no_grad():
|
| outputs = self.model_clip(**inputs)
|
|
|
| logits_per_image = outputs.logits_per_image
|
| probs = F.softmax(logits_per_image, dim=1)
|
|
|
| del inputs, outputs, logits_per_image
|
| gc.collect()
|
|
|
| return probs
|
|
|
|
|
| model_handler = ModelHandler()
|
|
|