import os import torch from PIL import Image import gradio as gr from datasets import load_dataset from transformers import AutoProcessor, MllamaForConditionalGeneration, Trainer, TrainingArguments # === CONFIG === CKPT = "unsloth/Llama-3.2-11B-Vision-Instruct" MODEL_SAVE_PATH = "./llama3-handwriting-ocr-simple" DATASET_NAME = "om440/partial-rimes-handwritten-small" MAX_TOKENS = 128 TRAIN_EXAMPLES = 20 # === PROMPT === TRAINING_PROMPT = """Extract all text from this image exactly as written. Include both handwritten and printed text. Preserve original formatting and line breaks.""" # === LOAD MODEL ET PROCESSOR === def load_model(): if os.path.exists(MODEL_SAVE_PATH): model = MllamaForConditionalGeneration.from_pretrained(MODEL_SAVE_PATH) processor = AutoProcessor.from_pretrained(MODEL_SAVE_PATH) else: model = MllamaForConditionalGeneration.from_pretrained(CKPT) processor = AutoProcessor.from_pretrained(CKPT) return model, processor # === PREPROCESSING SIMPLE === def preprocess(examples, processor): images = [] for img in examples["image"]: if hasattr(img, "convert"): image = img.convert("RGB") else: image = Image.fromarray(img).convert("RGB") image.thumbnail((336, 336), Image.Resampling.LANCZOS) images.append(image) inputs = processor( text=[TRAINING_PROMPT] * len(images), images=images, return_tensors="pt", padding=True, truncation=True, max_length=64, ) labels = processor.tokenizer( examples["text"], padding="max_length", truncation=True, max_length=MAX_TOKENS, return_tensors="pt" ) return { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "pixel_values": inputs["pixel_values"], "labels": labels["input_ids"] } # === TRAINING === def train(num_examples=TRAIN_EXAMPLES): dataset = load_dataset(DATASET_NAME)["train"].select(range(num_examples)) model, processor = load_model() tokenized = dataset.map( lambda ex: preprocess(ex, processor), batched=True, batch_size=1, remove_columns=dataset.column_names ) training_args = TrainingArguments( output_dir=MODEL_SAVE_PATH, per_device_train_batch_size=1, num_train_epochs=1, logging_steps=1, save_steps=10, save_total_limit=1, remove_unused_columns=False, report_to="none", gradient_accumulation_steps=1, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized, tokenizer=processor.tokenizer, data_collator=lambda data: { "input_ids": torch.stack([f["input_ids"] for f in data]), "attention_mask": torch.stack([f["attention_mask"] for f in data]), "pixel_values": torch.stack([f["pixel_values"] for f in data]), "labels": torch.stack([f["labels"] for f in data]), } ) trainer.train() model.save_pretrained(MODEL_SAVE_PATH) processor.save_pretrained(MODEL_SAVE_PATH) return f"✅ Entraînement terminé sur {num_examples} exemples." # === INFERENCE === def extract_text(image_path): model, processor = load_model() image = Image.open(image_path).convert("RGB") image.thumbnail((336, 336), Image.Resampling.LANCZOS) prompt = TRAINING_PROMPT inputs = processor( text=prompt, images=[image], return_tensors="pt" ) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, do_sample=False, ) result = processor.decode(outputs[0], skip_special_tokens=True) if prompt in result: result = result.replace(prompt, "").strip() return result or "Aucun texte détecté." # === GRADIO UI === def launch_gradio(): with gr.Blocks() as demo: gr.Markdown("# LLaMA 3.2 Vision OCR simple") with gr.Tab("Entraînement"): train_btn = gr.Button("🚀 Entraîner sur 20 exemples") train_output = gr.Textbox(lines=4, interactive=False) train_btn.click(fn=train, inputs=[], outputs=train_output) with gr.Tab("Extraction"): image_input = gr.Image(type="filepath", label="Uploader une image") extract_btn = gr.Button("🔍 Extraire le texte") text_output = gr.Textbox(lines=10, interactive=False) extract_btn.click(fn=extract_text, inputs=image_input, outputs=text_output) return demo if __name__ == "__main__": demo = launch_gradio() demo.launch(share=True)