Sgridda commited on
Commit
d1e0f9b
·
1 Parent(s): 8eb5cff

made it simple

Browse files
Files changed (3) hide show
  1. main.py +9 -4
  2. main_lightweight.py +139 -0
  3. main_simple.py +66 -0
main.py CHANGED
@@ -99,12 +99,17 @@ def run_ai_inference(diff: str) -> str:
99
 
100
  inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
101
 
 
 
 
102
  # Optimized generation parameters for speed
103
  outputs = model.generate(
104
- inputs,
105
- max_new_tokens=256, # Reduced from 1024
106
- do_sample=False,
107
- temperature=0.1, # Lower temperature for more focused output
 
 
108
  num_return_sequences=1,
109
  eos_token_id=tokenizer.eos_token_id,
110
  pad_token_id=tokenizer.eos_token_id,
 
99
 
100
  inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
101
 
102
+ # Create attention mask to avoid warnings and improve reliability
103
+ attention_mask = torch.ones_like(inputs)
104
+
105
  # Optimized generation parameters for speed
106
  outputs = model.generate(
107
+ inputs,
108
+ attention_mask=attention_mask,
109
+ max_new_tokens=128, # Further reduced for faster generation
110
+ do_sample=True, # Enable sampling to use temperature
111
+ temperature=0.3, # Lower temperature for more focused output
112
+ top_p=0.9, # Nucleus sampling for better quality
113
  num_return_sequences=1,
114
  eos_token_id=tokenizer.eos_token_id,
115
  pad_token_id=tokenizer.eos_token_id,
main_lightweight.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import logging
5
+ import json
6
+ import re
7
+
8
+ # Ultra-lightweight version with minimal AI
9
+ app = FastAPI(
10
+ title="AI Code Review Service",
11
+ description="An API to get AI-powered code reviews for pull request diffs.",
12
+ version="1.0.0",
13
+ )
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Try to load a very small model, fall back to mock if it fails
20
+ model = None
21
+ tokenizer = None
22
+
23
+ def load_simple_model():
24
+ """Try to load the smallest possible model."""
25
+ global model, tokenizer
26
+ try:
27
+ from transformers import AutoTokenizer, AutoModelForCausalLM
28
+
29
+ # Use the smallest possible model
30
+ model_name = "distilgpt2" # Much smaller than TinyLlama
31
+ logger.info("Loading lightweight model: %s", model_name)
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ if tokenizer.pad_token is None:
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device_map="auto" if torch.cuda.is_available() else None,
41
+ )
42
+ logger.info("Model loaded successfully")
43
+ return True
44
+ except Exception as e:
45
+ logger.warning("Failed to load AI model: %s. Using mock responses.", str(e))
46
+ return False
47
+
48
+ # Try to load model on startup
49
+ model_loaded = load_simple_model()
50
+
51
+ class DiffRequest(BaseModel):
52
+ diff: str
53
+
54
+ class ReviewComment(BaseModel):
55
+ file_path: str
56
+ line_number: int
57
+ comment_text: str
58
+
59
+ class ReviewResponse(BaseModel):
60
+ comments: list[ReviewComment]
61
+
62
+ @app.get("/health")
63
+ def health_check():
64
+ """Health check endpoint."""
65
+ return {
66
+ "status": "healthy",
67
+ "service": "AI Code Review Service",
68
+ "model_loaded": model_loaded,
69
+ "model_name": "distilgpt2" if model_loaded else "mock",
70
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
71
+ }
72
+
73
+ def simple_ai_review(diff: str):
74
+ """Very simple AI review using the lightweight model."""
75
+ if not model_loaded or not model or not tokenizer:
76
+ return None
77
+
78
+ try:
79
+ # Very simple prompt
80
+ prompt = f"Review this code change and suggest improvements:\n{diff[:200]}\nSuggestion:"
81
+
82
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=256, truncation=True)
83
+
84
+ # Very conservative generation
85
+ with torch.no_grad():
86
+ outputs = model.generate(
87
+ inputs,
88
+ max_new_tokens=50, # Very short response
89
+ do_sample=False,
90
+ num_return_sequences=1,
91
+ pad_token_id=tokenizer.eos_token_id,
92
+ use_cache=True
93
+ )
94
+
95
+ response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
96
+ return response.strip()
97
+ except Exception as e:
98
+ logger.warning("AI generation failed: %s", str(e))
99
+ return None
100
+
101
+ @app.post("/review", response_model=ReviewResponse)
102
+ def review_diff(request: DiffRequest):
103
+ """Review endpoint with fallback to mock data."""
104
+ logger.info("Received diff for review (length: %d chars)", len(request.diff))
105
+
106
+ # Try AI first, fall back to mock
107
+ ai_suggestion = None
108
+ if model_loaded:
109
+ ai_suggestion = simple_ai_review(request.diff)
110
+
111
+ if ai_suggestion:
112
+ # Use AI suggestion
113
+ comments = [{
114
+ "file_path": "reviewed_file.py",
115
+ "line_number": 1,
116
+ "comment_text": ai_suggestion
117
+ }]
118
+ logger.info("Returning AI-generated review")
119
+ else:
120
+ # Fall back to mock comments
121
+ comments = [
122
+ {
123
+ "file_path": "example.py",
124
+ "line_number": 1,
125
+ "comment_text": "Consider adding error handling and input validation."
126
+ },
127
+ {
128
+ "file_path": "example.py",
129
+ "line_number": 5,
130
+ "comment_text": "This function could benefit from better documentation."
131
+ }
132
+ ]
133
+ logger.info("Returning mock review comments")
134
+
135
+ return ReviewResponse(comments=comments)
136
+
137
+ if __name__ == "__main__":
138
+ import uvicorn
139
+ uvicorn.run(app, host="0.0.0.0", port=7860)
main_simple.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import json
4
+ import logging
5
+
6
+ # Simple version without AI model for testing
7
+ app = FastAPI(
8
+ title="AI Code Review Service",
9
+ description="An API to get AI-powered code reviews for pull request diffs.",
10
+ version="1.0.0",
11
+ )
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class DiffRequest(BaseModel):
18
+ diff: str
19
+
20
+ class ReviewComment(BaseModel):
21
+ file_path: str
22
+ line_number: int
23
+ comment_text: str
24
+
25
+ class ReviewResponse(BaseModel):
26
+ comments: list[ReviewComment]
27
+
28
+ @app.get("/health")
29
+ def health_check():
30
+ """Health check endpoint."""
31
+ return {
32
+ "status": "healthy",
33
+ "service": "AI Code Review Service",
34
+ "model_loaded": False, # No model in simple version
35
+ "message": "Simple version - returns mock reviews"
36
+ }
37
+
38
+ @app.post("/review", response_model=ReviewResponse)
39
+ def review_diff(request: DiffRequest):
40
+ """
41
+ Mock review endpoint that returns sample comments.
42
+ Replace this with actual AI logic once the Space is working.
43
+ """
44
+ logger.info("Received diff for review (length: %d chars)", len(request.diff))
45
+
46
+ # Mock review comments
47
+ mock_comments = [
48
+ {
49
+ "file_path": "example.py",
50
+ "line_number": 1,
51
+ "comment_text": "Consider adding docstrings to improve code documentation."
52
+ },
53
+ {
54
+ "file_path": "example.py",
55
+ "line_number": 5,
56
+ "comment_text": "This function could benefit from error handling."
57
+ }
58
+ ]
59
+
60
+ logger.info("Returning %d mock review comments", len(mock_comments))
61
+
62
+ return ReviewResponse(comments=mock_comments)
63
+
64
+ if __name__ == "__main__":
65
+ import uvicorn
66
+ uvicorn.run(app, host="0.0.0.0", port=7860)