madibaalbert commited on
Commit
c65a8ef
·
verified ·
1 Parent(s): 94ec00f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -57
app.py CHANGED
@@ -1,72 +1,100 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
- import os
 
 
4
 
5
- # --- CONFIGURATION DU MODÈLE ---
6
- # Utilisation d'un modèle léger pour garantir la fluidité sur le CPU gratuit de Hugging Face
7
- MODEL_ID = "distilbert-base-uncased-finetuned-sst-2-english"
8
 
9
- print(f"Loading model: {MODEL_ID}...")
10
- # On initialise le pipeline d'analyse de sentiment
11
- # Note: Le téléchargement se fait automatiquement au premier lancement dans le Space
12
- classifier = pipeline("sentiment-analysis", model=MODEL_ID)
13
 
14
- def predict_sentiment(text):
 
 
 
 
15
  """
16
- Fonction de traitement pour l'inférence.
17
- Elle sera exposée via l'endpoint API.
18
  """
19
- if not text or text.strip() == "":
20
- return "Veuillez entrer un texte valide."
21
 
22
- try:
23
- results = classifier(text)
24
- label = results[0]['label']
25
- score = round(results[0]['score'], 4)
26
- return f"Sentiment: {label} (Confiance: {score})"
27
- except Exception as e:
28
- return f"Erreur lors de l'analyse: {str(e)}"
29
-
30
- # --- CONSTRUCTION DE L'INTERFACE GRADIO ---
31
- # Utilisation de gr.Blocks pour un design plus "Pro" (OmniGroup Style)
32
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
33
- gr.Markdown(
34
- """
35
- # 🌐 OmniGroup AI - Sentiment Analysis API
36
- ### Déploiement souverain pour l'écosystème Pangea.
37
- Ce Space expose un endpoint gratuit pour analyser les sentiments textuels.
38
- """
39
  )
40
 
41
- with gr.Row():
42
- with gr.Column():
43
- input_text = gr.Textbox(
44
- label="Texte à analyser",
45
- placeholder="Entrez votre phrase ici...",
46
- lines=3
47
- )
48
- submit_btn = gr.Button("Analyser", variant="primary")
49
-
50
- with gr.Column():
51
- output_text = gr.Textbox(label="Résultat de l'API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Liaison du bouton et de la touche 'Entrée' à la fonction de prédiction
54
- submit_btn.click(fn=predict_sentiment, inputs=input_text, outputs=output_text, api_name="predict")
55
- input_text.submit(fn=predict_sentiment, inputs=input_text, outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- gr.Markdown(
58
- """
59
- ---
60
- **Note technique :** Pour utiliser cet endpoint via Python, utilisez le `gradio_client` :
61
- ```python
62
- from gradio_client import Client
63
- client = Client("votre-username/nom-du-space")
64
- result = client.predict("Texte", api_name="/predict")
65
- ```
66
- """
67
  )
68
 
69
- # --- LANCEMENT ---
70
  if __name__ == "__main__":
71
- # Hugging Face Spaces nécessite que le serveur écoute sur 0.0.0.0:7860 (par défaut via launch)
72
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ import time
4
+ import json
5
+ import torch
6
 
7
+ # --- CONFIGURATION OMNIGROUP ---
8
+ # On utilise un modèle compact mais puissant pour le CPU gratuit
9
+ MODEL_ID = "HuggingFaceTB/SmolLM-135M-Instruct"
10
 
11
+ print(f"Initialisation du moteur Pangea sur {MODEL_ID}...")
 
 
 
12
 
13
+ # Chargement du tokenizer et du modèle
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
16
+
17
+ def generate_response(prompt, max_tokens=128, temperature=0.7):
18
  """
19
+ Génère une réponse avec calcul du débit (tokens/s)
 
20
  """
21
+ start_time = time.time()
 
22
 
23
+ # Encodage
24
+ inputs = tokenizer(prompt, return_tensors="pt")
25
+ input_length = inputs.input_ids.shape[1]
26
+
27
+ # Génération
28
+ outputs = model.generate(
29
+ **inputs,
30
+ max_new_tokens=max_tokens,
31
+ temperature=temperature,
32
+ do_sample=True,
33
+ pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
34
  )
35
 
36
+ end_time = time.time()
37
+
38
+ # Décodage
39
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ # Extraire uniquement la nouvelle réponse (après le prompt)
41
+ new_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
42
+
43
+ # Métriques
44
+ duration = end_time - start_time
45
+ tokens_generated = len(outputs[0]) - input_length
46
+ tokens_per_sec = round(tokens_generated / duration, 2) if duration > 0 else 0
47
+
48
+ # Construction du JSON (Format Gemini-like)
49
+ json_output = {
50
+ "id": f"omni-{int(start_time)}",
51
+ "object": "text_completion",
52
+ "created": int(start_time),
53
+ "model": MODEL_ID,
54
+ "choices": [{
55
+ "text": new_text,
56
+ "index": 0,
57
+ "finish_reason": "stop"
58
+ }],
59
+ "usage": {
60
+ "prompt_tokens": input_length,
61
+ "completion_tokens": tokens_generated,
62
+ "total_tokens": input_length + tokens_generated,
63
+ "speed": f"{tokens_per_sec} tokens/s"
64
+ }
65
+ }
66
+
67
+ return new_text, json.dumps(json_output, indent=2), f"{tokens_per_sec} t/s"
68
 
69
+ # --- INTERFACE GRADIO PRO ---
70
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
71
+ gr.Markdown("# 🚀 OmniGroup Pangea API v2")
72
+ gr.Markdown("Endpoint haute performance avec métriques de débit en temps réel.")
73
+
74
+ with gr.Row():
75
+ with gr.Column(scale=2):
76
+ input_text = gr.Textbox(label="Prompt", placeholder="Posez une question à l'IA...", lines=5)
77
+ with gr.Row():
78
+ slider_tokens = gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens")
79
+ slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Température")
80
+ submit_btn = gr.Button("Générer l'inférence", variant="primary")
81
+
82
+ with gr.Column(scale=1):
83
+ speed_metric = gr.Label(label="Vitesse d'exécution (Débit)")
84
+
85
+ with gr.Tabs():
86
+ with gr.TabItem("Réponse Texte"):
87
+ output_text = gr.Textbox(label="Sortie Brute", lines=10)
88
+ with gr.TabItem("Réponse JSON (Format API)"):
89
+ output_json = gr.Code(label="JSON Payload", language="json")
90
 
91
+ # Mapping des fonctions
92
+ submit_btn.click(
93
+ fn=generate_response,
94
+ inputs=[input_text, slider_tokens, slider_temp],
95
+ outputs=[output_text, output_json, speed_metric],
96
+ api_name="chat" # L'endpoint sera /chat
 
 
 
 
97
  )
98
 
 
99
  if __name__ == "__main__":
 
100
  demo.launch(server_name="0.0.0.0", server_port=7860)