File size: 6,836 Bytes
ab1b163 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""FastAPI backend for LLM Council."""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Dict, Any
import uuid
import json
import asyncio
from . import storage
from .council import run_full_council, generate_conversation_title, stage1_collect_responses, stage2_collect_rankings, stage3_synthesize_final, calculate_aggregate_rankings
app = FastAPI(title="LLM Council API")
# Enable CORS for local development
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class CreateConversationRequest(BaseModel):
"""Request to create a new conversation."""
pass
class SendMessageRequest(BaseModel):
"""Request to send a message in a conversation."""
content: str
class ConversationMetadata(BaseModel):
"""Conversation metadata for list view."""
id: str
created_at: str
title: str
message_count: int
class Conversation(BaseModel):
"""Full conversation with all messages."""
id: str
created_at: str
title: str
messages: List[Dict[str, Any]]
@app.get("/")
async def root():
"""Health check endpoint."""
return {"status": "ok", "service": "LLM Council API"}
@app.get("/api/conversations", response_model=List[ConversationMetadata])
async def list_conversations():
"""List all conversations (metadata only)."""
return storage.list_conversations()
@app.post("/api/conversations", response_model=Conversation)
async def create_conversation(request: CreateConversationRequest):
"""Create a new conversation."""
conversation_id = str(uuid.uuid4())
conversation = storage.create_conversation(conversation_id)
return conversation
@app.get("/api/conversations/{conversation_id}", response_model=Conversation)
async def get_conversation(conversation_id: str):
"""Get a specific conversation with all its messages."""
conversation = storage.get_conversation(conversation_id)
if conversation is None:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversation
@app.post("/api/conversations/{conversation_id}/message")
async def send_message(conversation_id: str, request: SendMessageRequest):
"""
Send a message and run the 3-stage council process.
Returns the complete response with all stages.
"""
# Check if conversation exists
conversation = storage.get_conversation(conversation_id)
if conversation is None:
raise HTTPException(status_code=404, detail="Conversation not found")
# Check if this is the first message
is_first_message = len(conversation["messages"]) == 0
# Add user message
storage.add_user_message(conversation_id, request.content)
# If this is the first message, generate a title
if is_first_message:
title = await generate_conversation_title(request.content)
storage.update_conversation_title(conversation_id, title)
# Run the 3-stage council process
stage1_results, stage2_results, stage3_result, metadata = await run_full_council(
request.content
)
# Add assistant message with all stages
storage.add_assistant_message(
conversation_id,
stage1_results,
stage2_results,
stage3_result
)
# Return the complete response with metadata
return {
"stage1": stage1_results,
"stage2": stage2_results,
"stage3": stage3_result,
"metadata": metadata
}
@app.post("/api/conversations/{conversation_id}/message/stream")
async def send_message_stream(conversation_id: str, request: SendMessageRequest):
"""
Send a message and stream the 3-stage council process.
Returns Server-Sent Events as each stage completes.
"""
# Check if conversation exists
conversation = storage.get_conversation(conversation_id)
if conversation is None:
raise HTTPException(status_code=404, detail="Conversation not found")
# Check if this is the first message
is_first_message = len(conversation["messages"]) == 0
async def event_generator():
try:
# Add user message
storage.add_user_message(conversation_id, request.content)
# Start title generation in parallel (don't await yet)
title_task = None
if is_first_message:
title_task = asyncio.create_task(generate_conversation_title(request.content))
# Stage 1: Collect responses
yield f"data: {json.dumps({'type': 'stage1_start'})}\n\n"
stage1_results = await stage1_collect_responses(request.content)
yield f"data: {json.dumps({'type': 'stage1_complete', 'data': stage1_results})}\n\n"
# Stage 2: Collect rankings
yield f"data: {json.dumps({'type': 'stage2_start'})}\n\n"
stage2_results, label_to_model = await stage2_collect_rankings(request.content, stage1_results)
aggregate_rankings = calculate_aggregate_rankings(stage2_results, label_to_model)
yield f"data: {json.dumps({'type': 'stage2_complete', 'data': stage2_results, 'metadata': {'label_to_model': label_to_model, 'aggregate_rankings': aggregate_rankings}})}\n\n"
# Stage 3: Synthesize final answer
yield f"data: {json.dumps({'type': 'stage3_start'})}\n\n"
stage3_result = await stage3_synthesize_final(request.content, stage1_results, stage2_results)
yield f"data: {json.dumps({'type': 'stage3_complete', 'data': stage3_result})}\n\n"
# Wait for title generation if it was started
if title_task:
title = await title_task
storage.update_conversation_title(conversation_id, title)
yield f"data: {json.dumps({'type': 'title_complete', 'data': {'title': title}})}\n\n"
# Save complete assistant message
storage.add_assistant_message(
conversation_id,
stage1_results,
stage2_results,
stage3_result
)
# Send completion event
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
except Exception as e:
# Send error event
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
|