Spaces:
Build error
Build error
| import json | |
| import os | |
| import gradio as gr | |
| from distilabel.llms import InferenceEndpointsLLM, LlamaCppLLM | |
| from distilabel.steps.tasks.argillalabeller import ArgillaLabeller | |
| file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf") | |
| download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q5_K_S.gguf?download=true" | |
| if not os.path.exists(file_path): | |
| import requests | |
| import tqdm | |
| response = requests.get(download_url, stream=True) | |
| total_length = int(response.headers.get("content-length")) | |
| with open(file_path, "wb") as f: | |
| for chunk in tqdm.tqdm( | |
| response.iter_content(chunk_size=1024 * 1024), | |
| total=total_length / (1024 * 1024), | |
| unit="KB", | |
| unit_scale=True, | |
| ): | |
| f.write(chunk) | |
| llm_cpp = LlamaCppLLM( | |
| model_path=file_path, | |
| n_gpu_layers=-1, | |
| n_ctx=1000 * 114, | |
| generation_kwargs={"max_new_tokens": 1000 * 14}, | |
| ) | |
| task_cpp = ArgillaLabeller(llm=llm_cpp) | |
| task_cpp.load() | |
| llm_ep = InferenceEndpointsLLM( | |
| model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| generation_kwargs={"max_new_tokens": 1000}, | |
| ) | |
| task_ep = ArgillaLabeller(llm=llm_ep) | |
| task_ep.load() | |
| def load_examples(): | |
| with open("examples.json", "r") as f: | |
| return json.load(f) | |
| # Create Gradio examples | |
| examples = load_examples()[:1] | |
| def process_fields(fields): | |
| if isinstance(fields, str): | |
| fields = json.loads(fields) | |
| if isinstance(fields, dict): | |
| fields = [fields] | |
| return [field if isinstance(field, dict) else json.loads(field) for field in fields] | |
| def process_records_gradio(records, fields, question, example_records=None): | |
| try: | |
| # Convert string inputs to dictionaries | |
| if isinstance(records, str) and records: | |
| records = json.loads(records) | |
| if isinstance(example_records, str) and example_records: | |
| example_records = json.loads(example_records) | |
| if isinstance(fields, str) and fields: | |
| fields = json.loads(fields) | |
| if isinstance(question, str) and question: | |
| question = json.loads(question) | |
| if not fields and not question: | |
| raise Exception("Error: Either fields or question must be provided") | |
| runtime_parameters = {"fields": fields, "question": question} | |
| if example_records: | |
| runtime_parameters["example_records"] = example_records | |
| task_ep.set_runtime_parameters(runtime_parameters) | |
| task_cpp.set_runtime_parameters(runtime_parameters) | |
| results = [] | |
| try: | |
| output = next( | |
| task_ep.process(inputs=[{"record": record} for record in records]) | |
| ) | |
| except Exception: | |
| output = next( | |
| task_cpp.process(inputs=[{"record": record} for record in records]) | |
| ) | |
| for idx in range(len(records)): | |
| entry = output[idx] | |
| if entry["suggestions"]: | |
| results.append(entry["suggestions"]) | |
| return json.dumps({"results": results}, indent=2) | |
| except Exception as e: | |
| raise gr.Error(f"Error: {str(e)}") | |
| description = """ | |
| An example workflow for JSON payload. | |
| ```python | |
| import json | |
| import os | |
| from gradio_client import Client | |
| import argilla as rg | |
| # Initialize Argilla client | |
| gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
| argilla_client = rg.Argilla( | |
| api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"] | |
| ) | |
| # Load the dataset | |
| dataset = argilla_client.datasets(name="my_dataset", workspace="my_workspace") | |
| # Get the field and question | |
| field = dataset.settings.fields["text"] | |
| question = dataset.settings.questions["sentiment"] | |
| # Get completed and pending records | |
| completed_records_filter = rg.Filter(("status", "==", "completed")) | |
| pending_records_filter = rg.Filter(("status", "==", "pending")) | |
| example_records = list( | |
| dataset.records( | |
| query=rg.Query(filter=completed_records_filter), | |
| limit=5, | |
| ) | |
| ) | |
| some_pending_records = list( | |
| dataset.records( | |
| query=rg.Query(filter=pending_records_filter), | |
| limit=5, | |
| ) | |
| ) | |
| # Process the records | |
| payload = { | |
| "records": [record.to_dict() for record in some_pending_records], | |
| "fields": [field.serialize()], | |
| "question": question.serialize(), | |
| "example_records": [record.to_dict() for record in example_records], | |
| "api_name": "/predict", | |
| } | |
| response = gradio_client.predict(**payload) | |
| ``` | |
| """ | |
| interface = gr.Interface( | |
| fn=process_records_gradio, | |
| inputs=[ | |
| gr.Code(label="Records (JSON)", language="json", lines=5), | |
| gr.Code(label="Example Records (JSON, optional)", language="json", lines=5), | |
| gr.Code(label="Fields (JSON, optional)", language="json"), | |
| gr.Code(label="Question (JSON, optional)", language="json"), | |
| ], | |
| examples=examples, | |
| cache_examples=False, | |
| outputs=gr.Code(label="Suggestions", language="json", lines=10), | |
| title="Distilabel - ArgillaLabeller - Record Processing Interface", | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |