Spaces:
Sleeping
Sleeping
| import json | |
| import gradio as gr | |
| import os | |
| import requests | |
| from huggingface_hub import AsyncInferenceClient | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| api_url = os.getenv('API_URL') | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| client = AsyncInferenceClient(api_url) | |
| system_message = """ | |
| Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics. | |
| """ | |
| title = "Python Refactoring" | |
| description = """ | |
| Please give it 3 to 4 minutes for the model to load and Run , consider using Python code with less than 120 lines of code due to GPU constrainst | |
| """ | |
| css = """.toast-wrap { display: none !important } """ | |
| examples=[[""" | |
| import pandas as pd | |
| import re | |
| import ast | |
| from code_bert_score import score | |
| import numpy as np | |
| def preprocess_code(source_text): | |
| def remove_comments_and_docstrings(source_code): | |
| source_code = re.sub(r'#.*', '', source_code) | |
| source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL) | |
| return source_code | |
| pattern = r"```python\s+(.+?)\s+```" | |
| matches = re.findall(pattern, source_text, re.DOTALL) | |
| code_to_process = '\n'.join(matches) if matches else source_text | |
| cleaned_code = remove_comments_and_docstrings(code_to_process) | |
| return cleaned_code | |
| def evaluate_dataframe(df): | |
| results = {'P': [], 'R': [], 'F1': [], 'F3': []} | |
| for index, row in df.iterrows(): | |
| try: | |
| cands = [preprocess_code(row['generated_text'])] | |
| refs = [preprocess_code(row['output'])] | |
| P, R, F1, F3 = score(cands, refs, lang='python') | |
| results['P'].append(P[0]) | |
| results['R'].append(R[0]) | |
| results['F1'].append(F1[0]) | |
| results['F3'].append(F3[0]) | |
| except Exception as e: | |
| print(f"Error processing row {index}: {e}") | |
| for key in results.keys(): | |
| results[key].append(None) | |
| df_metrics = pd.DataFrame(results) | |
| return df_metrics | |
| def evaluate_dataframe_multiple_runs(df, runs=3): | |
| all_results = [] | |
| for run in range(runs): | |
| df_metrics = evaluate_dataframe(df) | |
| all_results.append(df_metrics) | |
| # Calculate mean and std deviation of metrics across runs | |
| df_metrics_mean = pd.concat(all_results).groupby(level=0).mean() | |
| df_metrics_std = pd.concat(all_results).groupby(level=0).std() | |
| return df_metrics_mean, df_metrics_std | |
| """ ] , | |
| [""" | |
| def analyze_sales_data(sales_records): | |
| active_sales = filter(lambda record: record['status'] == 'active', sales_records) | |
| sales_by_category = {} | |
| for record in active_sales: | |
| category = record['category'] | |
| total_sales = record['units_sold'] * record['price_per_unit'] | |
| if category not in sales_by_category: | |
| sales_by_category[category] = {'total_sales': 0, 'total_units': 0} | |
| sales_by_category[category]['total_sales'] += total_sales | |
| sales_by_category[category]['total_units'] += record['units_sold'] | |
| average_sales_data = [] | |
| for category, data in sales_by_category.items(): | |
| average_sales = data['total_sales'] / data['total_units'] | |
| sales_by_category[category]['average_sales'] = average_sales | |
| average_sales_data.append((category, average_sales)) | |
| average_sales_data.sort(key=lambda x: x[1], reverse=True) | |
| for rank, (category, _) in enumerate(average_sales_data, start=1): | |
| sales_by_category[category]['rank'] = rank | |
| return sales_by_category | |
| """]] | |
| # Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023] | |
| # Prompting style for Llama2 without using system prompt | |
| # <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST] | |
| # Stream text - stream tokens with InferenceClient from TGI | |
| async def predict(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.1,): | |
| if system_prompt != "": | |
| input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
| else: | |
| input_prompt = f"<s>[INST] " | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| for interaction in chatbot: | |
| input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
| input_prompt = input_prompt + str(message) + " [/INST] " | |
| partial_message = "" | |
| async for token in await client.text_generation(prompt=input_prompt, | |
| max_new_tokens=max_new_tokens, | |
| stream=True, | |
| best_of=1, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| repetition_penalty=repetition_penalty): | |
| partial_message = partial_message + token | |
| yield partial_message | |
| # No Stream - batch produce tokens using TGI inference endpoint | |
| def predict_batch(message, chatbot, system_prompt="", temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.1): | |
| if system_prompt != "": | |
| input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
| else: | |
| input_prompt = f"<s>[INST] " | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| for interaction in chatbot: | |
| input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
| input_prompt = input_prompt + str(message) + " [/INST] " | |
| print(f"input_prompt - {input_prompt}") | |
| data = { | |
| "inputs": input_prompt, | |
| "parameters": { | |
| "max_new_tokens":max_new_tokens, | |
| "temperature":temperature, | |
| "top_p":top_p, | |
| "repetition_penalty":repetition_penalty, | |
| "do_sample":True, | |
| }, | |
| } | |
| response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data), | |
| if response.status_code == 200: # check if the request was successful | |
| try: | |
| json_obj = response.json() | |
| if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0: | |
| return json_obj[0]['generated_text'] | |
| elif 'error' in json_obj[0]: | |
| return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt' | |
| else: | |
| print(f"Unexpected response: {json_obj[0]}") | |
| except json.JSONDecodeError: | |
| print(f"Failed to decode response as JSON: {response.text}") | |
| else: | |
| print(f"Request failed with status code {response.status_code}") | |
| def vote(data: gr.LikeData): | |
| if data.liked: | |
| print("You upvoted this response: " + data.value) | |
| else: | |
| print("You downvoted this response: " + data.value) | |
| additional_inputs=[ | |
| gr.Textbox("", label="Optional system prompt"), | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=4096, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.6, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) | |
| chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False) | |
| chat_interface_stream = gr.ChatInterface(predict, | |
| title=title, | |
| description=description, | |
| textbox=gr.Textbox(), | |
| chatbot=chatbot_stream, | |
| css=css, | |
| examples=examples, | |
| #cache_examples=True, | |
| additional_inputs=additional_inputs,) | |
| chat_interface_batch=gr.ChatInterface(predict_batch, | |
| title=title, | |
| description=description, | |
| textbox=gr.Textbox(), | |
| chatbot=chatbot_batch, | |
| css=css, | |
| examples=examples, | |
| #cache_examples=True, | |
| additional_inputs=additional_inputs,) | |
| # Gradio Demo | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Streaming"): | |
| # streaming chatbot | |
| chatbot_stream.like(vote, None, None) | |
| chat_interface_stream.render() | |
| with gr.Tab("Batch"): | |
| # non-streaming chatbot | |
| chatbot_batch.like(vote, None, None) | |
| chat_interface_batch.render() | |
| demo.queue(max_size=2).launch() |