Spaces:
Running
on
T4
Upload app.py
Browse files# Implement Real-Time Streaming for Chat Responses
## Description
This PR introduces real-time streaming functionality to our chat interface., aiming to enhance the user experience by providing immediate, token-by-token responses.
## Changes
- Enabled streaming in the HuggingFaceEndpoint configuration
- Implemented an asynchronous streaming process using `astream()`
- Modified the chat function to yield partial results in real-time
- Updated Gradio setup to support streaming responses (set queue as False)
## Expected Behavior
- Responses should start appearing immediately after a question is asked
- Text should stream in smoothly, word by word or token by token
- The final response should be identical to the non-streaming version
## Technical Details
Key components of the implementation:
1. **Streaming Callback**: Implemented `StreamingStdOutCallbackHandler` for real-time token processing.
2. **LLM Configuration**: Added `streaming=True` to `HuggingFaceEndpoint` setup.
3. **Asynchronous Streaming**: Created `process_stream()` function to handle token-by-token response generation.
4. **Real-Time Updates**: Modified main loop to yield updates as they become available.
|
@@ -217,29 +217,37 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
| 217 |
|
| 218 |
##-----------------------getting inference endpoints------------------------------
|
| 219 |
|
|
|
|
| 220 |
callback = StreamingStdOutCallbackHandler()
|
| 221 |
|
|
|
|
| 222 |
llm_qa = HuggingFaceEndpoint(
|
| 223 |
endpoint_url=model_config.get('reader', 'ENDPOINT'),
|
| 224 |
max_new_tokens=512,
|
| 225 |
repetition_penalty=1.03,
|
| 226 |
timeout=70,
|
| 227 |
huggingfacehub_api_token=HF_token,
|
| 228 |
-
streaming=True,
|
| 229 |
-
callbacks=[callback]
|
| 230 |
)
|
| 231 |
|
|
|
|
| 232 |
chat_model = ChatHuggingFace(llm=llm_qa)
|
| 233 |
|
|
|
|
| 234 |
docs_html = []
|
| 235 |
for i, d in enumerate(context_retrieved, 1):
|
| 236 |
docs_html.append(make_html_source(d, i))
|
| 237 |
docs_html = "".join(docs_html)
|
| 238 |
|
|
|
|
| 239 |
answer_yet = ""
|
| 240 |
|
|
|
|
| 241 |
async def process_stream():
|
| 242 |
-
nonlocal answer_yet
|
|
|
|
|
|
|
| 243 |
async for chunk in chat_model.astream(messages):
|
| 244 |
token = chunk.content
|
| 245 |
answer_yet += token
|
|
@@ -247,9 +255,10 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
| 247 |
history[-1] = (query, parsed_answer)
|
| 248 |
yield [tuple(x) for x in history], docs_html
|
| 249 |
|
|
|
|
| 250 |
async for update in process_stream():
|
| 251 |
yield update
|
| 252 |
-
|
| 253 |
# #callbacks = [StreamingStdOutCallbackHandler()]
|
| 254 |
# llm_qa = HuggingFaceEndpoint(
|
| 255 |
# endpoint_url= model_config.get('reader','ENDPOINT'),
|
|
@@ -508,11 +517,13 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
|
|
| 508 |
# https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
|
| 509 |
(textbox
|
| 510 |
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
|
|
|
| 511 |
.then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
|
| 512 |
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
|
| 513 |
|
| 514 |
(examples_hidden
|
| 515 |
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
|
|
|
|
| 516 |
.then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
|
| 517 |
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
|
| 518 |
)
|
|
|
|
| 217 |
|
| 218 |
##-----------------------getting inference endpoints------------------------------
|
| 219 |
|
| 220 |
+
# Set up the streaming callback handler
|
| 221 |
callback = StreamingStdOutCallbackHandler()
|
| 222 |
|
| 223 |
+
# Initialize the HuggingFaceEndpoint with streaming enabled
|
| 224 |
llm_qa = HuggingFaceEndpoint(
|
| 225 |
endpoint_url=model_config.get('reader', 'ENDPOINT'),
|
| 226 |
max_new_tokens=512,
|
| 227 |
repetition_penalty=1.03,
|
| 228 |
timeout=70,
|
| 229 |
huggingfacehub_api_token=HF_token,
|
| 230 |
+
streaming=True, # Enable streaming for real-time token generation
|
| 231 |
+
callbacks=[callback] # Add the streaming callback handler
|
| 232 |
)
|
| 233 |
|
| 234 |
+
# Create a ChatHuggingFace instance with the streaming-enabled endpoint
|
| 235 |
chat_model = ChatHuggingFace(llm=llm_qa)
|
| 236 |
|
| 237 |
+
# Prepare the HTML for displaying source documents
|
| 238 |
docs_html = []
|
| 239 |
for i, d in enumerate(context_retrieved, 1):
|
| 240 |
docs_html.append(make_html_source(d, i))
|
| 241 |
docs_html = "".join(docs_html)
|
| 242 |
|
| 243 |
+
# Initialize the variable to store the accumulated answer
|
| 244 |
answer_yet = ""
|
| 245 |
|
| 246 |
+
# Define an asynchronous generator function to process the streaming response
|
| 247 |
async def process_stream():
|
| 248 |
+
# Without nonlocal, Python would create a new local variable answer_yet inside process_stream(), instead of modifying the one from the outer scope.
|
| 249 |
+
nonlocal answer_yet # Use the outer scope's answer_yet variable
|
| 250 |
+
# Iterate over the streaming response chunks
|
| 251 |
async for chunk in chat_model.astream(messages):
|
| 252 |
token = chunk.content
|
| 253 |
answer_yet += token
|
|
|
|
| 255 |
history[-1] = (query, parsed_answer)
|
| 256 |
yield [tuple(x) for x in history], docs_html
|
| 257 |
|
| 258 |
+
# Stream the response updates
|
| 259 |
async for update in process_stream():
|
| 260 |
yield update
|
| 261 |
+
|
| 262 |
# #callbacks = [StreamingStdOutCallbackHandler()]
|
| 263 |
# llm_qa = HuggingFaceEndpoint(
|
| 264 |
# endpoint_url= model_config.get('reader','ENDPOINT'),
|
|
|
|
| 517 |
# https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
|
| 518 |
(textbox
|
| 519 |
.submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
|
| 520 |
+
# queue must be set as False (default) so the process is not waiting for another to be finished
|
| 521 |
.then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
|
| 522 |
.then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
|
| 523 |
|
| 524 |
(examples_hidden
|
| 525 |
.change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
|
| 526 |
+
# queue must be set as False (default) so the process is not waiting for another to be finished
|
| 527 |
.then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
|
| 528 |
.then(finish_chat, None, [textbox], api_name="finish_chat_examples")
|
| 529 |
)
|