test2 / app.py
om440's picture
Update app.py
e034032 verified
import os
import torch
from PIL import Image
import gradio as gr
from datasets import load_dataset
from transformers import AutoProcessor, MllamaForConditionalGeneration, Trainer, TrainingArguments
# === CONFIG ===
CKPT = "unsloth/Llama-3.2-11B-Vision-Instruct"
MODEL_SAVE_PATH = "./llama3-handwriting-ocr-simple"
DATASET_NAME = "om440/partial-rimes-handwritten-small"
MAX_TOKENS = 128
TRAIN_EXAMPLES = 20
# === PROMPT ===
TRAINING_PROMPT = """Extract all text from this image exactly as written.
Include both handwritten and printed text.
Preserve original formatting and line breaks."""
# === LOAD MODEL ET PROCESSOR ===
def load_model():
if os.path.exists(MODEL_SAVE_PATH):
model = MllamaForConditionalGeneration.from_pretrained(MODEL_SAVE_PATH)
processor = AutoProcessor.from_pretrained(MODEL_SAVE_PATH)
else:
model = MllamaForConditionalGeneration.from_pretrained(CKPT)
processor = AutoProcessor.from_pretrained(CKPT)
return model, processor
# === PREPROCESSING SIMPLE ===
def preprocess(examples, processor):
images = []
for img in examples["image"]:
if hasattr(img, "convert"):
image = img.convert("RGB")
else:
image = Image.fromarray(img).convert("RGB")
image.thumbnail((336, 336), Image.Resampling.LANCZOS)
images.append(image)
inputs = processor(
text=[TRAINING_PROMPT] * len(images),
images=images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64,
)
labels = processor.tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=MAX_TOKENS,
return_tensors="pt"
)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"pixel_values": inputs["pixel_values"],
"labels": labels["input_ids"]
}
# === TRAINING ===
def train(num_examples=TRAIN_EXAMPLES):
dataset = load_dataset(DATASET_NAME)["train"].select(range(num_examples))
model, processor = load_model()
tokenized = dataset.map(
lambda ex: preprocess(ex, processor),
batched=True,
batch_size=1,
remove_columns=dataset.column_names
)
training_args = TrainingArguments(
output_dir=MODEL_SAVE_PATH,
per_device_train_batch_size=1,
num_train_epochs=1,
logging_steps=1,
save_steps=10,
save_total_limit=1,
remove_unused_columns=False,
report_to="none",
gradient_accumulation_steps=1,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
tokenizer=processor.tokenizer,
data_collator=lambda data: {
"input_ids": torch.stack([f["input_ids"] for f in data]),
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
"pixel_values": torch.stack([f["pixel_values"] for f in data]),
"labels": torch.stack([f["labels"] for f in data]),
}
)
trainer.train()
model.save_pretrained(MODEL_SAVE_PATH)
processor.save_pretrained(MODEL_SAVE_PATH)
return f"✅ Entraînement terminé sur {num_examples} exemples."
# === INFERENCE ===
def extract_text(image_path):
model, processor = load_model()
image = Image.open(image_path).convert("RGB")
image.thumbnail((336, 336), Image.Resampling.LANCZOS)
prompt = TRAINING_PROMPT
inputs = processor(
text=prompt,
images=[image],
return_tensors="pt"
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
do_sample=False,
)
result = processor.decode(outputs[0], skip_special_tokens=True)
if prompt in result:
result = result.replace(prompt, "").strip()
return result or "Aucun texte détecté."
# === GRADIO UI ===
def launch_gradio():
with gr.Blocks() as demo:
gr.Markdown("# LLaMA 3.2 Vision OCR simple")
with gr.Tab("Entraînement"):
train_btn = gr.Button("🚀 Entraîner sur 20 exemples")
train_output = gr.Textbox(lines=4, interactive=False)
train_btn.click(fn=train, inputs=[], outputs=train_output)
with gr.Tab("Extraction"):
image_input = gr.Image(type="filepath", label="Uploader une image")
extract_btn = gr.Button("🔍 Extraire le texte")
text_output = gr.Textbox(lines=10, interactive=False)
extract_btn.click(fn=extract_text, inputs=image_input, outputs=text_output)
return demo
if __name__ == "__main__":
demo = launch_gradio()
demo.launch(share=True)