|
|
import gradio as gr |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("loading models...") |
|
|
|
|
|
|
|
|
det_model_repo = "rtr46/meiki.text.detect.v0" |
|
|
det_model_name = "meiki.text.detect.v0.1.960x544.onnx" |
|
|
rec_model_repo = "rtr46/meiki.txt.recognition.v0" |
|
|
rec_model_name = "meiki.text.rec.v0.960x32.onnx" |
|
|
|
|
|
input_det_width = 960 |
|
|
input_det_height = 544 |
|
|
input_rec_height = 32 |
|
|
input_rec_width = 960 |
|
|
x_overlap_threshold = 0.3 |
|
|
epsilon = 1e-6 |
|
|
|
|
|
|
|
|
try: |
|
|
det_model_path = hf_hub_download(repo_id=det_model_repo, filename=det_model_name) |
|
|
rec_model_path = hf_hub_download(repo_id=rec_model_repo, filename=rec_model_name) |
|
|
|
|
|
|
|
|
providers = ['CPUExecutionProvider'] |
|
|
det_session = ort.InferenceSession(det_model_path, providers=providers) |
|
|
rec_session = ort.InferenceSession(rec_model_path, providers=providers) |
|
|
|
|
|
print("models loaded successfully.") |
|
|
except Exception as e: |
|
|
det_session, rec_session = None, None |
|
|
print(f"error loading models: {e}") |
|
|
raise gr.Error(f"failed to load models. please check space logs. error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_for_detection(image): |
|
|
h_orig, w_orig, _ = image.shape |
|
|
resized = cv2.resize(image, (input_det_width, input_det_height), interpolation=cv2.INTER_LINEAR) |
|
|
input_tensor = resized.astype(np.float32) / 255.0 |
|
|
input_tensor = np.transpose(input_tensor, (2, 0, 1)) |
|
|
input_tensor = np.expand_dims(input_tensor, axis=0) |
|
|
scale_x = w_orig / input_det_width |
|
|
scale_y = h_orig / input_det_height |
|
|
return input_tensor, scale_x, scale_y |
|
|
|
|
|
def postprocess_detection_results(raw_outputs, scale_x, scale_y, conf_threshold): |
|
|
_, boxes, scores = raw_outputs |
|
|
boxes, scores = boxes[0], scores[0] |
|
|
text_boxes = [] |
|
|
for box, score in zip(boxes, scores): |
|
|
if score < conf_threshold: continue |
|
|
x1, y1, x2, y2 = box |
|
|
x1_orig, y1_orig = int(x1 * scale_x), int(y1 * scale_y) |
|
|
x2_orig, y2_orig = int(x2 * scale_x), int(y2 * scale_y) |
|
|
text_boxes.append({'bbox': [x1_orig, y1_orig, x2_orig, y2_orig]}) |
|
|
text_boxes.sort(key=lambda tb: tb['bbox'][1]) |
|
|
return text_boxes |
|
|
|
|
|
def preprocess_for_recognition(image, text_boxes): |
|
|
tensors, valid_indices, crop_metadata = [], [], [] |
|
|
for i, tb in enumerate(text_boxes): |
|
|
x1, y1, x2, y2 = tb['bbox'] |
|
|
width, height = x2 - x1, y2 - y1 |
|
|
if width < height or width == 0 or height == 0: continue |
|
|
crop = image[y1:y2, x1:x2] |
|
|
h, w, _ = crop.shape |
|
|
new_h, new_w = input_rec_height, int(round(w * (input_rec_height / h))) |
|
|
if new_w > input_rec_width: |
|
|
scale = input_rec_width / new_w |
|
|
new_w, new_h = input_rec_width, int(round(new_h * scale)) |
|
|
resized = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) |
|
|
pad_w, pad_h = input_rec_width - new_w, input_rec_height - new_h |
|
|
padded = np.pad(resized, ((0, pad_h), (0, pad_w), (0, 0)), constant_values=0) |
|
|
tensor = (padded.astype(np.float32) / 255.0) |
|
|
tensor = np.transpose(tensor, (2, 0, 1)) |
|
|
tensors.append(tensor) |
|
|
valid_indices.append(i) |
|
|
crop_metadata.append({'orig_bbox': [x1, y1, x2, y2], 'effective_w': new_w}) |
|
|
if not tensors: return None, [], [] |
|
|
return np.stack(tensors, axis=0), valid_indices, crop_metadata |
|
|
|
|
|
def postprocess_recognition_results(raw_rec_outputs, valid_indices, crop_metadata, rec_conf_threshold, num_total_boxes): |
|
|
labels_batch, boxes_batch, scores_batch = raw_rec_outputs |
|
|
full_results = [{'text': '', 'chars': []} for _ in range(num_total_boxes)] |
|
|
for i, (labels, boxes, scores) in enumerate(zip(labels_batch, boxes_batch, scores_batch)): |
|
|
meta = crop_metadata[i] |
|
|
gx1, gy1, gx2, gy2 = meta['orig_bbox'] |
|
|
crop_w, crop_h = gx2 - gx1, gy2 - gy1 |
|
|
effective_w = meta['effective_w'] |
|
|
candidates = [] |
|
|
for lbl, box, scr in zip(labels, boxes, scores): |
|
|
if scr < rec_conf_threshold: continue |
|
|
char = chr(lbl) |
|
|
rx1, ry1, rx2, ry2 = box |
|
|
rx1, rx2 = min(rx1, effective_w), min(rx2, effective_w) |
|
|
cx1, cx2 = (rx1 / effective_w) * crop_w, (rx2 / effective_w) * crop_w |
|
|
cy1, cy2 = (ry1 / input_rec_height) * crop_h, (ry2 / input_rec_height) * crop_h |
|
|
gx1_char, gy1_char = gx1 + int(cx1), gy1 + int(cy1) |
|
|
gx2_char, gy2_char = gx1 + int(cx2), gy1 + int(cy2) |
|
|
candidates.append({'char': char, 'bbox': [gx1_char, gy1_char, gx2_char, gy2_char], 'x_interval': (gx1_char, gx2_char), 'conf': float(scr)}) |
|
|
candidates.sort(key=lambda c: c['conf'], reverse=True) |
|
|
accepted = [] |
|
|
for cand in candidates: |
|
|
x1_c, x2_c = cand['x_interval'] |
|
|
width_c = x2_c - x1_c + epsilon |
|
|
is_overlap = any((max(0, min(x2_c, x2_a) - max(x1_c, x1_a)) / width_c) > x_overlap_threshold for x1_a, x2_a in (acc['x_interval'] for acc in accepted)) |
|
|
if not is_overlap: accepted.append(cand) |
|
|
accepted.sort(key=lambda c: c['x_interval'][0]) |
|
|
text = ''.join(c['char'] for c in accepted) |
|
|
final_chars = [{'char': c['char'], 'bbox': c['bbox'], 'conf': c['conf']} for c in accepted] |
|
|
full_results[valid_indices[i]] = {'text': text, 'chars': final_chars} |
|
|
return full_results |
|
|
|
|
|
|
|
|
|
|
|
def run_ocr_pipeline(input_image, det_threshold, rec_threshold): |
|
|
if input_image is None: |
|
|
raise gr.Error("please upload an image to process.") |
|
|
|
|
|
det_input, sx, sy = preprocess_for_detection(input_image) |
|
|
det_raw = det_session.run(None, {det_session.get_inputs()[0].name: det_input, det_session.get_inputs()[1].name: np.array([[input_det_width, input_det_height]], dtype=np.int64)}) |
|
|
text_boxes = postprocess_detection_results(det_raw, sx, sy, det_threshold) |
|
|
|
|
|
if not text_boxes: |
|
|
return input_image, "no text detected. try lowering the 'detection confidence' slider.", "" |
|
|
|
|
|
rec_batch, valid_indices, crop_metadata = preprocess_for_recognition(input_image, text_boxes) |
|
|
rec_raw = rec_session.run(None, {"images": rec_batch, "orig_target_sizes": np.array([[input_rec_width, input_rec_height]], dtype=np.int64)}) |
|
|
results = postprocess_recognition_results(rec_raw, valid_indices, crop_metadata, rec_threshold, len(text_boxes)) |
|
|
|
|
|
output_image = input_image.copy() |
|
|
full_text = [] |
|
|
for res in results: |
|
|
if res['text']: full_text.append(res['text']) |
|
|
for char_info in res['chars']: |
|
|
x1, y1, x2, y2 = char_info['bbox'] |
|
|
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
|
|
|
|
json_output = json.dumps(results, indent=2, ensure_ascii=False) |
|
|
|
|
|
return output_image, "\n".join(full_text), json_output |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# meikiocr: japanese video game ocr") |
|
|
gr.Markdown( |
|
|
"upload a screenshot from a japanese video game to see the high-accuracy ocr in action. " |
|
|
"the pipeline first detects text lines, then recognizes the characters in each line. " |
|
|
"adjust the confidence sliders if text is missed or incorrectly detected." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image(type="numpy", label="upload image") |
|
|
det_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="detection confidence") |
|
|
rec_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="recognition confidence") |
|
|
run_button = gr.Button("run ocr", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_image = gr.Image(type="numpy", label="ocr result") |
|
|
output_text = gr.Textbox(label="recognized text", lines=5) |
|
|
output_json = gr.Code(label="json output", language="json", lines=5) |
|
|
|
|
|
def process_example(img): |
|
|
|
|
|
return run_ocr_pipeline(img, 0.5, 0.1) |
|
|
|
|
|
example_image_path = os.path.join(os.path.dirname(__file__), "example.jpg") |
|
|
if os.path.exists(example_image_path): |
|
|
gr.Examples( |
|
|
examples=[example_image_path], |
|
|
inputs=[input_image], |
|
|
outputs=[output_image, output_text, output_json], |
|
|
fn=process_example, |
|
|
cache_examples=True |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_ocr_pipeline, |
|
|
inputs=[input_image, det_threshold, rec_threshold], |
|
|
outputs=[output_image, output_text, output_json] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### official github repository |
|
|
the full source code, documentation, and local command-line script for `meikiocr` are available on github. |
|
|
**► [github.com/rtr46/meikiocr](https://github.com/rtr46/meikiocr)** |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |