Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from ultralytics import YOLO | |
| from transformers import AutoProcessor | |
| from transformers import AutoModelForTokenClassification | |
| from utils import normalize_box, unnormalize_box, draw_output, create_df | |
| from PIL import Image, ImageDraw | |
| from vietocr.tool.predictor import Predictor | |
| from vietocr.tool.config import Cfg | |
| class Reciept_Analyzer: | |
| def __init__(self, | |
| processor_pretrained='microsoft/layoutlmv3-base', | |
| layoutlm_pretrained=os.path.join( | |
| 'models', 'checkpoint'), | |
| yolo_pretrained=os.path.join( | |
| 'models', 'best.pt'), | |
| vietocr_pretrained=os.path.join( | |
| 'models', 'vietocr', 'vgg_seq2seq.pth') | |
| ): | |
| print("Initializing processor") | |
| if torch.cuda.is_available(): | |
| print("Using GPU") | |
| else: | |
| print("No GPU detected, using CPU") | |
| self.processor = AutoProcessor.from_pretrained( | |
| processor_pretrained, apply_ocr=False) | |
| print("Finished initializing processor") | |
| print("Initializing LayoutLM model") | |
| self.lalm_model = AutoModelForTokenClassification.from_pretrained( | |
| layoutlm_pretrained) | |
| print("Finished initializing LayoutLM model") | |
| if yolo_pretrained is not None: | |
| print("Initializing YOLO model") | |
| self.yolo_model = YOLO(yolo_pretrained) | |
| print("Finished initializing YOLO model") | |
| print("Initializing VietOCR model") | |
| config = Cfg.load_config_from_name('vgg_seq2seq') | |
| config['weights'] = vietocr_pretrained | |
| config['cnn']['pretrained']= False | |
| config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| self.vietocr = Predictor(config) | |
| print("Finished initializing VietOCR model") | |
| def forward(self, img, output_path="output", is_save_cropped_img=False): | |
| input_image = Image.open(img) | |
| # detection with YOLOv8 | |
| bboxes = self.yolov8_det(input_image) | |
| # sort | |
| sorted_bboxes = self.sort_bboxes(bboxes) | |
| # draw bbox | |
| image_draw = input_image.copy() | |
| self.draw_bbox(image_draw, sorted_bboxes, output_path) | |
| # crop images | |
| cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path) | |
| # recognition with VietOCR | |
| texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes) | |
| # KIE with LayoutLMv3 | |
| pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path) | |
| # create dataframe | |
| return create_df(pred_texts, pred_label) | |
| def yolov8_det(self, img): | |
| return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int() | |
| def sort_bboxes(self, bboxes): | |
| bbox_list = [] | |
| for box in bboxes: | |
| tlx, tly, brx, bry = map(int, box) | |
| bbox_list.append([tlx, tly, brx, bry]) | |
| bbox_list.sort(key=lambda x: (x[1], x[2])) | |
| return bbox_list | |
| def draw_bbox(self, image_draw, bboxes, output_path): | |
| # draw bbox | |
| draw = ImageDraw.Draw(image_draw) | |
| for box in bboxes: | |
| draw.rectangle(box, outline='red', width=2) | |
| image_draw.save(os.path.join(output_path, 'bbox.jpg')) | |
| print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}") | |
| def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"): | |
| normalized_boxes = [] | |
| cropped_images = [] | |
| # OCR | |
| if is_save_cropped: | |
| cropped_folder = os.path.join(output_path, "cropped") | |
| if not os.path.exists(cropped_folder): | |
| os.makedirs(cropped_folder) | |
| i = 0 | |
| for box in bboxes: | |
| tlx, tly, brx, bry = map(int, box) | |
| normalized_box = normalize_box(box, input_image.width, input_image.height) | |
| normalized_boxes.append(normalized_box) | |
| cropped_ = input_image.crop((tlx, tly, brx, bry)) | |
| if is_save_cropped: | |
| cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg')) | |
| i += 1 | |
| cropped_images.append(cropped_) | |
| return cropped_images, normalized_boxes | |
| def ocr(self, cropped_images, normalized_boxes): | |
| mapping_bbox_texts = {} | |
| texts = [] | |
| for img, normalized_box in zip(cropped_images, normalized_boxes): | |
| result = self.vietocr.predict(img) | |
| text = result.strip().replace('\n', ' ') | |
| texts.append(text) | |
| mapping_bbox_texts[','.join(map(str, normalized_box))] = text | |
| return texts, mapping_bbox_texts | |
| def kie(self, img, texts, boxes, mapping_bbox_texts, output_path): | |
| encoding = self.processor(img, texts, | |
| boxes=boxes, | |
| return_offsets_mapping=True, | |
| return_tensors='pt', | |
| max_length=512, | |
| padding='max_length') | |
| offset_mapping = encoding.pop('offset_mapping') | |
| with torch.no_grad(): | |
| outputs = self.lalm_model(**encoding) | |
| id2label = self.lalm_model.config.id2label | |
| logits = outputs.logits | |
| token_boxes = encoding.bbox.squeeze().tolist() | |
| offset_mapping = offset_mapping.squeeze().tolist() | |
| predictions = logits.argmax(-1).squeeze().tolist() | |
| is_subword = np.array(offset_mapping)[:, 0] != 0 | |
| true_predictions = [] | |
| true_boxes = [] | |
| true_texts = [] | |
| for idx in range(len(predictions)): | |
| if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]: | |
| true_predictions.append(id2label[predictions[idx]]) | |
| true_boxes.append(unnormalize_box( | |
| token_boxes[idx], img.width, img.height)) | |
| true_texts.append(mapping_bbox_texts.get( | |
| ','.join(map(str, token_boxes[idx])), '')) | |
| if isinstance(output_path, str): | |
| os.makedirs(output_path, exist_ok=True) | |
| img_output = draw_output( | |
| image=img, | |
| true_predictions=true_predictions, | |
| true_boxes=true_boxes | |
| ) | |
| img_output.save(os.path.join(output_path, 'result.jpg')) | |
| print(f"Exported result to {os.path.join(output_path, 'result.jpg')}") | |
| return true_texts, true_predictions, true_boxes |