Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| from openai import OpenAI | |
| # Load sensitive information from environment variables | |
| RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY') | |
| RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID') | |
| BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1" | |
| MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod | |
| MAX_TOKENS = 4096# Max tokens for the model response | |
| # --- OpenAI Client Initialization --- | |
| # Check if the API key is provided | |
| if not RUNPOD_API_KEY: | |
| raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable or add it directly in the script.") | |
| # Initialize the OpenAI client to connect to the RunPod endpoint | |
| client = OpenAI( | |
| api_key=RUNPOD_API_KEY, | |
| base_url=BASE_URL, | |
| ) | |
| # --- Gradio App Configuration --- | |
| title = "Python Maintainability Refactoring (RunPod)" | |
| description = """ | |
| ## Instructions for Using the Model | |
| ### Model Loading Time: | |
| - Please allow time for the model on RunPod to initialize if it's starting fresh ("Cold Start"). | |
| ### Code Submission: | |
| - You can enter or paste your Python code you wish to have refactored, or use the provided example. | |
| ### Python Code Constraints: | |
| - Keep the code reasonably sized. While the 120-line limit was for the previous setup, large code blocks might still face limitations depending on the RunPod instance and model constraints. Max response length is set to {} tokens. | |
| ### Understanding Changes: | |
| - It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made. | |
| ### Usage Recommendation: | |
| - Intended for research and evaluation purposes. | |
| """.format(MAX_TOKENS) | |
| system_prompt = """### Instruction: | |
| Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics. | |
| ### Input: | |
| """ | |
| css = """.toast-wrap { display: none !important } """ | |
| examples = [ | |
| ["""def analyze_sales_data(sales_records): | |
| active_sales = filter(lambda record: record['status'] == 'active', sales_records) | |
| sales_by_category = {} | |
| for record in active_sales: | |
| category = record['category'] | |
| total_sales = record['units_sold'] * record['price_per_unit'] | |
| if category not in sales_by_category: | |
| sales_by_category[category] = {'total_sales': 0, 'total_units': 0} | |
| sales_by_category[category]['total_sales'] += total_sales | |
| sales_by_category[category]['total_units'] += record['units_sold'] | |
| average_sales_data = [] | |
| for category, data in sales_by_category.items(): | |
| average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero | |
| sales_by_category[category]['average_sales'] = average_sales | |
| average_sales_data.append((category, average_sales)) | |
| average_sales_data.sort(key=lambda x: x[1], reverse=True) | |
| for rank, (category, _) in enumerate(average_sales_data, start=1): | |
| sales_by_category[category]['rank'] = rank | |
| return sales_by_category"""], | |
| ["""import pandas as pd | |
| import re | |
| import ast | |
| from code_bert_score import score # Assuming this library is available in the environment | |
| import numpy as np | |
| def preprocess_code(source_text): | |
| def remove_comments_and_docstrings(source_code): | |
| # Remove single-line comments | |
| source_code = re.sub(r'#.*', '', source_code) | |
| # Remove multi-line strings (docstrings) | |
| source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL) | |
| return source_code.strip() # Added strip | |
| # Pattern to extract code specifically from markdown blocks if present | |
| pattern = r"```python\s+(.+?)\s+```" | |
| matches = re.findall(pattern, source_text, re.DOTALL) | |
| code_to_process = '\n'.join(matches) if matches else source_text | |
| cleaned_code = remove_comments_and_docstrings(code_to_process) | |
| return cleaned_code | |
| def evaluate_dataframe(df): | |
| results = {'P': [], 'R': [], 'F1': [], 'F3': []} | |
| for index, row in df.iterrows(): | |
| try: | |
| # Ensure inputs are lists of strings | |
| cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion | |
| refs = [preprocess_code(str(row['output']))] # Added str() conversion | |
| # Ensure code_bert_score.score returns four values | |
| score_results = score(cands, refs, lang='python') | |
| if len(score_results) == 4: | |
| P, R, F1, F3 = score_results | |
| results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output | |
| results['R'].append(R.item() if hasattr(R, 'item') else R) | |
| results['F1'].append(F1.item() if hasattr(F1, 'item') else F1) | |
| results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned | |
| else: | |
| print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.") | |
| for key in results.keys(): | |
| results[key].append(np.nan) # Append NaN for unexpected format | |
| except Exception as e: | |
| print(f"Error processing row {index}: {e}") | |
| for key in results.keys(): | |
| results[key].append(np.nan) # Use NaN for errors | |
| df_metrics = pd.DataFrame(results) | |
| return df_metrics | |
| def evaluate_dataframe_multiple_runs(df, runs=3): | |
| all_results = [] | |
| print(f"Starting evaluation for {runs} runs...") | |
| for run in range(runs): | |
| print(f"Run {run + 1}/{runs}") | |
| df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified | |
| all_results.append(df_metrics) | |
| print(f"Run {run + 1} completed.") | |
| if not all_results: | |
| print("No results collected.") | |
| return pd.DataFrame(), pd.DataFrame() | |
| # Concatenate results and calculate statistics | |
| try: | |
| concatenated_results = pd.concat(all_results) | |
| df_metrics_mean = concatenated_results.groupby(level=0).mean() | |
| df_metrics_std = concatenated_results.groupby(level=0).std() | |
| print("Mean and standard deviation calculated.") | |
| except Exception as e: | |
| print(f"Error calculating statistics: {e}") | |
| # Return empty DataFrames or handle as appropriate | |
| return pd.DataFrame(), pd.DataFrame() | |
| return df_metrics_mean, df_metrics_std"""] | |
| ] | |
| # --- Core Logic --- | |
| def gen_solution(prompt): | |
| """ | |
| Generates a solution for a given problem prompt by calling the LLM via RunPod. | |
| Parameters: | |
| - prompt (str): The problem prompt including the system message and user input. | |
| Returns: | |
| - str: The generated solution text, or an error message. | |
| """ | |
| try: | |
| # Call the OpenAI compatible endpoint on RunPod | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.1, # Keep temperature low for more deterministic refactoring | |
| top_p=1.0, | |
| max_tokens=MAX_TOKENS, | |
| # stream=False # Explicitly setting stream to False (default) | |
| ) | |
| # Extract the response content | |
| response_content = completion.choices[0].message.content | |
| return response_content | |
| except Exception as e: | |
| print(f"Error calling RunPod API: {e}") | |
| # Provide a user-friendly error message | |
| return f"Error: Could not get response from the model. Details: {str(e)}" | |
| # --- Gradio Interface Function --- | |
| def predict(message, history): | |
| """ | |
| Handles the user input, calls the backend model, and returns the response. | |
| 'history' parameter is required by gr.ChatInterface but might not be used here. | |
| """ | |
| # Construct the full prompt | |
| input_prompt = system_prompt + str(message) # Using the format from the original code | |
| # Get the refactored code from the backend | |
| refactored_code_response = gen_solution(input_prompt) | |
| # The response is returned directly to the ChatInterface | |
| return refactored_code_response | |
| # --- Launch Gradio Interface --- | |
| # Use gr.ChatInterface for a chat-like experience | |
| gr.ChatInterface( | |
| predict, | |
| chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation"), | |
| textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."), | |
| title=title, | |
| description=description, | |
| theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default() | |
| examples=examples, | |
| cache_examples=False, # Consider enabling caching if examples don't change often | |
| submit_btn="Submit Code", | |
| retry_btn="Retry", | |
| undo_btn="Undo", | |
| clear_btn="Clear", | |
| css=css # Apply custom CSS if needed | |
| ).queue().launch(share=True) # share=True creates a public link (use with caution) | |