Spaces:
Sleeping
Sleeping
| # basic_handler.py | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import traceback | |
| from websockets.asyncio.client import connect | |
| # Configuration for Gemini API | |
| host = "generativelanguage.googleapis.com" | |
| model = "gemini-2.0-flash-live-001" # You can change this to a different model if needed | |
| api_key_env = os.environ.get("GOOGLE_API_KEY", "") | |
| uri_template = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={{api_key}}" | |
| class AudioLoop: | |
| def __init__(self): | |
| self.ws = None | |
| # Queue for messages to be sent *to* Gemini | |
| self.out_queue = asyncio.Queue() | |
| # Queue for PCM audio received *from* Gemini | |
| self.audio_in_queue = asyncio.Queue() | |
| # Flag to signal shutdown | |
| self.shutdown_event = asyncio.Event() | |
| async def startup(self, api_key=None): | |
| """Send the model setup message to Gemini. | |
| Args: | |
| api_key: API key to use (overrides environment variable) | |
| """ | |
| # Use provided API key or fallback to environment variable | |
| key = api_key or api_key_env | |
| if not key: | |
| raise ValueError("No API key provided and GOOGLE_API_KEY environment variable not set") | |
| uri = uri_template.format(api_key=key) | |
| self.ws = await connect(uri, additional_headers={"Content-Type": "application/json"}) | |
| # Absolutely minimal setup message | |
| setup_msg = { | |
| "setup": { | |
| "model": f"models/{model}" | |
| } | |
| } | |
| await self.ws.send(json.dumps(setup_msg)) | |
| raw_response = await self.ws.recv() | |
| setup_response = json.loads(raw_response) | |
| print("[AudioLoop] Setup response from Gemini:", setup_response) | |
| async def send_realtime(self): | |
| """Read from out_queue and forward those messages to Gemini in real time.""" | |
| try: | |
| while not self.shutdown_event.is_set(): | |
| # Get next message from queue with timeout | |
| try: | |
| msg = await asyncio.wait_for(self.out_queue.get(), 0.5) | |
| await self.ws.send(json.dumps(msg)) | |
| except asyncio.TimeoutError: | |
| # No message in queue, continue checking | |
| continue | |
| except asyncio.CancelledError: | |
| print("[AudioLoop] send_realtime task cancelled") | |
| except Exception as e: | |
| print(f"[AudioLoop] Error in send_realtime: {e}") | |
| traceback.print_exc() | |
| finally: | |
| print("[AudioLoop] send_realtime task ended") | |
| async def receive_audio(self): | |
| """Read from Gemini websocket and process responses.""" | |
| try: | |
| while not self.shutdown_event.is_set(): | |
| try: | |
| raw_response = await asyncio.wait_for(self.ws.recv(), 0.5) | |
| response = json.loads(raw_response) | |
| # Print for debugging | |
| print(f"[AudioLoop] Received response: {json.dumps(response)[:500]}...") | |
| # Process audio data if present | |
| try: | |
| # Check for inline PCM data | |
| if ("serverContent" in response and | |
| "modelTurn" in response["serverContent"] and | |
| "parts" in response["serverContent"]["modelTurn"]): | |
| parts = response["serverContent"]["modelTurn"]["parts"] | |
| for part in parts: | |
| if "inlineData" in part and "data" in part["inlineData"]: | |
| b64data = part["inlineData"]["data"] | |
| pcm_data = base64.b64decode(b64data) | |
| await self.audio_in_queue.put(pcm_data) | |
| except Exception as e: | |
| print(f"[AudioLoop] Error extracting audio: {e}") | |
| # Handle tool calls if present | |
| tool_call = response.pop('toolCall', None) | |
| if tool_call: | |
| print(f"[AudioLoop] Tool call received: {tool_call}") | |
| # Send simple OK response for now | |
| for fc in tool_call.get('functionCalls', []): | |
| resp_msg = { | |
| 'tool_response': { | |
| 'function_responses': [{ | |
| 'id': fc.get('id', ''), | |
| 'name': fc.get('name', ''), | |
| 'response': {'result': {'string_value': 'ok'}} | |
| }] | |
| } | |
| } | |
| await self.ws.send(json.dumps(resp_msg)) | |
| except asyncio.TimeoutError: | |
| # No message received, continue checking | |
| continue | |
| except Exception as e: | |
| print(f"[AudioLoop] Error processing message: {e}") | |
| traceback.print_exc() | |
| except asyncio.CancelledError: | |
| print("[AudioLoop] receive_audio task cancelled") | |
| except Exception as e: | |
| print(f"[AudioLoop] Error in receive_audio: {e}") | |
| traceback.print_exc() | |
| finally: | |
| print("[AudioLoop] receive_audio task ended") | |
| async def run(self): | |
| """Main entry point: connects to Gemini, starts send/receive tasks.""" | |
| try: | |
| # Initialize the connection with Gemini | |
| await self.startup() | |
| # Start processing tasks | |
| try: | |
| # Create tasks for sending and receiving data | |
| send_task = asyncio.create_task(self.send_realtime()) | |
| receive_task = asyncio.create_task(self.receive_audio()) | |
| # Wait for shutdown event | |
| await self.shutdown_event.wait() | |
| # Cancel tasks | |
| send_task.cancel() | |
| receive_task.cancel() | |
| # Wait for tasks to complete | |
| await asyncio.gather(send_task, receive_task, return_exceptions=True) | |
| finally: | |
| # Clean up connection | |
| try: | |
| await self.ws.close() | |
| print("[AudioLoop] Closed WebSocket connection") | |
| except Exception as e: | |
| print(f"[AudioLoop] Error closing Gemini connection: {e}") | |
| except asyncio.CancelledError: | |
| print("[AudioLoop] run task cancelled") | |
| except Exception as e: | |
| print(f"[AudioLoop] Error in run: {e}") | |
| traceback.print_exc() | |
| finally: | |
| print("[AudioLoop] run task ended") | |
| def stop(self): | |
| """Signal tasks to stop and clean up resources.""" | |
| self.shutdown_event.set() |