waliaMuskaan011 commited on
Commit
790eff3
·
verified ·
1 Parent(s): 5d38601

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +71 -62
app.py CHANGED
@@ -1,88 +1,97 @@
1
  import gradio as gr
2
  import json
3
  import torch
4
- from functools import lru_cache
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
 
8
- # Load model and tokenizer
9
- @lru_cache(maxsize=1)
 
 
10
  def load_model():
11
- print("Loading model...")
12
- base_model = AutoModelForCausalLM.from_pretrained(
13
- "HuggingFaceTB/SmolLM-360M",
14
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
15
- device_map="auto"
16
- )
17
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M")
 
 
 
 
18
  if tokenizer.pad_token is None:
19
  tokenizer.pad_token = tokenizer.eos_token
20
 
21
- # Load LoRA adapters
 
 
 
 
 
 
22
  model = PeftModel.from_pretrained(base_model, "waliaMuskaan011/calendar-event-extractor-smollm")
23
- model.eval()
24
- print("Model loaded successfully!")
25
  return model, tokenizer
26
 
27
- model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def extract_calendar_event(event_text):
30
- """Extract calendar information from natural language text."""
31
-
32
  if not event_text.strip():
33
  return "Please enter some text describing a calendar event."
34
 
35
- # Build prompt
36
- prompt = f"""Extract calendar fields from: "{event_text}".
37
- Return ONLY valid JSON with keys [action,date,time,attendees,location,duration,recurrence,notes].
38
- Use null for unknown.
39
- """
40
-
41
  try:
42
- # Tokenize and generate
43
- inputs = tokenizer(prompt, return_tensors="pt")
44
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
45
 
 
46
  with torch.no_grad():
47
  outputs = model.generate(
48
- **inputs,
49
- max_new_tokens=160,
50
- temperature=0.0,
51
  do_sample=False,
52
- pad_token_id=tokenizer.eos_token_id
53
  )
54
 
55
- # Decode response
56
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
57
 
58
- # Robustly find the first complete JSON object in the output
59
- def _find_first_json(text: str):
60
- start = text.find("{")
61
- if start == -1:
62
- return None
63
- depth = 0
64
- for i in range(start, len(text)):
65
- ch = text[i]
66
- if ch == "{":
67
- depth += 1
68
- elif ch == "}":
69
- depth -= 1
70
- if depth == 0:
71
- return text[start:i+1]
72
- return None
73
-
74
- json_part = _find_first_json(full_response)
75
- if json_part is None and full_response.startswith(prompt):
76
- json_part = _find_first_json(full_response[len(prompt):])
77
-
78
- if json_part:
79
- try:
80
- parsed = json.loads(json_part)
81
- return json.dumps(parsed, indent=2, ensure_ascii=False)
82
- except json.JSONDecodeError:
83
- return "Generated (may need manual cleanup):\n" + json_part
84
  else:
85
- return "No JSON found.\n" + full_response
86
 
87
  except Exception as e:
88
  return f"Error processing request: {str(e)}"
@@ -127,17 +136,17 @@ with gr.Blocks(title="Calendar Event Extractor", theme=gr.themes.Soft()) as demo
127
  ],
128
  inputs=[input_text],
129
  outputs=[output_json],
130
- fn=extract_calendar_event,
131
  cache_examples=False
132
  )
133
 
134
  extract_btn.click(
135
- fn=extract_calendar_event,
136
  inputs=[input_text],
137
  outputs=[output_json]
138
  )
139
 
140
- gr.Markdown("""
141
  ---
142
  **Model Details**: Fine-tuned SmolLM-360M using LoRA • **Dataset**: ~2500 calendar events • **Training**: Custom augmentation pipeline
143
 
 
1
  import gradio as gr
2
  import json
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
  from peft import PeftModel
6
 
7
+ # Global variables for model and tokenizer
8
+ model = None
9
+ tokenizer = None
10
+
11
  def load_model():
12
+ """Load the fine-tuned model and tokenizer."""
13
+ global model, tokenizer
14
+
15
+ if model is not None and tokenizer is not None:
16
+ return model, tokenizer
17
+
18
+ print("🔄 Loading fine-tuned model...")
19
+
20
+ # Load base model and tokenizer
21
+ base_model_id = "HuggingFaceTB/SmolLM-360M"
22
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ # Load base model
27
+ base_model = AutoModelForCausalLM.from_pretrained(
28
+ base_model_id,
29
+ torch_dtype=torch.float32,
30
+ )
31
+
32
+ # Load fine-tuned adapter
33
  model = PeftModel.from_pretrained(base_model, "waliaMuskaan011/calendar-event-extractor-smollm")
34
+
35
+ print("Model loaded successfully!")
36
  return model, tokenizer
37
 
38
+ def extract_json_from_text(text):
39
+ """Extract the first JSON object from text."""
40
+ try:
41
+ # Find first { and matching }
42
+ start = text.find('{')
43
+ if start == -1:
44
+ return None
45
+
46
+ depth = 0
47
+ for i in range(start, len(text)):
48
+ if text[i] == '{':
49
+ depth += 1
50
+ elif text[i] == '}':
51
+ depth -= 1
52
+ if depth == 0:
53
+ json_str = text[start:i+1]
54
+ return json.loads(json_str)
55
+ return None
56
+ except (json.JSONDecodeError, TypeError, ValueError):
57
+ return None
58
 
59
+ def predict_calendar_event(event_text):
60
+ """Extract calendar information from event text."""
 
61
  if not event_text.strip():
62
  return "Please enter some text describing a calendar event."
63
 
 
 
 
 
 
 
64
  try:
65
+ # Load model
66
+ model, tokenizer = load_model()
67
+
68
+ # Create prompt - same format as test_model.py
69
+ prompt = f"Extract calendar information from: {event_text}\nCalendar JSON:"
70
+
71
+ # Tokenize
72
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
73
 
74
+ # Generate
75
  with torch.no_grad():
76
  outputs = model.generate(
77
+ inputs.input_ids,
78
+ attention_mask=inputs.attention_mask,
79
+ max_new_tokens=150,
80
  do_sample=False,
81
+ pad_token_id=tokenizer.eos_token_id,
82
  )
83
 
84
+ # Decode
85
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+ generated_text = full_response[len(prompt):].strip()
87
 
88
+ # Extract JSON
89
+ extracted_json = extract_json_from_text(generated_text)
90
+
91
+ if extracted_json:
92
+ return json.dumps(extracted_json, indent=2, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
+ return f"Could not extract valid JSON. Raw output: {generated_text[:200]}..."
95
 
96
  except Exception as e:
97
  return f"Error processing request: {str(e)}"
 
136
  ],
137
  inputs=[input_text],
138
  outputs=[output_json],
139
+ fn=predict_calendar_event,
140
  cache_examples=False
141
  )
142
 
143
  extract_btn.click(
144
+ fn=predict_calendar_event,
145
  inputs=[input_text],
146
  outputs=[output_json]
147
  )
148
 
149
+ gr.Markdown(f"""
150
  ---
151
  **Model Details**: Fine-tuned SmolLM-360M using LoRA • **Dataset**: ~2500 calendar events • **Training**: Custom augmentation pipeline
152