|
|
import os |
|
|
import torch |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoProcessor, MllamaForConditionalGeneration, Trainer, TrainingArguments |
|
|
|
|
|
|
|
|
CKPT = "unsloth/Llama-3.2-11B-Vision-Instruct" |
|
|
MODEL_SAVE_PATH = "./llama3-handwriting-ocr-simple" |
|
|
DATASET_NAME = "om440/partial-rimes-handwritten-small" |
|
|
MAX_TOKENS = 128 |
|
|
TRAIN_EXAMPLES = 20 |
|
|
|
|
|
|
|
|
TRAINING_PROMPT = """Extract all text from this image exactly as written. |
|
|
Include both handwritten and printed text. |
|
|
Preserve original formatting and line breaks.""" |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
if os.path.exists(MODEL_SAVE_PATH): |
|
|
model = MllamaForConditionalGeneration.from_pretrained(MODEL_SAVE_PATH) |
|
|
processor = AutoProcessor.from_pretrained(MODEL_SAVE_PATH) |
|
|
else: |
|
|
model = MllamaForConditionalGeneration.from_pretrained(CKPT) |
|
|
processor = AutoProcessor.from_pretrained(CKPT) |
|
|
return model, processor |
|
|
|
|
|
|
|
|
def preprocess(examples, processor): |
|
|
images = [] |
|
|
for img in examples["image"]: |
|
|
if hasattr(img, "convert"): |
|
|
image = img.convert("RGB") |
|
|
else: |
|
|
image = Image.fromarray(img).convert("RGB") |
|
|
image.thumbnail((336, 336), Image.Resampling.LANCZOS) |
|
|
images.append(image) |
|
|
|
|
|
inputs = processor( |
|
|
text=[TRAINING_PROMPT] * len(images), |
|
|
images=images, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=64, |
|
|
) |
|
|
labels = processor.tokenizer( |
|
|
examples["text"], |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=MAX_TOKENS, |
|
|
return_tensors="pt" |
|
|
) |
|
|
return { |
|
|
"input_ids": inputs["input_ids"], |
|
|
"attention_mask": inputs["attention_mask"], |
|
|
"pixel_values": inputs["pixel_values"], |
|
|
"labels": labels["input_ids"] |
|
|
} |
|
|
|
|
|
|
|
|
def train(num_examples=TRAIN_EXAMPLES): |
|
|
dataset = load_dataset(DATASET_NAME)["train"].select(range(num_examples)) |
|
|
model, processor = load_model() |
|
|
|
|
|
tokenized = dataset.map( |
|
|
lambda ex: preprocess(ex, processor), |
|
|
batched=True, |
|
|
batch_size=1, |
|
|
remove_columns=dataset.column_names |
|
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=MODEL_SAVE_PATH, |
|
|
per_device_train_batch_size=1, |
|
|
num_train_epochs=1, |
|
|
logging_steps=1, |
|
|
save_steps=10, |
|
|
save_total_limit=1, |
|
|
remove_unused_columns=False, |
|
|
report_to="none", |
|
|
gradient_accumulation_steps=1, |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized, |
|
|
tokenizer=processor.tokenizer, |
|
|
data_collator=lambda data: { |
|
|
"input_ids": torch.stack([f["input_ids"] for f in data]), |
|
|
"attention_mask": torch.stack([f["attention_mask"] for f in data]), |
|
|
"pixel_values": torch.stack([f["pixel_values"] for f in data]), |
|
|
"labels": torch.stack([f["labels"] for f in data]), |
|
|
} |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
model.save_pretrained(MODEL_SAVE_PATH) |
|
|
processor.save_pretrained(MODEL_SAVE_PATH) |
|
|
return f"✅ Entraînement terminé sur {num_examples} exemples." |
|
|
|
|
|
|
|
|
def extract_text(image_path): |
|
|
model, processor = load_model() |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
image.thumbnail((336, 336), Image.Resampling.LANCZOS) |
|
|
|
|
|
prompt = TRAINING_PROMPT |
|
|
|
|
|
inputs = processor( |
|
|
text=prompt, |
|
|
images=[image], |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
result = processor.decode(outputs[0], skip_special_tokens=True) |
|
|
if prompt in result: |
|
|
result = result.replace(prompt, "").strip() |
|
|
return result or "Aucun texte détecté." |
|
|
|
|
|
|
|
|
def launch_gradio(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# LLaMA 3.2 Vision OCR simple") |
|
|
|
|
|
with gr.Tab("Entraînement"): |
|
|
train_btn = gr.Button("🚀 Entraîner sur 20 exemples") |
|
|
train_output = gr.Textbox(lines=4, interactive=False) |
|
|
train_btn.click(fn=train, inputs=[], outputs=train_output) |
|
|
|
|
|
with gr.Tab("Extraction"): |
|
|
image_input = gr.Image(type="filepath", label="Uploader une image") |
|
|
extract_btn = gr.Button("🔍 Extraire le texte") |
|
|
text_output = gr.Textbox(lines=10, interactive=False) |
|
|
extract_btn.click(fn=extract_text, inputs=image_input, outputs=text_output) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = launch_gradio() |
|
|
demo.launch(share=True) |