| | """ |
| | StackNet API Client |
| | |
| | Handles all communication with the StackNettask network. |
| | SSE parsing and progress tracking are handled internally. |
| | """ |
| |
|
| | import json |
| | import tempfile |
| | import os |
| | from typing import AsyncGenerator, Optional, Any, Callable |
| | from dataclasses import dataclass |
| | from enum import Enum |
| |
|
| | import httpx |
| |
|
| | from ..config import config |
| |
|
| |
|
| | class MediaAction(str, Enum): |
| | """Supported media orchestration actions.""" |
| | GENERATE_MUSIC = "generate_music" |
| | CREATE_COVER = "create_cover" |
| | EXTRACT_STEMS = "extract_stems" |
| | ANALYZE_VISUAL = "analyze_visual" |
| | DESCRIBE_VIDEO = "describe_video" |
| | CREATE_COMPOSITE = "create_composite" |
| |
|
| |
|
| | @dataclass |
| | class TaskProgress: |
| | """Progress update from a running task.""" |
| | progress: float |
| | status: str |
| | message: str |
| |
|
| |
|
| | @dataclass |
| | class TaskResult: |
| | """Final result from a completed task.""" |
| | success: bool |
| | data: dict |
| | error: Optional[str] = None |
| |
|
| |
|
| | class StackNetClient: |
| | """ |
| | Client for StackNet task network API. |
| | |
| | All SSE parsing and polling is handled internally. |
| | Consumers receive clean progress updates and final results. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | base_url: Optional[str] = None, |
| | api_key: Optional[str] = None, |
| | timeout: float = 300.0 |
| | ): |
| | self.base_url = base_url or config.stacknet_url |
| | self.api_key = api_key |
| | self.timeout = timeout |
| | self._temp_dir = tempfile.mkdtemp(prefix="stacknet_") |
| |
|
| | async def submit_tool_task( |
| | self, |
| | tool_name: str, |
| | parameters: dict, |
| | server_name: str = "geoff", |
| | on_progress: Optional[Callable[[float, str], None]] = None |
| | ) -> TaskResult: |
| | """ |
| | Submit an MCP tool task and wait for completion. |
| | |
| | Args: |
| | tool_name: The tool to invoke (e.g., generate_image_5) |
| | parameters: Tool parameters |
| | server_name: MCP server name (default: geoff) |
| | on_progress: Callback for progress updates |
| | |
| | Returns: |
| | TaskResult with success status and output data |
| | """ |
| | payload = { |
| | "type": "mcp-tool", |
| | "serverName": server_name, |
| | "toolName": tool_name, |
| | "stream": True, |
| | "parameters": parameters |
| | } |
| |
|
| | headers = {"Content-Type": "application/json"} |
| | if self.api_key: |
| | auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
| | headers["Authorization"] = auth_header |
| |
|
| | async with httpx.AsyncClient(timeout=self.timeout) as client: |
| | try: |
| | async with client.stream( |
| | "POST", |
| | f"{self.base_url}/tasks", |
| | json=payload, |
| | headers=headers |
| | ) as response: |
| | if response.status_code != 200: |
| | error_text = await response.aread() |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
| | ) |
| |
|
| | return await self._process_sse_stream(response, on_progress) |
| |
|
| | except httpx.TimeoutException: |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error="Request timed out. The operation took too long." |
| | ) |
| | except httpx.RequestError as e: |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error=f"Network error: {str(e)}" |
| | ) |
| |
|
| | async def submit_media_task( |
| | self, |
| | action: MediaAction, |
| | prompt: Optional[str] = None, |
| | media_url: Optional[str] = None, |
| | audio_url: Optional[str] = None, |
| | video_url: Optional[str] = None, |
| | options: Optional[dict] = None, |
| | on_progress: Optional[Callable[[float, str], None]] = None |
| | ) -> TaskResult: |
| | """ |
| | Submit a media orchestration task and wait for completion. |
| | |
| | Args: |
| | action: The media action to perform |
| | prompt: Text prompt for generation |
| | media_url: URL for image input |
| | audio_url: URL for audio input |
| | video_url: URL for video input |
| | options: Additional options (tags, title, etc.) |
| | on_progress: Callback for progress updates (progress: 0-1, message: str) |
| | |
| | Returns: |
| | TaskResult with success status and output data |
| | """ |
| | payload = { |
| | "type": config.TASK_TYPE_MEDIA, |
| | "action": action.value, |
| | "stream": True, |
| | } |
| |
|
| | if prompt: |
| | payload["prompt"] = prompt |
| | if media_url: |
| | payload["mediaUrl"] = media_url |
| | if audio_url: |
| | payload["audioUrl"] = audio_url |
| | if video_url: |
| | payload["videoUrl"] = video_url |
| | if options: |
| | payload["options"] = options |
| |
|
| | headers = {"Content-Type": "application/json"} |
| | if self.api_key: |
| | auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
| | headers["Authorization"] = auth_header |
| |
|
| | async with httpx.AsyncClient(timeout=self.timeout) as client: |
| | try: |
| | async with client.stream( |
| | "POST", |
| | f"{self.base_url}/tasks", |
| | json=payload, |
| | headers=headers |
| | ) as response: |
| | if response.status_code != 200: |
| | error_text = await response.aread() |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
| | ) |
| |
|
| | return await self._process_sse_stream(response, on_progress) |
| |
|
| | except httpx.TimeoutException: |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error="Request timed out. The operation took too long." |
| | ) |
| | except httpx.RequestError as e: |
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error=f"Network error: {str(e)}" |
| | ) |
| |
|
| | async def _process_sse_stream( |
| | self, |
| | response: httpx.Response, |
| | on_progress: Optional[Callable[[float, str], None]] = None |
| | ) -> TaskResult: |
| | """Process SSE stream and extract final result.""" |
| | buffer = "" |
| | final_result: Optional[dict] = None |
| | error_message: Optional[str] = None |
| |
|
| | async for chunk in response.aiter_text(): |
| | buffer += chunk |
| | lines = buffer.split("\n") |
| | buffer = lines.pop() |
| |
|
| | for line in lines: |
| | if not line.startswith("data: "): |
| | continue |
| |
|
| | raw_data = line[6:].strip() |
| |
|
| | |
| | if raw_data == "[DONE]" or not raw_data: |
| | continue |
| |
|
| | try: |
| | event = json.loads(raw_data) |
| | event_type = event.get("type", "") |
| | event_data = event.get("data", event) |
| |
|
| | if event_type == "progress": |
| | if on_progress: |
| | progress = self._calculate_progress(event_data) |
| | message = event_data.get("message", "Processing...") |
| | on_progress(progress, message) |
| |
|
| | elif event_type == "result": |
| | final_result = event_data.get("output", event_data) |
| |
|
| | elif event_type == "error": |
| | error_message = event_data.get("message", "Unknown error occurred") |
| |
|
| | elif event_type == "complete": |
| | |
| | pass |
| |
|
| | except json.JSONDecodeError: |
| | continue |
| |
|
| | |
| | if buffer.strip() and buffer.startswith("data: "): |
| | raw_data = buffer[6:].strip() |
| | if raw_data and raw_data != "[DONE]": |
| | try: |
| | event = json.loads(raw_data) |
| | if event.get("type") == "result": |
| | final_result = event.get("data", {}).get("output", event.get("data", {})) |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | if error_message: |
| | return TaskResult(success=False, data={}, error=error_message) |
| |
|
| | if final_result: |
| | return TaskResult(success=True, data=final_result) |
| |
|
| | return TaskResult( |
| | success=False, |
| | data={}, |
| | error="No result received from the API" |
| | ) |
| |
|
| | def _calculate_progress(self, data: dict) -> float: |
| | """Calculate normalized progress (0.0 to 1.0).""" |
| | if not data: |
| | return 0.5 |
| |
|
| | status = data.get("status", "") |
| |
|
| | if status == "completed": |
| | return 1.0 |
| | if status == "polling": |
| | attempt = data.get("attempt", 1) |
| | max_attempts = data.get("maxAttempts", 30) |
| | return 0.2 + (attempt / max_attempts) * 0.6 |
| | if status == "processing": |
| | return 0.5 |
| | if status == "submitted": |
| | return 0.1 |
| |
|
| | return 0.5 |
| |
|
| | async def download_file(self, url: str, filename: Optional[str] = None) -> str: |
| | """Download a file to the temp directory and return local path.""" |
| | if not filename: |
| | filename = url.split("/")[-1].split("?")[0] |
| | if not filename: |
| | filename = "download" |
| |
|
| | local_path = os.path.join(self._temp_dir, filename) |
| |
|
| | async with httpx.AsyncClient(timeout=60.0) as client: |
| | response = await client.get(url) |
| | response.raise_for_status() |
| |
|
| | with open(local_path, "wb") as f: |
| | f.write(response.content) |
| |
|
| | return local_path |
| |
|
| | def cleanup(self): |
| | """Clean up temporary files.""" |
| | import shutil |
| | if os.path.exists(self._temp_dir): |
| | shutil.rmtree(self._temp_dir, ignore_errors=True) |
| |
|