Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import PyPDF2 | |
| import logging | |
| import torch | |
| import threading | |
| import time | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| from transformers import logging as hf_logging | |
| import spaces | |
| from llama_index.core import ( | |
| StorageContext, | |
| VectorStoreIndex, | |
| load_index_from_storage, | |
| Document as LlamaDocument, | |
| ) | |
| from llama_index.core import Settings | |
| from llama_index.core.node_parser import ( | |
| HierarchicalNodeParser, | |
| get_leaf_nodes, | |
| get_root_nodes, | |
| ) | |
| from llama_index.core.retrievers import AutoMergingRetriever | |
| from llama_index.core.storage.docstore import SimpleDocumentStore | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from tqdm import tqdm | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| hf_logging.set_verbosity_error() | |
| MODEL = "unsloth/Llama-3.2-3B-Instruct" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN not found in environment variables") | |
| # --- UI Settings --- | |
| TITLE = "<h1 style='text-align:center; margin-bottom: 20px;'>Local Thinking RAG: Llama 3.1 8B</h1>" | |
| DISCORD_BADGE = """<p style="text-align:center; margin-top: -10px;"> | |
| <a href="https://discord.gg/openfreeai" target="_blank"> | |
| <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge"> | |
| </a> | |
| </p> | |
| """ | |
| CSS = """ | |
| .upload-section { | |
| max-width: 400px; | |
| margin: 0 auto; | |
| padding: 10px; | |
| border: 2px dashed #ccc; | |
| border-radius: 10px; | |
| } | |
| .upload-button { | |
| background: #34c759 !important; | |
| color: white !important; | |
| border-radius: 25px !important; | |
| } | |
| .chatbot-container { | |
| margin-top: 20px; | |
| } | |
| .status-output { | |
| margin-top: 10px; | |
| font-size: 14px; | |
| } | |
| .processing-info { | |
| margin-top: 5px; | |
| font-size: 12px; | |
| color: #666; | |
| } | |
| .info-container { | |
| margin-top: 10px; | |
| padding: 10px; | |
| border-radius: 5px; | |
| } | |
| .file-list { | |
| margin-top: 0; | |
| max-height: 200px; | |
| overflow-y: auto; | |
| padding: 5px; | |
| border: 1px solid #eee; | |
| border-radius: 5px; | |
| } | |
| .stats-box { | |
| margin-top: 10px; | |
| padding: 10px; | |
| border-radius: 5px; | |
| font-size: 12px; | |
| } | |
| .submit-btn { | |
| background: #1a73e8 !important; | |
| color: white !important; | |
| border-radius: 25px !important; | |
| margin-left: 10px; | |
| padding: 5px 10px; | |
| font-size: 16px; | |
| } | |
| .input-row { | |
| display: flex; | |
| align-items: center; | |
| } | |
| @media (min-width: 768px) { | |
| .main-container { | |
| display: flex; | |
| justify-content: space-between; | |
| gap: 20px; | |
| } | |
| .upload-section { | |
| flex: 1; | |
| max-width: 300px; | |
| } | |
| .chatbot-container { | |
| flex: 2; | |
| margin-top: 0; | |
| } | |
| } | |
| """ | |
| global_model = None | |
| global_tokenizer = None | |
| global_file_info = {} | |
| def initialize_model_and_tokenizer(): | |
| global global_model, global_tokenizer | |
| if global_model is None or global_tokenizer is None: | |
| logger.info("Initializing model and tokenizer...") | |
| global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN) | |
| global_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| logger.info("Model and tokenizer initialized successfully") | |
| def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| global global_model, global_tokenizer | |
| if global_model is None or global_tokenizer is None: | |
| initialize_model_and_tokenizer() | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=global_tokenizer, | |
| model=global_model, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |
| def extract_text_from_document(file): | |
| file_name = file.name | |
| file_extension = os.path.splitext(file_name)[1].lower() | |
| if file_extension == '.txt': | |
| text = file.read().decode('utf-8') | |
| return text, len(text.split()), None | |
| elif file_extension == '.pdf': | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "\n\n".join(page.extract_text() for page in pdf_reader.pages) | |
| return text, len(text.split()), None | |
| else: | |
| return None, 0, ValueError(f"Unsupported file format: {file_extension}") | |
| def create_or_update_index(files, request: gr.Request): | |
| global global_file_info | |
| if not files: | |
| return "Please provide files.", "" | |
| start_time = time.time() | |
| user_id = request.session_hash | |
| save_dir = f"./{user_id}_index" | |
| # Initialize LlamaIndex modules | |
| llm = get_llm() | |
| embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) | |
| Settings.llm = llm | |
| Settings.embed_model = embed_model | |
| file_stats = [] | |
| new_documents = [] | |
| for file in tqdm(files, desc="Processing files"): | |
| file_basename = os.path.basename(file.name) | |
| text, word_count, error = extract_text_from_document(file) | |
| if error: | |
| logger.error(f"Error processing file {file_basename}: {str(error)}") | |
| file_stats.append({ | |
| "name": file_basename, | |
| "words": 0, | |
| "status": f"error: {str(error)}" | |
| }) | |
| continue | |
| doc = LlamaDocument( | |
| text=text, | |
| metadata={ | |
| "file_name": file_basename, | |
| "word_count": word_count, | |
| "source": "user_upload" | |
| } | |
| ) | |
| new_documents.append(doc) | |
| file_stats.append({ | |
| "name": file_basename, | |
| "words": word_count, | |
| "status": "processed" | |
| }) | |
| global_file_info[file_basename] = { | |
| "word_count": word_count, | |
| "processed_at": time.time() | |
| } | |
| node_parser = HierarchicalNodeParser.from_defaults( | |
| chunk_sizes=[2048, 512, 128], | |
| chunk_overlap=20 | |
| ) | |
| logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes") | |
| new_nodes = node_parser.get_nodes_from_documents(new_documents) | |
| new_leaf_nodes = get_leaf_nodes(new_nodes) | |
| new_root_nodes = get_root_nodes(new_nodes) | |
| logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)") | |
| if os.path.exists(save_dir): | |
| logger.info(f"Loading existing index from {save_dir}") | |
| storage_context = StorageContext.from_defaults(persist_dir=save_dir) | |
| index = load_index_from_storage(storage_context, settings=Settings) | |
| docstore = storage_context.docstore | |
| docstore.add_documents(new_nodes) | |
| for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"): | |
| index.insert_nodes([node]) | |
| total_docs = len(docstore.docs) | |
| logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files") | |
| else: | |
| logger.info("Creating new index") | |
| docstore = SimpleDocumentStore() | |
| storage_context = StorageContext.from_defaults(docstore=docstore) | |
| docstore.add_documents(new_nodes) | |
| index = VectorStoreIndex( | |
| new_leaf_nodes, | |
| storage_context=storage_context, | |
| settings=Settings | |
| ) | |
| total_docs = len(new_documents) | |
| logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files") | |
| index.storage_context.persist(persist_dir=save_dir) | |
| # custom outputs after processing files | |
| file_list_html = "<div class='file-list'>" | |
| for stat in file_stats: | |
| status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336" | |
| file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>" | |
| file_list_html += "</div>" | |
| processing_time = time.time() - start_time | |
| stats_output = f"<div class='stats-box'>" | |
| stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>" | |
| stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>" | |
| stats_output += f"✓ Total documents in index: {total_docs}<br>" | |
| stats_output += f"✓ Index saved to: {save_dir}<br>" | |
| stats_output += "</div>" | |
| output_container = f"<div class='info-container'>" | |
| output_container += file_list_html | |
| output_container += stats_output | |
| output_container += "</div>" | |
| return f"Successfully indexed {len(files)} files.", output_container | |
| def stream_chat( | |
| message: str, | |
| history: list, | |
| system_prompt: str, | |
| temperature: float, | |
| max_new_tokens: int, | |
| top_p: float, | |
| top_k: int, | |
| penalty: float, | |
| retriever_k: int, | |
| merge_threshold: float, | |
| request: gr.Request | |
| ): | |
| if not request: | |
| yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}] | |
| return | |
| user_id = request.session_hash | |
| index_dir = f"./{user_id}_index" | |
| if not os.path.exists(index_dir): | |
| yield history + [{"role": "assistant", "content": "Please upload documents first."}] | |
| return | |
| max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024 | |
| temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9 | |
| top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95 | |
| top_k = int(top_k) if isinstance(top_k, (int, float)) else 50 | |
| penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2 | |
| retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15 | |
| merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5 | |
| llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k) | |
| embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) | |
| Settings.llm = llm | |
| Settings.embed_model = embed_model | |
| storage_context = StorageContext.from_defaults(persist_dir=index_dir) | |
| index = load_index_from_storage(storage_context, settings=Settings) | |
| base_retriever = index.as_retriever(similarity_top_k=retriever_k) | |
| auto_merging_retriever = AutoMergingRetriever( | |
| base_retriever, | |
| storage_context=storage_context, | |
| simple_ratio_thresh=merge_threshold, | |
| verbose=True | |
| ) | |
| logger.info(f"Query: {message}") | |
| retrieval_start = time.time() | |
| base_nodes = base_retriever.retrieve(message) | |
| logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s") | |
| base_file_sources = {} | |
| for node in base_nodes: | |
| if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata: | |
| file_name = node.node.metadata['file_name'] | |
| if file_name not in base_file_sources: | |
| base_file_sources[file_name] = 0 | |
| base_file_sources[file_name] += 1 | |
| logger.info(f"Base retrieval file distribution: {base_file_sources}") | |
| merging_start = time.time() | |
| merged_nodes = auto_merging_retriever.retrieve(message) | |
| logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s") | |
| merged_file_sources = {} | |
| for node in merged_nodes: | |
| if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata: | |
| file_name = node.node.metadata['file_name'] | |
| if file_name not in merged_file_sources: | |
| merged_file_sources[file_name] = 0 | |
| merged_file_sources[file_name] += 1 | |
| logger.info(f"Merged retrieval file distribution: {merged_file_sources}") | |
| context = "\n\n".join([n.node.text for n in merged_nodes]) | |
| source_info = "" | |
| if merged_file_sources: | |
| source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys()) | |
| formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}" | |
| messages = [{"role": "system", "content": formatted_system_prompt}] | |
| for entry in history: | |
| messages.append(entry) | |
| messages.append({"role": "user", "content": message}) | |
| prompt = global_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| stop_event = threading.Event() | |
| class StopOnEvent(StoppingCriteria): | |
| def __init__(self, stop_event): | |
| super().__init__() | |
| self.stop_event = stop_event | |
| def __call__(self, input_ids, scores, **kwargs): | |
| return self.stop_event.is_set() | |
| stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)]) | |
| streamer = TextIteratorStreamer( | |
| global_tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=penalty, | |
| do_sample=True, | |
| stopping_criteria=stopping_criteria | |
| ) | |
| thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| updated_history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": ""} | |
| ] | |
| yield updated_history | |
| partial_response = "" | |
| try: | |
| for new_text in streamer: | |
| partial_response += new_text | |
| updated_history[-1]["content"] = partial_response | |
| yield updated_history | |
| yield updated_history | |
| except GeneratorExit: | |
| stop_event.set() | |
| thread.join() | |
| raise | |
| def create_demo(): | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| # Title | |
| gr.HTML(TITLE) | |
| # Discord badge immediately under the title | |
| gr.HTML(DISCORD_BADGE) | |
| with gr.Row(elem_classes="main-container"): | |
| with gr.Column(elem_classes="upload-section"): | |
| file_upload = gr.File( | |
| file_count="multiple", | |
| label="Drag & Drop PDF/TXT Files Here", | |
| file_types=[".pdf", ".txt"], | |
| elem_id="file-upload" | |
| ) | |
| upload_button = gr.Button("Upload & Index", elem_classes="upload-button") | |
| status_output = gr.Textbox( | |
| label="Status", | |
| placeholder="Upload files to start...", | |
| interactive=False | |
| ) | |
| file_info_output = gr.HTML( | |
| label="File Information", | |
| elem_classes="processing-info" | |
| ) | |
| upload_button.click( | |
| fn=create_or_update_index, | |
| inputs=[file_upload], | |
| outputs=[status_output, file_info_output] | |
| ) | |
| with gr.Column(elem_classes="chatbot-container"): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| placeholder="Chat with your documents...", | |
| show_label=False, | |
| type="messages" | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| message_input = gr.Textbox( | |
| placeholder="Type your question here...", | |
| show_label=False, | |
| container=False, | |
| lines=1, | |
| scale=8 | |
| ) | |
| submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| system_prompt = gr.Textbox( | |
| value="You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. As a knowledgeable assistant, provide detailed answers using the relevant information from all uploaded documents.", | |
| label="System Prompt", | |
| lines=3 | |
| ) | |
| with gr.Tab("Generation Parameters"): | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.9, | |
| label="Temperature" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=8192, | |
| step=64, | |
| value=1024, | |
| label="Max New Tokens", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.95, | |
| label="Top P" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top K" | |
| ) | |
| penalty = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.2, | |
| label="Repetition Penalty" | |
| ) | |
| with gr.Tab("Retrieval Parameters"): | |
| retriever_k = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| step=1, | |
| value=15, | |
| label="Initial Retrieval Size (Top K)" | |
| ) | |
| merge_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| step=0.1, | |
| value=0.5, | |
| label="Merge Threshold (lower = more merging)" | |
| ) | |
| submit_button.click( | |
| fn=stream_chat, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold | |
| ], | |
| outputs=chatbot | |
| ) | |
| message_input.submit( | |
| fn=stream_chat, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold | |
| ], | |
| outputs=chatbot | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| initialize_model_and_tokenizer() | |
| demo = create_demo() | |
| demo.launch() | |