File size: 9,296 Bytes
f0955c3
 
 
 
 
 
1b0f6fa
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71a1c93
f0955c3
 
 
 
893788d
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71a1c93
f0955c3
 
 
 
 
 
 
 
 
 
 
 
71a1c93
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
 
71a1c93
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893788d
 
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
1b0f6fa
f0955c3
 
 
 
 
 
 
 
71a1c93
f0955c3
 
 
 
1b0f6fa
 
 
f0955c3
 
 
 
 
 
 
 
 
 
 
 
 
 
71a1c93
 
f0955c3
 
 
 
893788d
 
 
091981b
 
 
71a1c93
091981b
893788d
 
 
 
 
091981b
893788d
 
f0955c3
 
 
 
893788d
f0955c3
 
091981b
 
 
 
 
 
 
 
 
 
f0955c3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
import onnxruntime as ort
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
import os
import json

# --- 1. configuration & model loading ---
# this section runs once when the space starts up.

print("loading models...")

# configuration
det_model_repo = "rtr46/meiki.text.detect.v0"
det_model_name = "meiki.text.detect.v0.1.960x544.onnx"
rec_model_repo = "rtr46/meiki.txt.recognition.v0"
rec_model_name = "meiki.text.rec.v0.960x32.onnx"

input_det_width = 960
input_det_height = 544
input_rec_height = 32
input_rec_width = 960
x_overlap_threshold = 0.3
epsilon = 1e-6

# load models from the hub
try:
    det_model_path = hf_hub_download(repo_id=det_model_repo, filename=det_model_name)
    rec_model_path = hf_hub_download(repo_id=rec_model_repo, filename=rec_model_name)
    
    # use cpu execution provider for broad compatibility in spaces
    providers = ['CPUExecutionProvider']
    det_session = ort.InferenceSession(det_model_path, providers=providers)
    rec_session = ort.InferenceSession(rec_model_path, providers=providers)
    
    print("models loaded successfully.")
except Exception as e:
    det_session, rec_session = None, None
    print(f"error loading models: {e}")
    raise gr.Error(f"failed to load models. please check space logs. error: {e}")

# --- 2. ocr pipeline helper functions ---
# (these functions remain unchanged)

def preprocess_for_detection(image):
    h_orig, w_orig, _ = image.shape
    resized = cv2.resize(image, (input_det_width, input_det_height), interpolation=cv2.INTER_LINEAR)
    input_tensor = resized.astype(np.float32) / 255.0
    input_tensor = np.transpose(input_tensor, (2, 0, 1))
    input_tensor = np.expand_dims(input_tensor, axis=0)
    scale_x = w_orig / input_det_width
    scale_y = h_orig / input_det_height
    return input_tensor, scale_x, scale_y

def postprocess_detection_results(raw_outputs, scale_x, scale_y, conf_threshold):
    _, boxes, scores = raw_outputs
    boxes, scores = boxes[0], scores[0]
    text_boxes = []
    for box, score in zip(boxes, scores):
        if score < conf_threshold: continue
        x1, y1, x2, y2 = box
        x1_orig, y1_orig = int(x1 * scale_x), int(y1 * scale_y)
        x2_orig, y2_orig = int(x2 * scale_x), int(y2 * scale_y)
        text_boxes.append({'bbox': [x1_orig, y1_orig, x2_orig, y2_orig]})
    text_boxes.sort(key=lambda tb: tb['bbox'][1])
    return text_boxes

def preprocess_for_recognition(image, text_boxes):
    tensors, valid_indices, crop_metadata = [], [], []
    for i, tb in enumerate(text_boxes):
        x1, y1, x2, y2 = tb['bbox']
        width, height = x2 - x1, y2 - y1
        if width < height or width == 0 or height == 0: continue
        crop = image[y1:y2, x1:x2]
        h, w, _ = crop.shape
        new_h, new_w = input_rec_height, int(round(w * (input_rec_height / h)))
        if new_w > input_rec_width:
            scale = input_rec_width / new_w
            new_w, new_h = input_rec_width, int(round(new_h * scale))
        resized = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        pad_w, pad_h = input_rec_width - new_w, input_rec_height - new_h
        padded = np.pad(resized, ((0, pad_h), (0, pad_w), (0, 0)), constant_values=0)
        tensor = (padded.astype(np.float32) / 255.0)
        tensor = np.transpose(tensor, (2, 0, 1))
        tensors.append(tensor)
        valid_indices.append(i)
        crop_metadata.append({'orig_bbox': [x1, y1, x2, y2], 'effective_w': new_w})
    if not tensors: return None, [], []
    return np.stack(tensors, axis=0), valid_indices, crop_metadata

def postprocess_recognition_results(raw_rec_outputs, valid_indices, crop_metadata, rec_conf_threshold, num_total_boxes):
    labels_batch, boxes_batch, scores_batch = raw_rec_outputs
    full_results = [{'text': '', 'chars': []} for _ in range(num_total_boxes)]
    for i, (labels, boxes, scores) in enumerate(zip(labels_batch, boxes_batch, scores_batch)):
        meta = crop_metadata[i]
        gx1, gy1, gx2, gy2 = meta['orig_bbox']
        crop_w, crop_h = gx2 - gx1, gy2 - gy1
        effective_w = meta['effective_w']
        candidates = []
        for lbl, box, scr in zip(labels, boxes, scores):
            if scr < rec_conf_threshold: continue
            char = chr(lbl)
            rx1, ry1, rx2, ry2 = box
            rx1, rx2 = min(rx1, effective_w), min(rx2, effective_w)
            cx1, cx2 = (rx1 / effective_w) * crop_w, (rx2 / effective_w) * crop_w
            cy1, cy2 = (ry1 / input_rec_height) * crop_h, (ry2 / input_rec_height) * crop_h
            gx1_char, gy1_char = gx1 + int(cx1), gy1 + int(cy1)
            gx2_char, gy2_char = gx1 + int(cx2), gy1 + int(cy2)
            candidates.append({'char': char, 'bbox': [gx1_char, gy1_char, gx2_char, gy2_char], 'x_interval': (gx1_char, gx2_char), 'conf': float(scr)})
        candidates.sort(key=lambda c: c['conf'], reverse=True)
        accepted = []
        for cand in candidates:
            x1_c, x2_c = cand['x_interval']
            width_c = x2_c - x1_c + epsilon
            is_overlap = any((max(0, min(x2_c, x2_a) - max(x1_c, x1_a)) / width_c) > x_overlap_threshold for x1_a, x2_a in (acc['x_interval'] for acc in accepted))
            if not is_overlap: accepted.append(cand)
        accepted.sort(key=lambda c: c['x_interval'][0])
        text = ''.join(c['char'] for c in accepted)
        final_chars = [{'char': c['char'], 'bbox': c['bbox'], 'conf': c['conf']} for c in accepted]
        full_results[valid_indices[i]] = {'text': text, 'chars': final_chars}
    return full_results

# --- 3. main gradio processing function ---

def run_ocr_pipeline(input_image, det_threshold, rec_threshold):
    if input_image is None:
        raise gr.Error("please upload an image to process.")

    det_input, sx, sy = preprocess_for_detection(input_image)
    det_raw = det_session.run(None, {det_session.get_inputs()[0].name: det_input, det_session.get_inputs()[1].name: np.array([[input_det_width, input_det_height]], dtype=np.int64)})
    text_boxes = postprocess_detection_results(det_raw, sx, sy, det_threshold)

    if not text_boxes:
        return input_image, "no text detected. try lowering the 'detection confidence' slider.", ""

    rec_batch, valid_indices, crop_metadata = preprocess_for_recognition(input_image, text_boxes)
    rec_raw = rec_session.run(None, {"images": rec_batch, "orig_target_sizes": np.array([[input_rec_width, input_rec_height]], dtype=np.int64)})
    results = postprocess_recognition_results(rec_raw, valid_indices, crop_metadata, rec_threshold, len(text_boxes))

    output_image = input_image.copy()
    full_text = []
    for res in results:
        if res['text']: full_text.append(res['text'])
        for char_info in res['chars']:
            x1, y1, x2, y2 = char_info['bbox']
            cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
    
    json_output = json.dumps(results, indent=2, ensure_ascii=False)
    
    return output_image, "\n".join(full_text), json_output

# --- 4. gradio interface definition ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# meikiocr: japanese video game ocr")
    gr.Markdown(
        "upload a screenshot from a japanese video game to see the high-accuracy ocr in action. "
        "the pipeline first detects text lines, then recognizes the characters in each line. "
        "adjust the confidence sliders if text is missed or incorrectly detected."
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="numpy", label="upload image")
            det_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="detection confidence")
            rec_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="recognition confidence")
            run_button = gr.Button("run ocr", variant="primary")
            
        with gr.Column(scale=2):
            output_image = gr.Image(type="numpy", label="ocr result")
            output_text = gr.Textbox(label="recognized text", lines=5)
            output_json = gr.Code(label="json output", language="json", lines=5)

    def process_example(img):
        # examples are pre-loaded as numpy by gradio, so we can pass them directly
        return run_ocr_pipeline(img, 0.5, 0.1)

    example_image_path = os.path.join(os.path.dirname(__file__), "example.jpg") # <-- updated filename
    if os.path.exists(example_image_path):
        gr.Examples(
            examples=[example_image_path],
            inputs=[input_image],
            outputs=[output_image, output_text, output_json],
            fn=process_example,
            cache_examples=True
        )
            
    run_button.click(
        fn=run_ocr_pipeline,
        inputs=[input_image, det_threshold, rec_threshold],
        outputs=[output_image, output_text, output_json]
    )

    # <-- a new markdown component is added here for the footer
    gr.Markdown(
        """
        ---
        ### official github repository
        the full source code, documentation, and local command-line script for `meikiocr` are available on github.
        **► [github.com/rtr46/meikiocr](https://github.com/rtr46/meikiocr)**
        """
    )

# --- 5. launch the app ---
demo.launch()