import os import io import json import time import shutil import tempfile from typing import Tuple import cv2 import fitz # PyMuPDF import numpy as np from PIL import Image import torch from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor from detectron2.data import MetadataCatalog from detectron2 import model_zoo from transformers import TrOCRProcessor, VisionEncoderDecoderModel # ----------------------------- # Configuration (override via env if needed) # ----------------------------- TEXTLINE_MODEL_PATH = os.getenv("TEXTLINE_MODEL_PATH", "./model_final.pth") USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" SCORE_THRESHOLD = float(os.getenv("SCORE_THRESHOLD", "0.5")) AREA_THRESHOLD_PERCENT = float(os.getenv("AREA_THRESHOLD_PERCENT", "12.5")) DPI = int(os.getenv("PDF_DPI", "200")) TROCR_SPANISH_MODEL = os.getenv("TROCR_SPANISH_MODEL", "qantev/trocr-large-spanish") TROCR_FALLBACK_MODEL = os.getenv("TROCR_FALLBACK_MODEL", "microsoft/trocr-base-printed") class EnhancedTextlineExtractor: def __init__(self, model_path: str): self.cfg = self._setup_cfg(model_path) self.predictor = DefaultPredictor(self.cfg) # Init TrOCR self.device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu") self.trocr_processor, self.trocr_model = self._load_trocr() self.trocr_model.to(self.device) def _setup_cfg(self, model_path: str): cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")) cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # textline, baseline cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESHOLD cfg.MODEL.WEIGHTS = model_path cfg.DATASETS.TEST = ("page_test",) cfg.DATALOADER.NUM_WORKERS = 2 MetadataCatalog.get("page_test").thing_classes = ["textline", "baseline"] return cfg def _load_trocr(self): try: processor = TrOCRProcessor.from_pretrained(TROCR_SPANISH_MODEL) model = VisionEncoderDecoderModel.from_pretrained(TROCR_SPANISH_MODEL) return processor, model except Exception: processor = TrOCRProcessor.from_pretrained(TROCR_FALLBACK_MODEL) model = VisionEncoderDecoderModel.from_pretrained(TROCR_FALLBACK_MODEL) return processor, model def pdf_to_images(self, pdf_path: str, dpi: int = DPI): doc = fitz.open(pdf_path) images = [] try: for page_num in range(len(doc)): page = doc.load_page(page_num) mat = fitz.Matrix(dpi / 72, dpi / 72) pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("png") nparr = np.frombuffer(img_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) images.append(img) finally: doc.close() return images def filter_margin_boxes_by_area(self, boxes, scores, area_threshold_percent: float = AREA_THRESHOLD_PERCENT): if len(boxes) == 0: return np.array([]), np.array([]), np.array([]), np.array([]) areas = [] for box in boxes: x1, y1, x2, y2 = box areas.append((x2 - x1) * (y2 - y1)) areas = np.array(areas) avg_area = np.mean(areas) area_threshold = avg_area * (area_threshold_percent / 100.0) main_boxes, main_scores, margin_boxes, margin_scores = [], [], [], [] for b, s, a in zip(boxes, scores, areas): if a >= area_threshold: main_boxes.append(b) main_scores.append(s) else: margin_boxes.append(b) margin_scores.append(s) return np.array(main_boxes), np.array(main_scores), np.array(margin_boxes), np.array(margin_scores) def process_page_standard(self, image): outputs = self.predictor(image) instances = outputs["instances"] boxes = instances.pred_boxes.tensor.cpu().numpy() scores = instances.scores.cpu().numpy() if len(boxes) == 0: return {"success": False, "error": "No textlines detected"} main_boxes, main_scores, _, _ = self.filter_margin_boxes_by_area(boxes, scores) if len(main_boxes) == 0: return {"success": False, "error": "No textlines after filtering"} line_segments = [] full_text_lines = [] for i, (box, score) in enumerate(zip(main_boxes, main_scores)): x1, y1, x2, y2 = map(int, box) crop_bgr = image[y1:y2, x1:x2] try: crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(crop_rgb) pixel_values = self.trocr_processor(images=pil_image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(self.device) with torch.no_grad(): generated_ids = self.trocr_model.generate(pixel_values, max_new_tokens=128) generated_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] text = generated_text.strip() full_text_lines.append(text) line_segments.append({ "line_index": i, "bbox": [int(x1), int(y1), int(x2), int(y2)], "score": float(score), "text": text, "confidence": 1.0 }) except Exception: line_segments.append({ "line_index": i, "bbox": [int(x1), int(y1), int(x2), int(y2)], "score": float(score), "text": "", "confidence": 0.0 }) return { "success": True, "line_segments": line_segments, "full_text": "\n".join(full_text_lines) } def _zip_directory(src_dir: str, zip_path: str) -> str: base, _ = os.path.splitext(zip_path) archive = shutil.make_archive(base, 'zip', src_dir) return archive def run_ocr(pdf_path: str, split_page_enabled: bool = False, use_llm: bool = False, gemini_key: str = None) -> Tuple[str, str]: """ Run OCR on the provided PDF. Returns: combined_text (str), zip_file_path (str) """ extractor = EnhancedTextlineExtractor(TEXTLINE_MODEL_PATH) images = extractor.pdf_to_images(pdf_path, dpi=DPI) temp_dir = tempfile.mkdtemp(prefix="ocr_outputs_") inferences_dir = os.path.join(temp_dir, "inferences") os.makedirs(inferences_dir, exist_ok=True) all_results = [] for i, image in enumerate(images): result = extractor.process_page_standard(image) all_results.append(result) page_file = os.path.join(inferences_dir, f"page_{i+1}_result.json") with open(page_file, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) combined_text = "\n\n".join([r.get("full_text", "") for r in all_results if r.get("success")]) # Optional Gemini correction over combined text (simple, single pass) if use_llm and gemini_key and combined_text.strip(): try: import google.generativeai as genai genai.configure(api_key=gemini_key) prompt = ( "Correct the following historical Spanish OCR text while preserving grammar and style. " "Fix orthography, punctuation, and obvious OCR mistakes. Return only corrected text.\n\n" + combined_text ) response = genai.GenerativeModel('gemini-2.5-pro').generate_content(prompt) if getattr(response, 'text', None): combined_text = response.text.strip() except Exception: # Swallow LLM errors and return original text pass zip_path = os.path.join(temp_dir, "per_page_jsons.zip") archive_path = _zip_directory(inferences_dir, zip_path) return combined_text, archive_path