Spaces:
Build error
Build error
| import subprocess | |
| commands = 'pip uninstall gradio -y; echo "pwd is: $(pwd)"; pip install ./gradio-12.34.57.tar.gz' | |
| subprocess.run(commands, shell=True) | |
| import json | |
| import os | |
| import shutil | |
| import threading | |
| import gradio as gr | |
| from dialogues import DialogueTemplate | |
| from huggingface_hub import Repository | |
| from text_generation import Client | |
| from utils import get_full_text, wrap_html_code | |
| STYLE = """ | |
| // "done" class is injected when user has made | |
| // decision between two candidate generated answers | |
| .message.bot.done { | |
| animation: colorTransition 2s ease-in-out; | |
| } | |
| // fade out animation effect when user selects a choice | |
| @keyframes colorTransition { | |
| 0% { | |
| background-color: var(--checkbox-background-color-selected); | |
| } | |
| 100% { | |
| background-color: var(--background-fill-secondary); | |
| } | |
| } | |
| """ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| REPO_ID = "sheonhan/rm-test-data" | |
| API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1" | |
| LABELER_ID = "labeler_123" | |
| SESSION_ID = "session_123" | |
| client = Client( | |
| API_URL, | |
| headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
| ) | |
| repo = None | |
| if HF_TOKEN: | |
| try: | |
| shutil.rmtree("./data/") | |
| except: | |
| pass | |
| print("Pulling repo...") | |
| repo = Repository( | |
| local_dir="./data/", | |
| clone_from=REPO_ID, | |
| use_auth_token=HF_TOKEN, | |
| repo_type="dataset", | |
| ) | |
| repo.git_pull() | |
| system_message = "" | |
| def generate(user_message, history): | |
| past_messages = [] | |
| for data in history: | |
| user_data, model_data = data | |
| past_messages.extend( | |
| [ | |
| {"role": "user", "content": user_data}, | |
| {"role": "assistant", "content": model_data.rstrip()}, | |
| ] | |
| ) | |
| if len(past_messages) < 1: | |
| dialogue_template = DialogueTemplate( | |
| system=system_message, | |
| messages=[{"role": "user", "content": user_message}], | |
| end_token="<|endoftext|>", | |
| ) | |
| prompt = dialogue_template.get_inference_prompt() | |
| else: | |
| dialogue_template = DialogueTemplate( | |
| system=system_message, | |
| messages=past_messages + [{"role": "user", "content": user_message}], | |
| end_token="<|endoftext|>", | |
| ) | |
| prompt = dialogue_template.get_inference_prompt() | |
| response_1 = client.generate_stream( | |
| prompt, temperature=0.1, stop_sequences=["<|end|>"] | |
| ) | |
| response_2 = client.generate_stream( | |
| prompt, temperature=0.9, stop_sequences=["<|end|>"] | |
| ) | |
| response_1_text = get_full_text(response_1) | |
| response_2_text = get_full_text(response_2) | |
| option_a = wrap_html_code(response_1_text.strip()) | |
| option_b = wrap_html_code(response_2_text.strip()) | |
| option_a = f"A: {option_a}" | |
| option_b = f"B: {option_b}" | |
| history.append((user_message, option_a, option_b)) | |
| return "", history | |
| def save_labeling_data(last_dialogue, score): | |
| ( | |
| prompt, | |
| response_1, | |
| response_2, | |
| ) = last_dialogue | |
| response_1 = response_1[3:] # Remove label "A: " | |
| response_2 = response_2[3:] # Remove label "B: " | |
| file_name = "data.jsonl" | |
| if repo is not None: | |
| repo.git_pull(rebase=True) | |
| with open(os.path.join("data", file_name), "a", encoding="utf-8") as f: | |
| data = { | |
| "labeler_id": LABELER_ID, | |
| "session_id": SESSION_ID, | |
| "prompt": prompt, | |
| "response_1": response_1, | |
| "response_2": response_2, | |
| "score": score, | |
| } | |
| json.dump(data, f, ensure_ascii=False) | |
| f.write("\n") | |
| repo.push_to_hub() | |
| def on_select(event: gr.SelectData, history): | |
| score = event.value | |
| index_to_delete = event.index | |
| threading.Thread(target=save_labeling_data, args=(history[-1], score)).start() | |
| del history[-1][index_to_delete] | |
| return history | |
| with gr.Blocks(css=STYLE) as demo: | |
| chatbot = gr.Chatbot() | |
| user_message = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| user_message.submit( | |
| generate, | |
| [user_message, chatbot], | |
| [user_message, chatbot], | |
| queue=False, | |
| ).then( | |
| None, | |
| None, | |
| None, | |
| _js="""()=>{ | |
| let last_elem = document.querySelector("div.message.bot.done"); | |
| last_elem.classList.remove("done"); | |
| } | |
| """, | |
| ) | |
| chatbot.select(on_select, chatbot, chatbot).then( | |
| None, | |
| None, | |
| None, | |
| _js="""()=>{ | |
| let last_elem = document.querySelector("div.message.bot.latest"); | |
| last_elem.classList.remove("latest"); | |
| last_elem.classList.add("done"); | |
| } | |
| """, | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch() | |