Commit
·
361bd3e
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +16 -0
- .env.example +100 -0
- .gitattributes +35 -0
- .github/workflows/sync-portfolio.yml +70 -0
- .gitignore +194 -0
- Dockerfile +53 -0
- README.md +240 -0
- compose.yaml +35 -0
- entrypoint.sh +37 -0
- pyproject.toml +108 -0
- src/agents/__init__.py +17 -0
- src/agents/agents.py +97 -0
- src/agents/bg_task_agent/bg_task_agent.py +62 -0
- src/agents/bg_task_agent/task.py +53 -0
- src/agents/chatbot.py +23 -0
- src/agents/command_agent.py +55 -0
- src/agents/github_mcp_agent/github_mcp_agent.py +102 -0
- src/agents/interrupt_agent.py +232 -0
- src/agents/knowledge_base_agent.py +174 -0
- src/agents/langgraph_supervisor_agent.py +62 -0
- src/agents/langgraph_supervisor_hierarchy_agent.py +46 -0
- src/agents/lazy_agent.py +43 -0
- src/agents/llama_guard.py +121 -0
- src/agents/portfolio_agent/database_search.py +44 -0
- src/agents/portfolio_agent/portfolio_agent.py +85 -0
- src/agents/portfolio_agent/prompt.py +115 -0
- src/agents/rag_assistant.py +146 -0
- src/agents/research_assistant.py +148 -0
- src/agents/tools.py +56 -0
- src/agents/utils.py +17 -0
- src/core/__init__.py +4 -0
- src/core/embeddings.py +37 -0
- src/core/llm.py +147 -0
- src/core/settings.py +289 -0
- src/memory/__init__.py +40 -0
- src/memory/mongodb.py +62 -0
- src/memory/postgres.py +135 -0
- src/memory/sqlite.py +40 -0
- src/run_agent.py +40 -0
- src/run_service.py +37 -0
- src/schema/__init__.py +25 -0
- src/schema/models.py +165 -0
- src/schema/schema.py +175 -0
- src/schema/task_data.py +74 -0
- src/scripts/create_chroma_db.py +83 -0
- src/scripts/load_portfolio.py +25 -0
- src/scripts/portfolio/document.py +129 -0
- src/scripts/portfolio/notion_loader.py +68 -0
- src/scripts/portfolio/portfolio_ingestion.py +150 -0
- src/scripts/portfolio/prompt.py +29 -0
.dockerignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.gitignore
|
| 3 |
+
.env
|
| 4 |
+
**/__pycache__
|
| 5 |
+
**/*.pyc
|
| 6 |
+
**/*.pyo
|
| 7 |
+
**/*.pyd
|
| 8 |
+
.Python
|
| 9 |
+
env
|
| 10 |
+
venv
|
| 11 |
+
.venv
|
| 12 |
+
chroma_db/
|
| 13 |
+
ollama_data/
|
| 14 |
+
.agent/
|
| 15 |
+
.gemini/
|
| 16 |
+
.specstory/
|
.env.example
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API keys for different providers
|
| 2 |
+
OPENAI_API_KEY=
|
| 3 |
+
AZURE_OPENAI_API_KEY=
|
| 4 |
+
DEEPSEEK_API_KEY=
|
| 5 |
+
ANTHROPIC_API_KEY=
|
| 6 |
+
GOOGLE_API_KEY=
|
| 7 |
+
GROQ_API_KEY=
|
| 8 |
+
OPENROUTER_API_KEY=
|
| 9 |
+
USE_AWS_BEDROCK=false
|
| 10 |
+
|
| 11 |
+
#Vertex AI
|
| 12 |
+
GOOGLE_APPLICATION_CREDENTIALS=
|
| 13 |
+
|
| 14 |
+
# Amazon Bedrock Knowledge Base ID
|
| 15 |
+
AWS_KB_ID="<knowledge-base-id>"
|
| 16 |
+
|
| 17 |
+
# Use a fake model for testing
|
| 18 |
+
USE_FAKE_MODEL=false
|
| 19 |
+
|
| 20 |
+
# Set a default model
|
| 21 |
+
DEFAULT_MODEL=
|
| 22 |
+
|
| 23 |
+
# If MODEL is set to "openai-compatible", set the following
|
| 24 |
+
# This is just a flexible solution. If you need multiple model options, you still need to add it to models.py
|
| 25 |
+
COMPATIBLE_MODEL=
|
| 26 |
+
COMPATIBLE_API_KEY=
|
| 27 |
+
COMPATIBLE_BASE_URL=
|
| 28 |
+
|
| 29 |
+
# Web server configuration
|
| 30 |
+
HOST=0.0.0.0
|
| 31 |
+
PORT=7860
|
| 32 |
+
|
| 33 |
+
# Authentication secret, HTTP bearer token header is required if set
|
| 34 |
+
AUTH_SECRET=
|
| 35 |
+
CORS_ORIGINS=http://localhost:3000,http://localhost:8081,http://localhost:5173
|
| 36 |
+
|
| 37 |
+
# Langsmith configuration
|
| 38 |
+
# LANGSMITH_TRACING=true
|
| 39 |
+
# LANGSMITH_API_KEY=
|
| 40 |
+
# LANGSMITH_PROJECT=default
|
| 41 |
+
# LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
| 42 |
+
|
| 43 |
+
# Application mode. If the value is "dev", it will enable uvicorn reload
|
| 44 |
+
MODE=
|
| 45 |
+
|
| 46 |
+
# Database type.
|
| 47 |
+
# If the value is "postgres", then it will require Postgresql related environment variables.
|
| 48 |
+
# If the value is "sqlite", then you can configure optional file path via SQLITE_DB_PATH
|
| 49 |
+
DATABASE_TYPE=
|
| 50 |
+
|
| 51 |
+
# If DATABASE_TYPE=sqlite (Optional)
|
| 52 |
+
SQLITE_DB_PATH=
|
| 53 |
+
|
| 54 |
+
# If DATABASE_TYPE=postgres
|
| 55 |
+
# Docker Compose default values (will work with docker-compose setup)
|
| 56 |
+
POSTGRES_USER=
|
| 57 |
+
POSTGRES_PASSWORD=
|
| 58 |
+
POSTGRES_HOST=
|
| 59 |
+
POSTGRES_PORT=
|
| 60 |
+
POSTGRES_DB=
|
| 61 |
+
|
| 62 |
+
# you will be able to identify AST connections in Postgres Connection Manager under this Application Name
|
| 63 |
+
# POSTGRES_APPLICATION_NAME = "agent-service-toolkit"
|
| 64 |
+
# set these values to customize the number of connections in the pool. Saver and store have independent connection pools
|
| 65 |
+
# POSTGRES_MIN_CONNECTIONS_PER_POOL=1
|
| 66 |
+
# POSTGRES_MAX_CONNECTIONS_PER_POOL= 3
|
| 67 |
+
|
| 68 |
+
# OpenWeatherMap API key
|
| 69 |
+
OPENWEATHERMAP_API_KEY=
|
| 70 |
+
|
| 71 |
+
# Add for running ollama
|
| 72 |
+
# OLLAMA_MODEL=llama3.2
|
| 73 |
+
# Note: set OLLAMA_BASE_URL if running service in docker and ollama on bare metal
|
| 74 |
+
# OLLAMA_BASE_URL=http://host.docker.internal:11434
|
| 75 |
+
|
| 76 |
+
# Add for running Azure OpenAI
|
| 77 |
+
# AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com
|
| 78 |
+
# AZURE_OPENAI_API_VERSION=2024-10-21
|
| 79 |
+
# AZURE_OPENAI_DEPLOYMENT_MAP={"gpt-4o": "gpt-4o-deployment", "gpt-4o-mini": "gpt-4o-mini-deployment"}
|
| 80 |
+
|
| 81 |
+
# Agent URL: used in Streamlit app - if not set, defaults to http://{HOST}:{PORT}
|
| 82 |
+
# AGENT_URL=http://localhost:7860
|
| 83 |
+
|
| 84 |
+
# LANGFUSE Configuration
|
| 85 |
+
#LANGFUSE_TRACING=true
|
| 86 |
+
#LANGFUSE_PUBLIC_KEY=pk-...
|
| 87 |
+
#LANGFUSE_SECRET_KEY=sk-lf-....
|
| 88 |
+
#LANGFUSE_HOST=http://localhost:3000
|
| 89 |
+
|
| 90 |
+
# GitHub MCP Agent Configuration
|
| 91 |
+
# GitHub Personal Access Token (required for GitHub MCP server)
|
| 92 |
+
# If not set, the GitHub MCP agent will have no tools
|
| 93 |
+
GITHUB_PAT=
|
| 94 |
+
|
| 95 |
+
# Voice Features (Optional)
|
| 96 |
+
# NOTE: Voice features are configured on the client (Streamlit app) side, not the server (API).
|
| 97 |
+
# Requires OPENAI_API_KEY to be set (see above).
|
| 98 |
+
# Set provider name to enable voice input/output. Leave empty to disable.
|
| 99 |
+
VOICE_STT_PROVIDER= # Speech-to-text provider (only 'openai' supported currently)
|
| 100 |
+
VOICE_TTS_PROVIDER= # Text-to-speech provider (only 'openai' supported currently)
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync-portfolio.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync Notion Portfolio to PGVector Store
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
inputs:
|
| 6 |
+
since:
|
| 7 |
+
description: 'ISO 8601 date to sync from (e.g., 2024-01-01T00:00:00.000Z). Leave empty to use last sync date.'
|
| 8 |
+
required: false
|
| 9 |
+
type: string
|
| 10 |
+
schedule:
|
| 11 |
+
# Run daily at 2 AM UTC
|
| 12 |
+
- cron: "0 2 * * *"
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
sync-notion-portfolio-to-pgvector-store:
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
|
| 18 |
+
steps:
|
| 19 |
+
- name: Checkout code
|
| 20 |
+
uses: actions/checkout@v4
|
| 21 |
+
|
| 22 |
+
- name: Set up Python
|
| 23 |
+
uses: actions/setup-python@v5
|
| 24 |
+
with:
|
| 25 |
+
python-version: '3.12'
|
| 26 |
+
|
| 27 |
+
- name: Install uv
|
| 28 |
+
uses: astral-sh/setup-uv@v4
|
| 29 |
+
with:
|
| 30 |
+
version: "latest"
|
| 31 |
+
|
| 32 |
+
- name: Install dependencies
|
| 33 |
+
run: uv sync --frozen
|
| 34 |
+
|
| 35 |
+
- name: Sync Notion portfolio data to PGVector store
|
| 36 |
+
env:
|
| 37 |
+
# Notion API credentials
|
| 38 |
+
NOTION_TOKEN: ${{ secrets.NOTION_TOKEN }}
|
| 39 |
+
NOTION_EDUCATION_ID: ${{ secrets.NOTION_EDUCATION_ID }}
|
| 40 |
+
NOTION_EXPERIENCE_ID: ${{ secrets.NOTION_EXPERIENCE_ID }}
|
| 41 |
+
NOTION_PROJECT_ID: ${{ secrets.NOTION_PROJECT_ID }}
|
| 42 |
+
NOTION_TESTIMONIAL_ID: ${{ secrets.NOTION_TESTIMONIAL_ID }}
|
| 43 |
+
NOTION_BLOG_ID: ${{ secrets.NOTION_BLOG_ID }}
|
| 44 |
+
|
| 45 |
+
# Database configuration
|
| 46 |
+
DATABASE_TYPE: postgres
|
| 47 |
+
POSTGRES_USER: ${{ secrets.POSTGRES_USER }}
|
| 48 |
+
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
|
| 49 |
+
POSTGRES_HOST: ${{ secrets.POSTGRES_HOST }}
|
| 50 |
+
POSTGRES_PORT: ${{ secrets.POSTGRES_PORT }}
|
| 51 |
+
POSTGRES_DB: ${{ secrets.POSTGRES_DB }}
|
| 52 |
+
|
| 53 |
+
# LLM API keys (at least one required)
|
| 54 |
+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
| 55 |
+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
| 56 |
+
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
| 57 |
+
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
| 58 |
+
|
| 59 |
+
# Optional settings
|
| 60 |
+
OWNER: ${{ secrets.OWNER || 'Anuj Joshi' }}
|
| 61 |
+
VECTOR_STORE_COLLECTION_NAME: ${{ secrets.VECTOR_STORE_COLLECTION_NAME || 'portfolio' }}
|
| 62 |
+
|
| 63 |
+
# Python path
|
| 64 |
+
PYTHONPATH: ${{ github.workspace }}/src
|
| 65 |
+
run: |
|
| 66 |
+
if [ -n "${{ inputs.since }}" ]; then
|
| 67 |
+
uv run python -m scripts.load_portfolio --since "${{ inputs.since }}"
|
| 68 |
+
else
|
| 69 |
+
uv run python -m scripts.load_portfolio
|
| 70 |
+
fi
|
.gitignore
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Streamlit and sqlite
|
| 2 |
+
.streamlit/secrets.toml
|
| 3 |
+
checkpoints.db
|
| 4 |
+
checkpoints.db-*
|
| 5 |
+
|
| 6 |
+
# Langgraph
|
| 7 |
+
.langgraph_api/
|
| 8 |
+
|
| 9 |
+
# VSCode
|
| 10 |
+
.vscode
|
| 11 |
+
.DS_Store
|
| 12 |
+
*.code-workspace
|
| 13 |
+
|
| 14 |
+
# cursor
|
| 15 |
+
.cursorindexingignore
|
| 16 |
+
.specstory/
|
| 17 |
+
|
| 18 |
+
# Byte-compiled / optimized / DLL files
|
| 19 |
+
__pycache__/
|
| 20 |
+
*.py[cod]
|
| 21 |
+
*$py.class
|
| 22 |
+
|
| 23 |
+
# C extensions
|
| 24 |
+
*.so
|
| 25 |
+
|
| 26 |
+
# Distribution / packaging
|
| 27 |
+
.Python
|
| 28 |
+
build/
|
| 29 |
+
develop-eggs/
|
| 30 |
+
dist/
|
| 31 |
+
downloads/
|
| 32 |
+
eggs/
|
| 33 |
+
.eggs/
|
| 34 |
+
lib/
|
| 35 |
+
lib64/
|
| 36 |
+
parts/
|
| 37 |
+
sdist/
|
| 38 |
+
var/
|
| 39 |
+
wheels/
|
| 40 |
+
share/python-wheels/
|
| 41 |
+
*.egg-info/
|
| 42 |
+
.installed.cfg
|
| 43 |
+
*.egg
|
| 44 |
+
MANIFEST
|
| 45 |
+
|
| 46 |
+
# PyInstaller
|
| 47 |
+
# Usually these files are written by a python script from a template
|
| 48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 49 |
+
*.manifest
|
| 50 |
+
*.spec
|
| 51 |
+
|
| 52 |
+
# Installer logs
|
| 53 |
+
pip-log.txt
|
| 54 |
+
pip-delete-this-directory.txt
|
| 55 |
+
|
| 56 |
+
# Unit test / coverage reports
|
| 57 |
+
htmlcov/
|
| 58 |
+
.tox/
|
| 59 |
+
.nox/
|
| 60 |
+
.coverage
|
| 61 |
+
.coverage.*
|
| 62 |
+
.cache
|
| 63 |
+
nosetests.xml
|
| 64 |
+
coverage.xml
|
| 65 |
+
*.cover
|
| 66 |
+
*.py,cover
|
| 67 |
+
.hypothesis/
|
| 68 |
+
.pytest_cache/
|
| 69 |
+
cover/
|
| 70 |
+
|
| 71 |
+
# Translations
|
| 72 |
+
*.mo
|
| 73 |
+
*.pot
|
| 74 |
+
|
| 75 |
+
# Django stuff:
|
| 76 |
+
*.log
|
| 77 |
+
local_settings.py
|
| 78 |
+
db.sqlite3
|
| 79 |
+
db.sqlite3-journal
|
| 80 |
+
|
| 81 |
+
# Flask stuff:
|
| 82 |
+
instance/
|
| 83 |
+
.webassets-cache
|
| 84 |
+
|
| 85 |
+
# Scrapy stuff:
|
| 86 |
+
.scrapy
|
| 87 |
+
|
| 88 |
+
# Sphinx documentation
|
| 89 |
+
docs/_build/
|
| 90 |
+
|
| 91 |
+
# PyBuilder
|
| 92 |
+
.pybuilder/
|
| 93 |
+
target/
|
| 94 |
+
|
| 95 |
+
# Jupyter Notebook
|
| 96 |
+
.ipynb_checkpoints
|
| 97 |
+
|
| 98 |
+
# IPython
|
| 99 |
+
profile_default/
|
| 100 |
+
ipython_config.py
|
| 101 |
+
|
| 102 |
+
# pyenv
|
| 103 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 104 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 105 |
+
# .python-version
|
| 106 |
+
|
| 107 |
+
# pipenv
|
| 108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 111 |
+
# install all needed dependencies.
|
| 112 |
+
#Pipfile.lock
|
| 113 |
+
|
| 114 |
+
# poetry
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 116 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 117 |
+
# commonly ignored for libraries.
|
| 118 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 119 |
+
#poetry.lock
|
| 120 |
+
|
| 121 |
+
# pdm
|
| 122 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 123 |
+
#pdm.lock
|
| 124 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 125 |
+
# in version control.
|
| 126 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 127 |
+
.pdm.toml
|
| 128 |
+
.pdm-python
|
| 129 |
+
.pdm-build/
|
| 130 |
+
|
| 131 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 132 |
+
__pypackages__/
|
| 133 |
+
|
| 134 |
+
# Celery stuff
|
| 135 |
+
celerybeat-schedule
|
| 136 |
+
celerybeat.pid
|
| 137 |
+
|
| 138 |
+
# SageMath parsed files
|
| 139 |
+
*.sage.py
|
| 140 |
+
|
| 141 |
+
# Environments
|
| 142 |
+
.env
|
| 143 |
+
.python-version
|
| 144 |
+
.venv
|
| 145 |
+
env/
|
| 146 |
+
venv/
|
| 147 |
+
ENV/
|
| 148 |
+
env.bak/
|
| 149 |
+
venv.bak/
|
| 150 |
+
|
| 151 |
+
# Spyder project settings
|
| 152 |
+
.spyderproject
|
| 153 |
+
.spyproject
|
| 154 |
+
|
| 155 |
+
# Rope project settings
|
| 156 |
+
.ropeproject
|
| 157 |
+
|
| 158 |
+
# mkdocs documentation
|
| 159 |
+
/site
|
| 160 |
+
|
| 161 |
+
# mypy
|
| 162 |
+
.mypy_cache/
|
| 163 |
+
.dmypy.json
|
| 164 |
+
dmypy.json
|
| 165 |
+
|
| 166 |
+
# Pyre type checker
|
| 167 |
+
.pyre/
|
| 168 |
+
|
| 169 |
+
# pytype static type analyzer
|
| 170 |
+
.pytype/
|
| 171 |
+
|
| 172 |
+
# Cython debug symbols
|
| 173 |
+
cython_debug/
|
| 174 |
+
|
| 175 |
+
# PyCharm
|
| 176 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 177 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 178 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 179 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 180 |
+
.idea/
|
| 181 |
+
|
| 182 |
+
# Data
|
| 183 |
+
*.docx
|
| 184 |
+
*.pdf
|
| 185 |
+
|
| 186 |
+
# Chroma db
|
| 187 |
+
chroma_db
|
| 188 |
+
*.db
|
| 189 |
+
*.sqlite3
|
| 190 |
+
*.bin
|
| 191 |
+
|
| 192 |
+
# Private Credentials, ignore everything in the folder but the .gitkeep file
|
| 193 |
+
privatecredentials/*
|
| 194 |
+
!privatecredentials/.gitkeep
|
Dockerfile
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Stage 1: Get the official Ollama binary
|
| 2 |
+
FROM ollama/ollama:latest AS ollama_source
|
| 3 |
+
|
| 4 |
+
# Stage 2: Final Image
|
| 5 |
+
FROM python:3.12.3-slim
|
| 6 |
+
|
| 7 |
+
# Install system dependencies (curl for health, socat for port mapping)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
curl \
|
| 10 |
+
socat \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy Ollama from the official image
|
| 14 |
+
COPY --from=ollama_source /usr/bin/ollama /usr/bin/ollama
|
| 15 |
+
# Also copy the libraries which are necessary for Ollama to run
|
| 16 |
+
COPY --from=ollama_source /usr/lib/ollama /usr/lib/ollama
|
| 17 |
+
|
| 18 |
+
# Set up the non-root user
|
| 19 |
+
RUN useradd -m -u 1000 user
|
| 20 |
+
WORKDIR /app
|
| 21 |
+
RUN chown user:user /app
|
| 22 |
+
|
| 23 |
+
# Install uv and dependencies
|
| 24 |
+
RUN pip install --no-cache-dir uv
|
| 25 |
+
|
| 26 |
+
USER user
|
| 27 |
+
COPY --chown=user:user pyproject.toml uv.lock ./
|
| 28 |
+
RUN uv sync --frozen --no-dev --no-install-project
|
| 29 |
+
|
| 30 |
+
# Copy app code
|
| 31 |
+
COPY --chown=user:user src/agents/ ./agents/
|
| 32 |
+
COPY --chown=user:user src/core/ ./core/
|
| 33 |
+
COPY --chown=user:user src/memory/ ./memory/
|
| 34 |
+
COPY --chown=user:user src/schema/ ./schema/
|
| 35 |
+
COPY --chown=user:user src/service/ ./service/
|
| 36 |
+
COPY --chown=user:user src/run_service.py ./run_service.py
|
| 37 |
+
COPY --chown=user:user entrypoint.sh ./entrypoint.sh
|
| 38 |
+
RUN chmod +x ./entrypoint.sh
|
| 39 |
+
|
| 40 |
+
# Environment variables
|
| 41 |
+
ENV HOST=0.0.0.0
|
| 42 |
+
ENV PORT=7860
|
| 43 |
+
ENV OLLAMA_HOST=0.0.0.0:11434
|
| 44 |
+
ENV OLLAMA_KEEP_ALIVE=24h
|
| 45 |
+
ENV OLLAMA_BASE_URL=http://localhost:11434
|
| 46 |
+
|
| 47 |
+
EXPOSE 7860
|
| 48 |
+
EXPOSE 11434
|
| 49 |
+
|
| 50 |
+
RUN mkdir -p /home/user/.ollama
|
| 51 |
+
VOLUME ["/home/user/.ollama"]
|
| 52 |
+
|
| 53 |
+
CMD ["./entrypoint.sh"]
|
README.md
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AI Agent Service Toolkit
|
| 3 |
+
emoji: 🧰
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# 🧰 AI Agent Service Toolkit
|
| 11 |
+
|
| 12 |
+
[](https://github.com/JoshuaC215/agent-service-toolkit/actions/workflows/test.yml) [](https://codecov.io/github/JoshuaC215/agent-service-toolkit) [](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/pyproject.toml)
|
| 13 |
+
[](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/LICENSE) [](https://agent-service-toolkit.streamlit.app/)
|
| 14 |
+
|
| 15 |
+
A full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit.
|
| 16 |
+
|
| 17 |
+
It includes a [LangGraph](https://langchain-ai.github.io/langgraph/) agent, a [FastAPI](https://fastapi.tiangolo.com/) service to serve it, a client to interact with the service, and a [Streamlit](https://streamlit.io/) app that uses the client to provide a chat interface. Data structures and settings are built with [Pydantic](https://github.com/pydantic/pydantic).
|
| 18 |
+
|
| 19 |
+
This project offers a template for you to easily build and run your own agents using the LangGraph framework. It demonstrates a complete setup from agent definition to user interface, making it easier to get started with LangGraph-based projects by providing a full, robust toolkit.
|
| 20 |
+
|
| 21 |
+
**[🎥 Watch a video walkthrough of the repo and app](https://www.youtube.com/watch?v=pdYVHw_YCNY)**
|
| 22 |
+
|
| 23 |
+
## Overview
|
| 24 |
+
|
| 25 |
+
### [Try the app!](https://agent-service-toolkit.streamlit.app/)
|
| 26 |
+
|
| 27 |
+
<a href="https://agent-service-toolkit.streamlit.app/"><img src="media/app_screenshot.png" width="600"></a>
|
| 28 |
+
|
| 29 |
+
### Quickstart
|
| 30 |
+
|
| 31 |
+
Run directly in python
|
| 32 |
+
|
| 33 |
+
```sh
|
| 34 |
+
# At least one LLM API key is required
|
| 35 |
+
echo 'OPENAI_API_KEY=your_openai_api_key' >> .env
|
| 36 |
+
|
| 37 |
+
# uv is the recommended way to install agent-service-toolkit, but "pip install ." also works
|
| 38 |
+
# For uv installation options, see: https://docs.astral.sh/uv/getting-started/installation/
|
| 39 |
+
curl -LsSf https://astral.sh/uv/0.7.19/install.sh | sh
|
| 40 |
+
|
| 41 |
+
# Install dependencies. "uv sync" creates .venv automatically
|
| 42 |
+
uv sync --frozen
|
| 43 |
+
source .venv/bin/activate
|
| 44 |
+
python src/run_service.py
|
| 45 |
+
|
| 46 |
+
# In another shell
|
| 47 |
+
source .venv/bin/activate
|
| 48 |
+
streamlit run src/streamlit_app.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Run with docker
|
| 52 |
+
|
| 53 |
+
```sh
|
| 54 |
+
echo 'OPENAI_API_KEY=your_openai_api_key' >> .env
|
| 55 |
+
docker compose watch
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Architecture Diagram
|
| 59 |
+
|
| 60 |
+
<img src="media/agent_architecture.png" width="600">
|
| 61 |
+
|
| 62 |
+
### Key Features
|
| 63 |
+
|
| 64 |
+
1. **LangGraph Agent and latest features**: A customizable agent built using the LangGraph framework. Implements the latest LangGraph v1.0 features including human in the loop with `interrupt()`, flow control with `Command`, long-term memory with `Store`, and `langgraph-supervisor`.
|
| 65 |
+
1. **FastAPI Service**: Serves the agent with both streaming and non-streaming endpoints.
|
| 66 |
+
1. **Advanced Streaming**: A novel approach to support both token-based and message-based streaming.
|
| 67 |
+
1. **Streamlit Interface**: Provides a user-friendly chat interface for interacting with the agent, including voice input and output.
|
| 68 |
+
1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path. Available agents and models are described in `/info`
|
| 69 |
+
1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests.
|
| 70 |
+
1. **Content Moderation**: Implements LlamaGuard for content moderation (requires Groq API key).
|
| 71 |
+
1. **RAG Agent**: A basic RAG agent implementation using ChromaDB - see [docs](docs/RAG_Assistant.md).
|
| 72 |
+
1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith.
|
| 73 |
+
1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment.
|
| 74 |
+
1. **Testing**: Includes robust unit and integration tests for the full repo.
|
| 75 |
+
|
| 76 |
+
### Key Files
|
| 77 |
+
|
| 78 |
+
The repository is structured as follows:
|
| 79 |
+
|
| 80 |
+
- `src/agents/`: Defines several agents with different capabilities
|
| 81 |
+
- `src/schema/`: Defines the protocol schema
|
| 82 |
+
- `src/core/`: Core modules including LLM definition and settings
|
| 83 |
+
- `src/service/service.py`: FastAPI service to serve the agents
|
| 84 |
+
- `src/client/client.py`: Client to interact with the agent service
|
| 85 |
+
- `src/streamlit_app.py`: Streamlit app providing a chat interface
|
| 86 |
+
- `tests/`: Unit and integration tests
|
| 87 |
+
|
| 88 |
+
## Setup and Usage
|
| 89 |
+
|
| 90 |
+
1. Clone the repository:
|
| 91 |
+
|
| 92 |
+
```sh
|
| 93 |
+
git clone https://github.com/JoshuaC215/agent-service-toolkit.git
|
| 94 |
+
cd agent-service-toolkit
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
2. Set up environment variables:
|
| 98 |
+
Create a `.env` file in the root directory. At least one LLM API key or configuration is required. See the [`.env.example` file](./.env.example) for a full list of available environment variables, including a variety of model provider API keys, header-based authentication, LangSmith tracing, testing and development modes, and OpenWeatherMap API key.
|
| 99 |
+
|
| 100 |
+
3. You can now run the agent service and the Streamlit app locally, either with Docker or just using Python. The Docker setup is recommended for simpler environment setup and immediate reloading of the services when you make changes to your code.
|
| 101 |
+
|
| 102 |
+
### Additional setup for specific AI providers
|
| 103 |
+
|
| 104 |
+
- [Setting up Ollama](docs/Ollama.md)
|
| 105 |
+
- [Setting up VertexAI](docs/VertexAI.md)
|
| 106 |
+
- [Setting up RAG with ChromaDB](docs/RAG_Assistant.md)
|
| 107 |
+
|
| 108 |
+
### Building or customizing your own agent
|
| 109 |
+
|
| 110 |
+
To customize the agent for your own use case:
|
| 111 |
+
|
| 112 |
+
1. Add your new agent to the `src/agents` directory. You can copy `research_assistant.py` or `chatbot.py` and modify it to change the agent's behavior and tools.
|
| 113 |
+
1. Import and add your new agent to the `agents` dictionary in `src/agents/agents.py`. Your agent can be called by `/<your_agent_name>/invoke` or `/<your_agent_name>/stream`.
|
| 114 |
+
1. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities.
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
### Handling Private Credential files
|
| 118 |
+
|
| 119 |
+
If your agents or chosen LLM require file-based credential files or certificates, the `privatecredentials/` has been provided for your development convenience. All contents, excluding the `.gitkeep` files, are ignored by git and docker's build process. See [Working with File-based Credentials](docs/File_Based_Credentials.md) for suggested use.
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
### Docker Setup
|
| 123 |
+
|
| 124 |
+
This project includes a Docker setup for easy development and deployment. The `compose.yaml` file defines three services: `postgres`, `agent_service` and `streamlit_app`. The `Dockerfile` for each service is in their respective directories.
|
| 125 |
+
|
| 126 |
+
For local development, we recommend using [docker compose watch](https://docs.docker.com/compose/file-watch/). This feature allows for a smoother development experience by automatically updating your containers when changes are detected in your source code.
|
| 127 |
+
|
| 128 |
+
1. Make sure you have Docker and Docker Compose (>= [v2.23.0](https://docs.docker.com/compose/release-notes/#2230)) installed on your system.
|
| 129 |
+
|
| 130 |
+
2. Create a `.env` file from the `.env.example`. At minimum, you need to provide an LLM API key (e.g., OPENAI_API_KEY).
|
| 131 |
+
```sh
|
| 132 |
+
cp .env.example .env
|
| 133 |
+
# Edit .env to add your API keys
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
3. Build and launch the services in watch mode:
|
| 137 |
+
|
| 138 |
+
```sh
|
| 139 |
+
docker compose watch
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
This will automatically:
|
| 143 |
+
- Start a PostgreSQL database service that the agent service connects to
|
| 144 |
+
- Start the agent service with FastAPI
|
| 145 |
+
- Start the Streamlit app for the user interface
|
| 146 |
+
|
| 147 |
+
4. The services will now automatically update when you make changes to your code:
|
| 148 |
+
- Changes in the relevant python files and directories will trigger updates for the relevant services.
|
| 149 |
+
- NOTE: If you make changes to the `pyproject.toml` or `uv.lock` files, you will need to rebuild the services by running `docker compose up --build`.
|
| 150 |
+
|
| 151 |
+
5. Access the Streamlit app by navigating to `http://localhost:8501` in your web browser.
|
| 152 |
+
|
| 153 |
+
6. The agent service API will be available at `http://0.0.0.0:7860`. You can also use the OpenAPI docs at `http://0.0.0.0:7860/redoc`.
|
| 154 |
+
|
| 155 |
+
7. Use `docker compose down` to stop the services.
|
| 156 |
+
|
| 157 |
+
This setup allows you to develop and test your changes in real-time without manually restarting the services.
|
| 158 |
+
|
| 159 |
+
### Building other apps on the AgentClient
|
| 160 |
+
|
| 161 |
+
The repo includes a generic `src/client/client.AgentClient` that can be used to interact with the agent service. This client is designed to be flexible and can be used to build other apps on top of the agent. It supports both synchronous and asynchronous invocations, and streaming and non-streaming requests.
|
| 162 |
+
|
| 163 |
+
See the `src/run_client.py` file for full examples of how to use the `AgentClient`. A quick example:
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
from client import AgentClient
|
| 167 |
+
client = AgentClient()
|
| 168 |
+
|
| 169 |
+
response = client.invoke("Tell me a brief joke?")
|
| 170 |
+
response.pretty_print()
|
| 171 |
+
# ================================== Ai Message ==================================
|
| 172 |
+
#
|
| 173 |
+
# A man walked into a library and asked the librarian, "Do you have any books on Pavlov's dogs and Schrödinger's cat?"
|
| 174 |
+
# The librarian replied, "It rings a bell, but I'm not sure if it's here or not."
|
| 175 |
+
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### Development with LangGraph Studio
|
| 179 |
+
|
| 180 |
+
The agent supports [LangGraph Studio](https://langchain-ai.github.io/langgraph/concepts/langgraph_studio/), the IDE for developing agents in LangGraph.
|
| 181 |
+
|
| 182 |
+
`langgraph-cli[inmem]` is installed with `uv sync`. You can simply add your `.env` file to the root directory as described above, and then launch LangGraph Studio with `langgraph dev`. Customize `langgraph.json` as needed. See the [local quickstart](https://langchain-ai.github.io/langgraph/cloud/how-tos/studio/quick_start/#local-development-server) to learn more.
|
| 183 |
+
|
| 184 |
+
### Local development without Docker
|
| 185 |
+
|
| 186 |
+
You can also run the agent service and the Streamlit app locally without Docker, just using a Python virtual environment.
|
| 187 |
+
|
| 188 |
+
1. Create a virtual environment and install dependencies:
|
| 189 |
+
|
| 190 |
+
```sh
|
| 191 |
+
uv sync --frozen
|
| 192 |
+
source .venv/bin/activate
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
2. Run the FastAPI server:
|
| 196 |
+
|
| 197 |
+
```sh
|
| 198 |
+
python src/run_service.py
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
3. In a separate terminal, run the Streamlit app:
|
| 202 |
+
|
| 203 |
+
```sh
|
| 204 |
+
streamlit run src/streamlit_app.py
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
4. Open your browser and navigate to the URL provided by Streamlit (usually `http://localhost:8501`).
|
| 208 |
+
|
| 209 |
+
## Projects built with or inspired by agent-service-toolkit
|
| 210 |
+
|
| 211 |
+
The following are a few of the public projects that drew code or inspiration from this repo.
|
| 212 |
+
|
| 213 |
+
- **[PolyRAG](https://github.com/QuentinFuxa/PolyRAG)** - Extends agent-service-toolkit with RAG capabilities over both PostgreSQL databases and PDF documents.
|
| 214 |
+
- **[alexrisch/agent-web-kit](https://github.com/alexrisch/agent-web-kit)** - A Next.JS frontend for agent-service-toolkit
|
| 215 |
+
- **[raushan-in/dapa](https://github.com/raushan-in/dapa)** - Digital Arrest Protection App (DAPA) enables users to report financial scams and frauds efficiently via a user-friendly platform.
|
| 216 |
+
|
| 217 |
+
**Please create a pull request editing the README or open a discussion with any new ones to be added!** Would love to include more projects.
|
| 218 |
+
|
| 219 |
+
## Contributing
|
| 220 |
+
|
| 221 |
+
Contributions are welcome! Please feel free to submit a Pull Request. Currently the tests need to be run using the local development without Docker setup. To run the tests for the agent service:
|
| 222 |
+
|
| 223 |
+
1. Ensure you're in the project root directory and have activated your virtual environment.
|
| 224 |
+
|
| 225 |
+
2. Install the development dependencies and pre-commit hooks:
|
| 226 |
+
|
| 227 |
+
```sh
|
| 228 |
+
uv sync --frozen
|
| 229 |
+
pre-commit install
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
3. Run the tests using pytest:
|
| 233 |
+
|
| 234 |
+
```sh
|
| 235 |
+
pytest
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## License
|
| 239 |
+
|
| 240 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
compose.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
portfolio_bot:
|
| 3 |
+
build:
|
| 4 |
+
context: .
|
| 5 |
+
dockerfile: Dockerfile
|
| 6 |
+
ports:
|
| 7 |
+
- "7860:7860"
|
| 8 |
+
env_file:
|
| 9 |
+
- .env
|
| 10 |
+
environment:
|
| 11 |
+
- OLLAMA_BASE_URL=http://localhost:11434
|
| 12 |
+
- OLLAMA_MODEL=qwen3:0.6b
|
| 13 |
+
- OLLAMA_EMBEDDING_MODEL=embeddinggemma:300m
|
| 14 |
+
healthcheck:
|
| 15 |
+
test: [ "CMD", "curl", "-f", "http://localhost:7860/health" ]
|
| 16 |
+
interval: 30s
|
| 17 |
+
timeout: 10s
|
| 18 |
+
retries: 5
|
| 19 |
+
develop:
|
| 20 |
+
watch:
|
| 21 |
+
- path: src/agents/
|
| 22 |
+
action: sync+restart
|
| 23 |
+
target: /app/agents/
|
| 24 |
+
- path: src/schema/
|
| 25 |
+
action: sync+restart
|
| 26 |
+
target: /app/schema/
|
| 27 |
+
- path: src/service/
|
| 28 |
+
action: sync+restart
|
| 29 |
+
target: /app/service/
|
| 30 |
+
- path: src/core/
|
| 31 |
+
action: sync+restart
|
| 32 |
+
target: /app/core/
|
| 33 |
+
- path: src/memory/
|
| 34 |
+
action: sync+restart
|
| 35 |
+
target: /app/memory/
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
echo "Starting Ollama..."
|
| 5 |
+
ollama serve &
|
| 6 |
+
OLLAMA_PID=$!
|
| 7 |
+
|
| 8 |
+
# Wait for Ollama to start
|
| 9 |
+
until curl -s http://localhost:11434/api/tags > /dev/null; do
|
| 10 |
+
echo "Waiting for Ollama..."
|
| 11 |
+
sleep 2
|
| 12 |
+
done
|
| 13 |
+
|
| 14 |
+
echo "Pulling models..."
|
| 15 |
+
# Pull models based on environment variables or defaults
|
| 16 |
+
OLLAMA_MODEL=${OLLAMA_MODEL:-"qwen3:0.6b"}
|
| 17 |
+
OLLAMA_EMBEDDING_MODEL=${OLLAMA_EMBEDDING_MODEL:-"embeddinggemma:300m"}
|
| 18 |
+
|
| 19 |
+
echo "Pulling model: $OLLAMA_MODEL"
|
| 20 |
+
ollama pull $OLLAMA_MODEL
|
| 21 |
+
|
| 22 |
+
echo "Pulling embedding model: $OLLAMA_EMBEDDING_MODEL"
|
| 23 |
+
ollama pull $OLLAMA_EMBEDDING_MODEL
|
| 24 |
+
|
| 25 |
+
echo "Ollama ready."
|
| 26 |
+
|
| 27 |
+
echo "Starting Backend Service on port 7860..."
|
| 28 |
+
# Set variables for the python service
|
| 29 |
+
export HOST=0.0.0.0
|
| 30 |
+
export PORT=7860
|
| 31 |
+
export OLLAMA_BASE_URL=http://localhost:11434
|
| 32 |
+
|
| 33 |
+
# Ensure logs are visible
|
| 34 |
+
export PYTHONUNBUFFERED=1
|
| 35 |
+
|
| 36 |
+
# Start the python service
|
| 37 |
+
uv run python run_service.py
|
pyproject.toml
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "agent-service-toolkit"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
authors = [{ name = "Joshua Carroll", email = "carroll.joshk@gmail.com" }]
|
| 7 |
+
classifiers = [
|
| 8 |
+
"Development Status :: 4 - Beta",
|
| 9 |
+
"License :: OSI Approved :: MIT License",
|
| 10 |
+
"Framework :: FastAPI",
|
| 11 |
+
"Programming Language :: Python :: 3.11",
|
| 12 |
+
"Programming Language :: Python :: 3.12",
|
| 13 |
+
"Programming Language :: Python :: 3.13",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
requires-python = ">=3.11,<3.14"
|
| 17 |
+
|
| 18 |
+
dependencies = [
|
| 19 |
+
"docx2txt ~=0.8",
|
| 20 |
+
"duckduckgo-search>=7.3.0",
|
| 21 |
+
"fastapi ~=0.115.5",
|
| 22 |
+
"grpcio >=1.68.0",
|
| 23 |
+
"httpx ~=0.28.0",
|
| 24 |
+
"jiter ~=0.8.2",
|
| 25 |
+
"langchain ~=1.0.5",
|
| 26 |
+
"langchain-core ~=1.0.0",
|
| 27 |
+
"langchain-community ~=0.4.1",
|
| 28 |
+
"langchain-anthropic ~=1.0.0",
|
| 29 |
+
"langchain-aws ~=1.0.0",
|
| 30 |
+
"langchain-postgres ~=0.0.9",
|
| 31 |
+
"langchain-google-genai ~=3.0.0",
|
| 32 |
+
"langchain-google-vertexai>=3.0.3",
|
| 33 |
+
"langchain-groq ~=1.0.1",
|
| 34 |
+
"langchain-ollama ~=1.0.0",
|
| 35 |
+
"langchain-openai ~=1.0.2",
|
| 36 |
+
"langfuse >=2.65.0",
|
| 37 |
+
"langgraph ~=1.0.0",
|
| 38 |
+
"langgraph-checkpoint-mongodb ~=0.1.3",
|
| 39 |
+
"langgraph-checkpoint-postgres ~=2.0.13",
|
| 40 |
+
"langgraph-checkpoint-sqlite ~=2.0.1",
|
| 41 |
+
"langgraph-supervisor ~=0.0.31",
|
| 42 |
+
"langsmith ~=0.4.0",
|
| 43 |
+
"numexpr ~=2.10.1",
|
| 44 |
+
"numpy ~=2.3.4",
|
| 45 |
+
"onnxruntime ~= 1.21.1",
|
| 46 |
+
"pandas ~=2.2.3",
|
| 47 |
+
"psycopg[binary,pool] ~=3.2.4",
|
| 48 |
+
"pyarrow >=18.1.0",
|
| 49 |
+
"pydantic ~=2.10.1",
|
| 50 |
+
"pydantic-settings ~=2.12.0",
|
| 51 |
+
"pypdf ~=5.3.0",
|
| 52 |
+
"pyowm ~=3.3.0",
|
| 53 |
+
"python-dotenv ~=1.0.1",
|
| 54 |
+
"setuptools ~=75.6.0",
|
| 55 |
+
"streamlit ~=1.52.0",
|
| 56 |
+
"tiktoken >=0.8.0",
|
| 57 |
+
"uvicorn ~=0.32.1",
|
| 58 |
+
"langchain-mcp-adapters>=0.1.10",
|
| 59 |
+
"ddgs>=9.9.1",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
[dependency-groups]
|
| 63 |
+
dev = [
|
| 64 |
+
"langgraph-cli[inmem]",
|
| 65 |
+
"pre-commit",
|
| 66 |
+
"pytest",
|
| 67 |
+
"pytest-cov",
|
| 68 |
+
"pytest-env",
|
| 69 |
+
"pytest-asyncio",
|
| 70 |
+
"ruff",
|
| 71 |
+
"mypy",
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Group for the minimal dependencies to run just the client and Streamlit app.
|
| 75 |
+
# These are also installed in the default dependencies.
|
| 76 |
+
# To install run: `uv sync --frozen --only-group client`
|
| 77 |
+
client = [
|
| 78 |
+
"httpx~=0.28.0",
|
| 79 |
+
"pydantic ~=2.10.1",
|
| 80 |
+
"python-dotenv ~=1.0.1",
|
| 81 |
+
"streamlit~=1.52.0",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
[tool.ruff]
|
| 85 |
+
line-length = 100
|
| 86 |
+
target-version = "py311"
|
| 87 |
+
|
| 88 |
+
[tool.ruff.lint]
|
| 89 |
+
extend-select = ["I", "U"]
|
| 90 |
+
|
| 91 |
+
[tool.pytest.ini_options]
|
| 92 |
+
pythonpath = ["src"]
|
| 93 |
+
asyncio_default_fixture_loop_scope = "function"
|
| 94 |
+
|
| 95 |
+
[tool.pytest_env]
|
| 96 |
+
OPENAI_API_KEY = "sk-fake-openai-key"
|
| 97 |
+
|
| 98 |
+
[tool.mypy]
|
| 99 |
+
plugins = "pydantic.mypy"
|
| 100 |
+
exclude = "src/streamlit_app.py"
|
| 101 |
+
|
| 102 |
+
[[tool.mypy.overrides]]
|
| 103 |
+
module = ["numexpr.*"]
|
| 104 |
+
follow_untyped_imports = true
|
| 105 |
+
|
| 106 |
+
[[tool.mypy.overrides]]
|
| 107 |
+
module = ["langchain_mcp_adapters.*"]
|
| 108 |
+
ignore_missing_imports = true
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.agents import (
|
| 2 |
+
DEFAULT_AGENT,
|
| 3 |
+
AgentGraph,
|
| 4 |
+
AgentGraphLike,
|
| 5 |
+
get_agent,
|
| 6 |
+
get_all_agent_info,
|
| 7 |
+
load_agent,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"get_agent",
|
| 12 |
+
"load_agent",
|
| 13 |
+
"get_all_agent_info",
|
| 14 |
+
"DEFAULT_AGENT",
|
| 15 |
+
"AgentGraph",
|
| 16 |
+
"AgentGraphLike",
|
| 17 |
+
]
|
src/agents/agents.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 4 |
+
from langgraph.pregel import Pregel
|
| 5 |
+
|
| 6 |
+
from agents.bg_task_agent.bg_task_agent import bg_task_agent
|
| 7 |
+
from agents.chatbot import chatbot
|
| 8 |
+
from agents.command_agent import command_agent
|
| 9 |
+
from agents.github_mcp_agent.github_mcp_agent import github_mcp_agent
|
| 10 |
+
from agents.interrupt_agent import interrupt_agent
|
| 11 |
+
from agents.knowledge_base_agent import kb_agent
|
| 12 |
+
from agents.langgraph_supervisor_agent import langgraph_supervisor_agent
|
| 13 |
+
from agents.langgraph_supervisor_hierarchy_agent import langgraph_supervisor_hierarchy_agent
|
| 14 |
+
from agents.portfolio_agent.portfolio_agent import portfolio_agent
|
| 15 |
+
|
| 16 |
+
from agents.lazy_agent import LazyLoadingAgent
|
| 17 |
+
from agents.rag_assistant import rag_assistant
|
| 18 |
+
from agents.research_assistant import research_assistant
|
| 19 |
+
from schema import AgentInfo
|
| 20 |
+
|
| 21 |
+
DEFAULT_AGENT = "portfolio-agent"
|
| 22 |
+
|
| 23 |
+
# Type alias to handle LangGraph's different agent patterns
|
| 24 |
+
# - @entrypoint functions return Pregel
|
| 25 |
+
# - StateGraph().compile() returns CompiledStateGraph
|
| 26 |
+
AgentGraph = CompiledStateGraph | Pregel # What get_agent() returns (always loaded)
|
| 27 |
+
AgentGraphLike = CompiledStateGraph | Pregel | LazyLoadingAgent # What can be stored in registry
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Agent:
|
| 32 |
+
description: str
|
| 33 |
+
graph_like: AgentGraphLike
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
agents: dict[str, Agent] = {
|
| 37 |
+
"chatbot": Agent(description="A simple chatbot.", graph_like=chatbot),
|
| 38 |
+
"research-assistant": Agent(
|
| 39 |
+
description="A research assistant with web search and calculator.",
|
| 40 |
+
graph_like=research_assistant,
|
| 41 |
+
),
|
| 42 |
+
"rag-assistant": Agent(
|
| 43 |
+
description="A RAG assistant with access to information in a database.",
|
| 44 |
+
graph_like=rag_assistant,
|
| 45 |
+
),
|
| 46 |
+
"portfolio-agent": Agent(
|
| 47 |
+
description="A portfolio assistant with access to information in a database.",
|
| 48 |
+
graph_like=portfolio_agent,
|
| 49 |
+
),
|
| 50 |
+
"command-agent": Agent(description="A command agent.", graph_like=command_agent),
|
| 51 |
+
"bg-task-agent": Agent(description="A background task agent.", graph_like=bg_task_agent),
|
| 52 |
+
"langgraph-supervisor-agent": Agent(
|
| 53 |
+
description="A langgraph supervisor agent", graph_like=langgraph_supervisor_agent
|
| 54 |
+
),
|
| 55 |
+
"langgraph-supervisor-hierarchy-agent": Agent(
|
| 56 |
+
description="A langgraph supervisor agent with a nested hierarchy of agents",
|
| 57 |
+
graph_like=langgraph_supervisor_hierarchy_agent,
|
| 58 |
+
),
|
| 59 |
+
"interrupt-agent": Agent(
|
| 60 |
+
description="An agent the uses interrupts.", graph_like=interrupt_agent
|
| 61 |
+
),
|
| 62 |
+
"knowledge-base-agent": Agent(
|
| 63 |
+
description="A retrieval-augmented generation agent using Amazon Bedrock Knowledge Base",
|
| 64 |
+
graph_like=kb_agent,
|
| 65 |
+
),
|
| 66 |
+
"github-mcp-agent": Agent(
|
| 67 |
+
description="A GitHub agent with MCP tools for repository management and development workflows.",
|
| 68 |
+
graph_like=github_mcp_agent,
|
| 69 |
+
),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def load_agent(agent_id: str) -> None:
|
| 74 |
+
"""Load lazy agents if needed."""
|
| 75 |
+
graph_like = agents[agent_id].graph_like
|
| 76 |
+
if isinstance(graph_like, LazyLoadingAgent):
|
| 77 |
+
await graph_like.load()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_agent(agent_id: str) -> AgentGraph:
|
| 81 |
+
"""Get an agent graph, loading lazy agents if needed."""
|
| 82 |
+
agent_graph = agents[agent_id].graph_like
|
| 83 |
+
|
| 84 |
+
# If it's a lazy loading agent, ensure it's loaded and return its graph
|
| 85 |
+
if isinstance(agent_graph, LazyLoadingAgent):
|
| 86 |
+
if not agent_graph._loaded:
|
| 87 |
+
raise RuntimeError(f"Agent {agent_id} not loaded. Call load() first.")
|
| 88 |
+
return agent_graph.get_graph()
|
| 89 |
+
|
| 90 |
+
# Otherwise return the graph directly
|
| 91 |
+
return agent_graph
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_all_agent_info() -> list[AgentInfo]:
|
| 95 |
+
return [
|
| 96 |
+
AgentInfo(key=agent_id, description=agent.description) for agent_id, agent in agents.items()
|
| 97 |
+
]
|
src/agents/bg_task_agent/bg_task_agent.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 4 |
+
from langchain_core.messages import AIMessage
|
| 5 |
+
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
|
| 6 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 7 |
+
from langgraph.types import StreamWriter
|
| 8 |
+
|
| 9 |
+
from agents.bg_task_agent.task import Task
|
| 10 |
+
from core import get_model, settings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AgentState(MessagesState, total=False):
|
| 14 |
+
"""`total=False` is PEP589 specs.
|
| 15 |
+
|
| 16 |
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
|
| 21 |
+
preprocessor = RunnableLambda(
|
| 22 |
+
lambda state: state["messages"],
|
| 23 |
+
name="StateModifier",
|
| 24 |
+
)
|
| 25 |
+
return preprocessor | model # type: ignore[return-value]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 29 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 30 |
+
model_runnable = wrap_model(m)
|
| 31 |
+
response = await model_runnable.ainvoke(state, config)
|
| 32 |
+
|
| 33 |
+
# We return a list, because this will get added to the existing list
|
| 34 |
+
return {"messages": [response]}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def bg_task(state: AgentState, writer: StreamWriter) -> AgentState:
|
| 38 |
+
task1 = Task("Simple task 1...", writer)
|
| 39 |
+
task2 = Task("Simple task 2...", writer)
|
| 40 |
+
|
| 41 |
+
task1.start()
|
| 42 |
+
await asyncio.sleep(2)
|
| 43 |
+
task2.start()
|
| 44 |
+
await asyncio.sleep(2)
|
| 45 |
+
task1.write_data(data={"status": "Still running..."})
|
| 46 |
+
await asyncio.sleep(2)
|
| 47 |
+
task2.finish(result="error", data={"output": 42})
|
| 48 |
+
await asyncio.sleep(2)
|
| 49 |
+
task1.finish(result="success", data={"output": 42})
|
| 50 |
+
return {"messages": []}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Define the graph
|
| 54 |
+
agent = StateGraph(AgentState)
|
| 55 |
+
agent.add_node("model", acall_model)
|
| 56 |
+
agent.add_node("bg_task", bg_task)
|
| 57 |
+
agent.set_entry_point("bg_task")
|
| 58 |
+
|
| 59 |
+
agent.add_edge("bg_task", "model")
|
| 60 |
+
agent.add_edge("model", END)
|
| 61 |
+
|
| 62 |
+
bg_task_agent = agent.compile()
|
src/agents/bg_task_agent/task.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
from uuid import uuid4
|
| 3 |
+
|
| 4 |
+
from langchain_core.messages import BaseMessage
|
| 5 |
+
from langgraph.types import StreamWriter
|
| 6 |
+
|
| 7 |
+
from agents.utils import CustomData
|
| 8 |
+
from schema.task_data import TaskData
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Task:
|
| 12 |
+
def __init__(self, task_name: str, writer: StreamWriter | None = None) -> None:
|
| 13 |
+
self.name = task_name
|
| 14 |
+
self.id = str(uuid4())
|
| 15 |
+
self.state: Literal["new", "running", "complete"] = "new"
|
| 16 |
+
self.result: Literal["success", "error"] | None = None
|
| 17 |
+
self.writer = writer
|
| 18 |
+
|
| 19 |
+
def _generate_and_dispatch_message(self, writer: StreamWriter | None, data: dict):
|
| 20 |
+
writer = writer or self.writer
|
| 21 |
+
task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data)
|
| 22 |
+
if self.result:
|
| 23 |
+
task_data.result = self.result
|
| 24 |
+
task_custom_data = CustomData(
|
| 25 |
+
type=self.name,
|
| 26 |
+
data=task_data.model_dump(),
|
| 27 |
+
)
|
| 28 |
+
if writer:
|
| 29 |
+
task_custom_data.dispatch(writer)
|
| 30 |
+
return task_custom_data.to_langchain()
|
| 31 |
+
|
| 32 |
+
def start(self, writer: StreamWriter | None = None, data: dict = {}) -> BaseMessage:
|
| 33 |
+
self.state = "new"
|
| 34 |
+
task_message = self._generate_and_dispatch_message(writer, data)
|
| 35 |
+
return task_message
|
| 36 |
+
|
| 37 |
+
def write_data(self, writer: StreamWriter | None = None, data: dict = {}) -> BaseMessage:
|
| 38 |
+
if self.state == "complete":
|
| 39 |
+
raise ValueError("Only incomplete tasks can output data.")
|
| 40 |
+
self.state = "running"
|
| 41 |
+
task_message = self._generate_and_dispatch_message(writer, data)
|
| 42 |
+
return task_message
|
| 43 |
+
|
| 44 |
+
def finish(
|
| 45 |
+
self,
|
| 46 |
+
result: Literal["success", "error"],
|
| 47 |
+
writer: StreamWriter | None = None,
|
| 48 |
+
data: dict = {},
|
| 49 |
+
) -> BaseMessage:
|
| 50 |
+
self.state = "complete"
|
| 51 |
+
self.result = result
|
| 52 |
+
task_message = self._generate_and_dispatch_message(writer, data)
|
| 53 |
+
return task_message
|
src/agents/chatbot.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import BaseMessage
|
| 2 |
+
from langchain_core.runnables import RunnableConfig
|
| 3 |
+
from langgraph.func import entrypoint
|
| 4 |
+
|
| 5 |
+
from core import get_model, settings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@entrypoint()
|
| 9 |
+
async def chatbot(
|
| 10 |
+
inputs: dict[str, list[BaseMessage]],
|
| 11 |
+
*,
|
| 12 |
+
previous: dict[str, list[BaseMessage]],
|
| 13 |
+
config: RunnableConfig,
|
| 14 |
+
):
|
| 15 |
+
messages = inputs["messages"]
|
| 16 |
+
if previous:
|
| 17 |
+
messages = previous["messages"] + messages
|
| 18 |
+
|
| 19 |
+
model = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 20 |
+
response = await model.ainvoke(messages)
|
| 21 |
+
return entrypoint.final(
|
| 22 |
+
value={"messages": [response]}, save={"messages": messages + [response]}
|
| 23 |
+
)
|
src/agents/command_agent.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from langchain_core.messages import AIMessage
|
| 5 |
+
from langgraph.graph import START, MessagesState, StateGraph
|
| 6 |
+
from langgraph.types import Command
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AgentState(MessagesState, total=False):
|
| 10 |
+
"""`total=False` is PEP589 specs.
|
| 11 |
+
|
| 12 |
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Define the nodes
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def node_a(state: AgentState) -> Command[Literal["node_b", "node_c"]]:
|
| 20 |
+
print("Called A")
|
| 21 |
+
value = random.choice(["a", "b"])
|
| 22 |
+
goto: Literal["node_b", "node_c"]
|
| 23 |
+
# this is a replacement for a conditional edge function
|
| 24 |
+
if value == "a":
|
| 25 |
+
goto = "node_b"
|
| 26 |
+
else:
|
| 27 |
+
goto = "node_c"
|
| 28 |
+
|
| 29 |
+
# note how Command allows you to BOTH update the graph state AND route to the next node
|
| 30 |
+
return Command(
|
| 31 |
+
# this is the state update
|
| 32 |
+
update={"messages": [AIMessage(content=f"Hello {value}")]},
|
| 33 |
+
# this is a replacement for an edge
|
| 34 |
+
goto=goto,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def node_b(state: AgentState):
|
| 39 |
+
print("Called B")
|
| 40 |
+
return {"messages": [AIMessage(content="Hello B")]}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def node_c(state: AgentState):
|
| 44 |
+
print("Called C")
|
| 45 |
+
return {"messages": [AIMessage(content="Hello C")]}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
builder = StateGraph(AgentState)
|
| 49 |
+
builder.add_edge(START, "node_a")
|
| 50 |
+
builder.add_node(node_a)
|
| 51 |
+
builder.add_node(node_b)
|
| 52 |
+
builder.add_node(node_c)
|
| 53 |
+
# NOTE: there are no edges between nodes A, B and C!
|
| 54 |
+
|
| 55 |
+
command_agent = builder.compile()
|
src/agents/github_mcp_agent/github_mcp_agent.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GitHub MCP Agent - An agent that uses GitHub MCP tools for repository management."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from langchain.agents import create_agent
|
| 7 |
+
from langchain_core.tools import BaseTool
|
| 8 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 9 |
+
from langchain_mcp_adapters.sessions import StreamableHttpConnection
|
| 10 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 11 |
+
|
| 12 |
+
from agents.lazy_agent import LazyLoadingAgent
|
| 13 |
+
from core import get_model, settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
current_date = datetime.now().strftime("%B %d, %Y")
|
| 18 |
+
prompt = f"""
|
| 19 |
+
You are GitHubBot, a specialized assistant for GitHub repository management and development workflows.
|
| 20 |
+
You have access to GitHub MCP tools that allow you to interact with GitHub repositories, issues, pull requests,
|
| 21 |
+
and other GitHub resources. Today's date is {current_date}.
|
| 22 |
+
|
| 23 |
+
Your capabilities include:
|
| 24 |
+
- Repository management (create, clone, browse)
|
| 25 |
+
- Issue management (create, list, update, close)
|
| 26 |
+
- Pull request management (create, review, merge)
|
| 27 |
+
- Branch management (create, switch, merge)
|
| 28 |
+
- File operations (read, write, search)
|
| 29 |
+
- Commit operations (create, view history)
|
| 30 |
+
|
| 31 |
+
Guidelines:
|
| 32 |
+
- Always be helpful and provide clear explanations of GitHub operations
|
| 33 |
+
- When creating or modifying content, ensure it follows best practices
|
| 34 |
+
- Be cautious with destructive operations (deletes, force pushes, etc.)
|
| 35 |
+
- Provide context about what you're doing and why
|
| 36 |
+
- Use appropriate commit messages and PR descriptions
|
| 37 |
+
- Respect repository permissions and access controls
|
| 38 |
+
|
| 39 |
+
NOTE: You have access to GitHub MCP tools that provide direct GitHub API access.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GitHubMCPAgent(LazyLoadingAgent):
|
| 44 |
+
"""GitHub MCP Agent with async initialization."""
|
| 45 |
+
|
| 46 |
+
def __init__(self) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self._mcp_tools: list[BaseTool] = []
|
| 49 |
+
self._mcp_client: MultiServerMCPClient | None = None
|
| 50 |
+
|
| 51 |
+
async def load(self) -> None:
|
| 52 |
+
"""Initialize the GitHub MCP agent by loading MCP tools."""
|
| 53 |
+
if not settings.GITHUB_PAT:
|
| 54 |
+
logger.info("GITHUB_PAT is not set, GitHub MCP agent will have no tools")
|
| 55 |
+
self._mcp_tools = []
|
| 56 |
+
self._graph = self._create_graph()
|
| 57 |
+
self._loaded = True
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Initialize MCP client directly
|
| 62 |
+
github_pat = settings.GITHUB_PAT.get_secret_value()
|
| 63 |
+
connections = {
|
| 64 |
+
"github": StreamableHttpConnection(
|
| 65 |
+
transport="streamable_http",
|
| 66 |
+
url=settings.MCP_GITHUB_SERVER_URL,
|
| 67 |
+
headers={
|
| 68 |
+
"Authorization": f"Bearer {github_pat}",
|
| 69 |
+
},
|
| 70 |
+
)
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
self._mcp_client = MultiServerMCPClient(connections)
|
| 74 |
+
logger.info("MCP client initialized successfully")
|
| 75 |
+
|
| 76 |
+
# Get tools from the client
|
| 77 |
+
self._mcp_tools = await self._mcp_client.get_tools()
|
| 78 |
+
logger.info(f"GitHub MCP agent initialized with {len(self._mcp_tools)} tools")
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Failed to initialize GitHub MCP agent: {e}")
|
| 82 |
+
self._mcp_tools = []
|
| 83 |
+
self._mcp_client = None
|
| 84 |
+
|
| 85 |
+
# Create and store the graph
|
| 86 |
+
self._graph = self._create_graph()
|
| 87 |
+
self._loaded = True
|
| 88 |
+
|
| 89 |
+
def _create_graph(self) -> CompiledStateGraph:
|
| 90 |
+
"""Create the GitHub MCP agent graph."""
|
| 91 |
+
model = get_model(settings.DEFAULT_MODEL)
|
| 92 |
+
|
| 93 |
+
return create_agent(
|
| 94 |
+
model=model,
|
| 95 |
+
tools=self._mcp_tools,
|
| 96 |
+
name="github-mcp-agent",
|
| 97 |
+
system_prompt=prompt,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Create the agent instance
|
| 102 |
+
github_mcp_agent = GitHubMCPAgent()
|
src/agents/interrupt_agent.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from langchain_core.language_models.base import LanguageModelInput
|
| 6 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 7 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
| 8 |
+
from langchain_core.prompts import SystemMessagePromptTemplate
|
| 9 |
+
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable
|
| 10 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 11 |
+
from langgraph.store.base import BaseStore
|
| 12 |
+
from langgraph.types import interrupt
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
|
| 15 |
+
from core import get_model, settings
|
| 16 |
+
|
| 17 |
+
# Added logger
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AgentState(MessagesState, total=False):
|
| 22 |
+
"""`total=False` is PEP589 specs.
|
| 23 |
+
|
| 24 |
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
birthdate: datetime | None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def wrap_model(
|
| 31 |
+
model: BaseChatModel | Runnable[LanguageModelInput, Any], system_prompt: BaseMessage
|
| 32 |
+
) -> RunnableSerializable[AgentState, Any]:
|
| 33 |
+
preprocessor = RunnableLambda(
|
| 34 |
+
lambda state: [system_prompt] + state["messages"],
|
| 35 |
+
name="StateModifier",
|
| 36 |
+
)
|
| 37 |
+
return preprocessor | model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
background_prompt = SystemMessagePromptTemplate.from_template("""
|
| 41 |
+
You are a helpful assistant that tells users there zodiac sign.
|
| 42 |
+
Provide a one sentence summary of the origin of zodiac signs.
|
| 43 |
+
Don't tell the user what their sign is, you are just demonstrating your knowledge on the topic.
|
| 44 |
+
""")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
async def background(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 48 |
+
"""This node is to demonstrate doing work before the interrupt"""
|
| 49 |
+
|
| 50 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 51 |
+
model_runnable = wrap_model(m, background_prompt.format())
|
| 52 |
+
response = await model_runnable.ainvoke(state, config)
|
| 53 |
+
|
| 54 |
+
return {"messages": [AIMessage(content=response.content)]}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
birthdate_extraction_prompt = SystemMessagePromptTemplate.from_template("""
|
| 58 |
+
You are an expert at extracting birthdates from conversational text.
|
| 59 |
+
|
| 60 |
+
Rules for extraction:
|
| 61 |
+
- Look for user messages that mention birthdates
|
| 62 |
+
- Consider various date formats (MM/DD/YYYY, YYYY-MM-DD, Month Day, Year)
|
| 63 |
+
- Validate that the date is reasonable (not in the future)
|
| 64 |
+
- If no clear birthdate was provided by the user, return None
|
| 65 |
+
""")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BirthdateExtraction(BaseModel):
|
| 69 |
+
birthdate: str | None = Field(
|
| 70 |
+
description="The extracted birthdate in YYYY-MM-DD format. If no birthdate is found, this should be None."
|
| 71 |
+
)
|
| 72 |
+
reasoning: str = Field(
|
| 73 |
+
description="Explanation of how the birthdate was extracted or why no birthdate was found"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
async def determine_birthdate(
|
| 78 |
+
state: AgentState, config: RunnableConfig, store: BaseStore
|
| 79 |
+
) -> AgentState:
|
| 80 |
+
"""This node examines the conversation history to determine user's birthdate, checking store first."""
|
| 81 |
+
|
| 82 |
+
# Attempt to get user_id for unique storage per user
|
| 83 |
+
user_id = config["configurable"].get("user_id")
|
| 84 |
+
logger.info(f"[determine_birthdate] Extracted user_id: {user_id}")
|
| 85 |
+
namespace = None
|
| 86 |
+
key = "birthdate"
|
| 87 |
+
birthdate = None # Initialize birthdate
|
| 88 |
+
|
| 89 |
+
if user_id:
|
| 90 |
+
# Use user_id in the namespace to ensure uniqueness per user
|
| 91 |
+
namespace = (user_id,)
|
| 92 |
+
|
| 93 |
+
# Check if we already have the birthdate in the store for this user
|
| 94 |
+
try:
|
| 95 |
+
result = await store.aget(namespace, key=key)
|
| 96 |
+
# Handle cases where store.aget might return Item directly or a list
|
| 97 |
+
user_data = None
|
| 98 |
+
if result: # Check if anything was returned
|
| 99 |
+
if isinstance(result, list):
|
| 100 |
+
if result: # Check if list is not empty
|
| 101 |
+
user_data = result[0]
|
| 102 |
+
else: # Assume it's the Item object directly
|
| 103 |
+
user_data = result
|
| 104 |
+
|
| 105 |
+
if user_data and user_data.value.get("birthdate"):
|
| 106 |
+
# Convert ISO format string back to datetime object
|
| 107 |
+
birthdate_str = user_data.value["birthdate"]
|
| 108 |
+
birthdate = datetime.fromisoformat(birthdate_str) if birthdate_str else None
|
| 109 |
+
# We already have the birthdate, return it
|
| 110 |
+
logger.info(
|
| 111 |
+
f"[determine_birthdate] Found birthdate in store for user {user_id}: {birthdate}"
|
| 112 |
+
)
|
| 113 |
+
return {
|
| 114 |
+
"birthdate": birthdate,
|
| 115 |
+
"messages": [],
|
| 116 |
+
}
|
| 117 |
+
except Exception as e:
|
| 118 |
+
# Log the error or handle cases where the store might be unavailable
|
| 119 |
+
logger.error(f"Error reading from store for namespace {namespace}, key {key}: {e}")
|
| 120 |
+
# Proceed with extraction if read fails
|
| 121 |
+
pass
|
| 122 |
+
else:
|
| 123 |
+
# If no user_id, we cannot reliably store/retrieve user-specific data.
|
| 124 |
+
# Consider logging this situation.
|
| 125 |
+
logger.warning(
|
| 126 |
+
"Warning: user_id not found in config. Skipping persistent birthdate storage/retrieval for this run."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# If birthdate wasn't retrieved from store, proceed with extraction
|
| 130 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 131 |
+
model_runnable = wrap_model(
|
| 132 |
+
m.with_structured_output(BirthdateExtraction), birthdate_extraction_prompt.format()
|
| 133 |
+
).with_config(tags=["skip_stream"])
|
| 134 |
+
response: BirthdateExtraction = await model_runnable.ainvoke(state, config)
|
| 135 |
+
|
| 136 |
+
# If no birthdate found after extraction attempt, interrupt
|
| 137 |
+
if response.birthdate is None:
|
| 138 |
+
birthdate_input = interrupt(f"{response.reasoning}\nPlease tell me your birthdate?")
|
| 139 |
+
# Re-run extraction with the new input
|
| 140 |
+
state["messages"].append(HumanMessage(birthdate_input))
|
| 141 |
+
# Note: Recursive call might need careful handling of depth or state updates
|
| 142 |
+
return await determine_birthdate(state, config, store)
|
| 143 |
+
|
| 144 |
+
# Birthdate found - convert string to datetime
|
| 145 |
+
try:
|
| 146 |
+
birthdate = datetime.fromisoformat(response.birthdate)
|
| 147 |
+
except ValueError:
|
| 148 |
+
# If parsing fails, ask for clarification
|
| 149 |
+
birthdate_input = interrupt(
|
| 150 |
+
"I couldn't understand the date format. Please provide your birthdate in YYYY-MM-DD format."
|
| 151 |
+
)
|
| 152 |
+
# Re-run extraction with the new input
|
| 153 |
+
state["messages"].append(HumanMessage(birthdate_input))
|
| 154 |
+
# Note: Recursive call might need careful handling of depth or state updates
|
| 155 |
+
return await determine_birthdate(state, config, store)
|
| 156 |
+
|
| 157 |
+
# Store the newly extracted birthdate only if we have a user_id
|
| 158 |
+
if user_id and namespace:
|
| 159 |
+
# Convert datetime to ISO format string for JSON serialization
|
| 160 |
+
birthdate_str = birthdate.isoformat() if birthdate else None
|
| 161 |
+
try:
|
| 162 |
+
await store.aput(namespace, key, {"birthdate": birthdate_str})
|
| 163 |
+
except Exception as e:
|
| 164 |
+
# Log the error or handle cases where the store write might fail
|
| 165 |
+
logger.error(f"Error writing to store for namespace {namespace}, key {key}: {e}")
|
| 166 |
+
|
| 167 |
+
# Return the determined birthdate (either from store or extracted)
|
| 168 |
+
logger.info(f"[determine_birthdate] Returning birthdate {birthdate} for user {user_id}")
|
| 169 |
+
return {
|
| 170 |
+
"birthdate": birthdate,
|
| 171 |
+
"messages": [],
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
response_prompt = SystemMessagePromptTemplate.from_template("""
|
| 176 |
+
You are a helpful assistant.
|
| 177 |
+
|
| 178 |
+
Known information:
|
| 179 |
+
- The user's birthdate is {birthdate_str}
|
| 180 |
+
|
| 181 |
+
User's latest message: "{last_user_message}"
|
| 182 |
+
|
| 183 |
+
Based on the known information and the user's message, provide a helpful and relevant response.
|
| 184 |
+
If the user asked for their birthdate, confirm it.
|
| 185 |
+
If the user asked for their zodiac sign, calculate it and tell them.
|
| 186 |
+
Otherwise, respond conversationally based on their message.
|
| 187 |
+
""")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def generate_response(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 191 |
+
"""Generates the final response based on the user's query and the available birthdate."""
|
| 192 |
+
birthdate = state.get("birthdate")
|
| 193 |
+
if state.get("messages") and isinstance(state["messages"][-1], HumanMessage):
|
| 194 |
+
last_user_message = state["messages"][-1].content
|
| 195 |
+
else:
|
| 196 |
+
last_user_message = ""
|
| 197 |
+
|
| 198 |
+
if not birthdate:
|
| 199 |
+
# This should ideally not be reached if determine_birthdate worked correctly and possibly interrupted.
|
| 200 |
+
# Handle cases where birthdate might still be missing.
|
| 201 |
+
return {
|
| 202 |
+
"messages": [
|
| 203 |
+
AIMessage(
|
| 204 |
+
content="I couldn't determine your birthdate. Could you please provide it?"
|
| 205 |
+
)
|
| 206 |
+
]
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
birthdate_str = birthdate.strftime("%B %d, %Y") # Format for display
|
| 210 |
+
|
| 211 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 212 |
+
model_runnable = wrap_model(
|
| 213 |
+
m, response_prompt.format(birthdate_str=birthdate_str, last_user_message=last_user_message)
|
| 214 |
+
)
|
| 215 |
+
response = await model_runnable.ainvoke(state, config)
|
| 216 |
+
|
| 217 |
+
return {"messages": [AIMessage(content=response.content)]}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Define the graph
|
| 221 |
+
agent = StateGraph(AgentState)
|
| 222 |
+
agent.add_node("background", background)
|
| 223 |
+
agent.add_node("determine_birthdate", determine_birthdate)
|
| 224 |
+
agent.add_node("generate_response", generate_response)
|
| 225 |
+
|
| 226 |
+
agent.set_entry_point("background")
|
| 227 |
+
agent.add_edge("background", "determine_birthdate")
|
| 228 |
+
agent.add_edge("determine_birthdate", "generate_response")
|
| 229 |
+
agent.add_edge("generate_response", END)
|
| 230 |
+
|
| 231 |
+
interrupt_agent = agent.compile()
|
| 232 |
+
interrupt_agent.name = "interrupt-agent"
|
src/agents/knowledge_base_agent.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from langchain_aws import AmazonKnowledgeBasesRetriever
|
| 6 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 7 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 8 |
+
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
|
| 9 |
+
from langchain_core.runnables.base import RunnableSequence
|
| 10 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 11 |
+
from langgraph.managed import RemainingSteps
|
| 12 |
+
|
| 13 |
+
from core import get_model, settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Define the state
|
| 19 |
+
class AgentState(MessagesState, total=False):
|
| 20 |
+
"""State for Knowledge Base agent."""
|
| 21 |
+
|
| 22 |
+
remaining_steps: RemainingSteps
|
| 23 |
+
retrieved_documents: list[dict[str, Any]]
|
| 24 |
+
kb_documents: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Create the retriever
|
| 28 |
+
def get_kb_retriever():
|
| 29 |
+
"""Create and return a Knowledge Base retriever instance."""
|
| 30 |
+
# Get the Knowledge Base ID from environment
|
| 31 |
+
kb_id = os.environ.get("AWS_KB_ID", "")
|
| 32 |
+
if not kb_id:
|
| 33 |
+
raise ValueError("AWS_KB_ID environment variable must be set")
|
| 34 |
+
|
| 35 |
+
# Create the retriever with the specified Knowledge Base ID
|
| 36 |
+
retriever = AmazonKnowledgeBasesRetriever(
|
| 37 |
+
knowledge_base_id=kb_id,
|
| 38 |
+
retrieval_config={
|
| 39 |
+
"vectorSearchConfiguration": {
|
| 40 |
+
"numberOfResults": 3,
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
return retriever
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
|
| 48 |
+
"""Wrap the model with a system prompt for the Knowledge Base agent."""
|
| 49 |
+
|
| 50 |
+
def create_system_message(state):
|
| 51 |
+
base_prompt = """You are a helpful assistant that provides accurate information based on retrieved documents.
|
| 52 |
+
|
| 53 |
+
You will receive a query along with relevant documents retrieved from a knowledge base. Use these documents to inform your response.
|
| 54 |
+
|
| 55 |
+
Follow these guidelines:
|
| 56 |
+
1. Base your answer primarily on the retrieved documents
|
| 57 |
+
2. If the documents contain the answer, provide it clearly and concisely
|
| 58 |
+
3. If the documents are insufficient, state that you don't have enough information
|
| 59 |
+
4. Never make up facts or information not present in the documents
|
| 60 |
+
5. Always cite the source documents when referring to specific information
|
| 61 |
+
6. If the documents contradict each other, acknowledge this and explain the different perspectives
|
| 62 |
+
|
| 63 |
+
Format your response in a clear, conversational manner. Use markdown formatting when appropriate.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
# Check if documents were retrieved
|
| 67 |
+
if "kb_documents" in state:
|
| 68 |
+
# Append document information to the system prompt
|
| 69 |
+
document_prompt = f"\n\nI've retrieved the following documents that may be relevant to the query:\n\n{state['kb_documents']}\n\nPlease use these documents to inform your response to the user's query. Only use information from these documents and clearly indicate when you are unsure."
|
| 70 |
+
return [SystemMessage(content=base_prompt + document_prompt)] + state["messages"]
|
| 71 |
+
else:
|
| 72 |
+
# No documents were retrieved
|
| 73 |
+
no_docs_prompt = (
|
| 74 |
+
"\n\nNo relevant documents were found in the knowledge base for this query."
|
| 75 |
+
)
|
| 76 |
+
return [SystemMessage(content=base_prompt + no_docs_prompt)] + state["messages"]
|
| 77 |
+
|
| 78 |
+
preprocessor = RunnableLambda(
|
| 79 |
+
create_system_message,
|
| 80 |
+
name="StateModifier",
|
| 81 |
+
)
|
| 82 |
+
return RunnableSequence(preprocessor, model)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
async def retrieve_documents(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 86 |
+
"""Retrieve relevant documents from the knowledge base."""
|
| 87 |
+
# Get the last human message
|
| 88 |
+
human_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
|
| 89 |
+
if not human_messages:
|
| 90 |
+
# Include messages from original state
|
| 91 |
+
return {"messages": [], "retrieved_documents": []}
|
| 92 |
+
|
| 93 |
+
# Use the last human message as the query
|
| 94 |
+
query = human_messages[-1].content
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Initialize the retriever
|
| 98 |
+
retriever = get_kb_retriever()
|
| 99 |
+
|
| 100 |
+
# Retrieve documents
|
| 101 |
+
retrieved_docs = await retriever.ainvoke(query)
|
| 102 |
+
|
| 103 |
+
# Create document summaries for the state
|
| 104 |
+
document_summaries = []
|
| 105 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 106 |
+
summary = {
|
| 107 |
+
"id": doc.metadata.get("id", f"doc-{i}"),
|
| 108 |
+
"source": doc.metadata.get("source", "Unknown"),
|
| 109 |
+
"title": doc.metadata.get("title", f"Document {i}"),
|
| 110 |
+
"content": doc.page_content,
|
| 111 |
+
"relevance_score": doc.metadata.get("score", 0),
|
| 112 |
+
}
|
| 113 |
+
document_summaries.append(summary)
|
| 114 |
+
|
| 115 |
+
logger.info(f"Retrieved {len(document_summaries)} documents for query: {query[:50]}...")
|
| 116 |
+
|
| 117 |
+
return {"retrieved_documents": document_summaries, "messages": []}
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Error retrieving documents: {str(e)}")
|
| 121 |
+
return {"retrieved_documents": [], "messages": []}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
async def prepare_augmented_prompt(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 125 |
+
"""Prepare a prompt augmented with retrieved document content."""
|
| 126 |
+
# Get retrieved documents
|
| 127 |
+
documents = state.get("retrieved_documents", [])
|
| 128 |
+
|
| 129 |
+
if not documents:
|
| 130 |
+
return {"messages": []}
|
| 131 |
+
|
| 132 |
+
# Format retrieved documents for the model
|
| 133 |
+
formatted_docs = "\n\n".join(
|
| 134 |
+
[
|
| 135 |
+
f"--- Document {i + 1} ---\n"
|
| 136 |
+
f"Source: {doc.get('source', 'Unknown')}\n"
|
| 137 |
+
f"Title: {doc.get('title', 'Unknown')}\n\n"
|
| 138 |
+
f"{doc.get('content', '')}"
|
| 139 |
+
for i, doc in enumerate(documents)
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Store formatted documents in the state
|
| 144 |
+
return {"kb_documents": formatted_docs, "messages": []}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 148 |
+
"""Generate a response based on the retrieved documents."""
|
| 149 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 150 |
+
model_runnable = wrap_model(m)
|
| 151 |
+
|
| 152 |
+
response = await model_runnable.ainvoke(state, config)
|
| 153 |
+
|
| 154 |
+
return {"messages": [response]}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Define the graph
|
| 158 |
+
agent = StateGraph(AgentState)
|
| 159 |
+
|
| 160 |
+
# Add nodes
|
| 161 |
+
agent.add_node("retrieve_documents", retrieve_documents)
|
| 162 |
+
agent.add_node("prepare_augmented_prompt", prepare_augmented_prompt)
|
| 163 |
+
agent.add_node("model", acall_model)
|
| 164 |
+
|
| 165 |
+
# Set entry point
|
| 166 |
+
agent.set_entry_point("retrieve_documents")
|
| 167 |
+
|
| 168 |
+
# Add edges to define the flow
|
| 169 |
+
agent.add_edge("retrieve_documents", "prepare_augmented_prompt")
|
| 170 |
+
agent.add_edge("prepare_augmented_prompt", "model")
|
| 171 |
+
agent.add_edge("model", END)
|
| 172 |
+
|
| 173 |
+
# Compile the agent
|
| 174 |
+
kb_agent = agent.compile()
|
src/agents/langgraph_supervisor_agent.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from langchain.agents import create_agent
|
| 4 |
+
from langgraph_supervisor import create_supervisor
|
| 5 |
+
|
| 6 |
+
from core import get_model, settings
|
| 7 |
+
|
| 8 |
+
model = get_model(settings.DEFAULT_MODEL)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def add(a: float, b: float) -> float:
|
| 12 |
+
"""Add two numbers."""
|
| 13 |
+
return a + b
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def multiply(a: float, b: float) -> float:
|
| 17 |
+
"""Multiply two numbers."""
|
| 18 |
+
return a * b
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def web_search(query: str) -> str:
|
| 22 |
+
"""Search the web for information."""
|
| 23 |
+
return (
|
| 24 |
+
"Here are the headcounts for each of the FAANG companies in 2024:\n"
|
| 25 |
+
"1. **Facebook (Meta)**: 67,317 employees.\n"
|
| 26 |
+
"2. **Apple**: 164,000 employees.\n"
|
| 27 |
+
"3. **Amazon**: 1,551,000 employees.\n"
|
| 28 |
+
"4. **Netflix**: 14,000 employees.\n"
|
| 29 |
+
"5. **Google (Alphabet)**: 181,269 employees."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
math_agent: Any = create_agent(
|
| 34 |
+
model=model,
|
| 35 |
+
tools=[add, multiply],
|
| 36 |
+
name="sub-agent-math_expert",
|
| 37 |
+
system_prompt="You are a math expert. Always use one tool at a time.",
|
| 38 |
+
).with_config(tags=["skip_stream"])
|
| 39 |
+
|
| 40 |
+
research_agent: Any = create_agent(
|
| 41 |
+
model=model,
|
| 42 |
+
tools=[web_search],
|
| 43 |
+
name="sub-agent-research_expert",
|
| 44 |
+
system_prompt="You are a world class researcher with access to web search. Do not do any math.",
|
| 45 |
+
).with_config(tags=["skip_stream"])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Create supervisor workflow
|
| 49 |
+
workflow = create_supervisor(
|
| 50 |
+
[research_agent, math_agent],
|
| 51 |
+
model=model,
|
| 52 |
+
prompt=(
|
| 53 |
+
"You are a team supervisor managing a research expert and a math expert. "
|
| 54 |
+
"For current events, use research_agent. "
|
| 55 |
+
"For math problems, use math_agent."
|
| 56 |
+
),
|
| 57 |
+
add_handoff_back_messages=True,
|
| 58 |
+
# UI now expects this to be True so we don't have to guess when a handoff back occurs
|
| 59 |
+
output_mode="full_history", # otherwise when reloading conversations, the sub-agents' messages are not included
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
langgraph_supervisor_agent = workflow.compile()
|
src/agents/langgraph_supervisor_hierarchy_agent.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.agents import create_agent
|
| 2 |
+
from langgraph_supervisor import create_supervisor
|
| 3 |
+
|
| 4 |
+
from agents.langgraph_supervisor_agent import add, multiply, web_search
|
| 5 |
+
from core import get_model, settings
|
| 6 |
+
|
| 7 |
+
model = get_model(settings.DEFAULT_MODEL)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def workflow(chosen_model):
|
| 11 |
+
math_agent = create_agent(
|
| 12 |
+
model=chosen_model,
|
| 13 |
+
tools=[add, multiply],
|
| 14 |
+
name="sub-agent-math_expert", # Identify the graph node as a sub-agent
|
| 15 |
+
system_prompt="You are a math expert. Always use one tool at a time.",
|
| 16 |
+
).with_config(tags=["skip_stream"])
|
| 17 |
+
|
| 18 |
+
research_agent = (
|
| 19 |
+
create_supervisor(
|
| 20 |
+
[math_agent],
|
| 21 |
+
model=chosen_model,
|
| 22 |
+
tools=[web_search],
|
| 23 |
+
prompt="You are a world class researcher with access to web search. Do not do any math, you have a math expert for that. ",
|
| 24 |
+
supervisor_name="supervisor-research_expert", # Identify the graph node as a supervisor to the math agent
|
| 25 |
+
)
|
| 26 |
+
.compile(
|
| 27 |
+
name="sub-agent-research_expert"
|
| 28 |
+
) # Identify the graph node as a sub-agent to the main supervisor
|
| 29 |
+
.with_config(tags=["skip_stream"])
|
| 30 |
+
) # Stream tokens are ignored for sub-agents in the UI
|
| 31 |
+
|
| 32 |
+
# Create supervisor workflow
|
| 33 |
+
return create_supervisor(
|
| 34 |
+
[research_agent],
|
| 35 |
+
model=chosen_model,
|
| 36 |
+
prompt=(
|
| 37 |
+
"You are a team supervisor managing a research expert with math capabilities."
|
| 38 |
+
"For current events, use research_agent. "
|
| 39 |
+
),
|
| 40 |
+
add_handoff_back_messages=True,
|
| 41 |
+
# UI now expects this to be True so we don't have to guess when a handoff back occurs
|
| 42 |
+
output_mode="full_history", # otherwise when reloading conversations, the sub-agents' messages are not included
|
| 43 |
+
) # default name for supervisor is "supervisor".
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
langgraph_supervisor_hierarchy_agent = workflow(model).compile()
|
src/agents/lazy_agent.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent types with async initialization and dynamic graph creation."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 6 |
+
from langgraph.pregel import Pregel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LazyLoadingAgent(ABC):
|
| 10 |
+
"""Base class for agents that require async loading."""
|
| 11 |
+
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
"""Initialize the agent."""
|
| 14 |
+
self._loaded = False
|
| 15 |
+
self._graph: CompiledStateGraph | Pregel | None = None
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
async def load(self) -> None:
|
| 19 |
+
"""
|
| 20 |
+
Perform async loading for this agent.
|
| 21 |
+
|
| 22 |
+
This method is called during service startup and should handle:
|
| 23 |
+
- Setting up external connections (MCP clients, databases, etc.)
|
| 24 |
+
- Loading tools or resources
|
| 25 |
+
- Any other async setup required
|
| 26 |
+
- Creating the agent's graph
|
| 27 |
+
"""
|
| 28 |
+
raise NotImplementedError # pragma: no cover
|
| 29 |
+
|
| 30 |
+
def get_graph(self) -> CompiledStateGraph | Pregel:
|
| 31 |
+
"""
|
| 32 |
+
Get the agent's graph.
|
| 33 |
+
|
| 34 |
+
Returns the graph instance that was created during load().
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
The agent's graph (CompiledStateGraph or Pregel)
|
| 38 |
+
"""
|
| 39 |
+
if not self._loaded:
|
| 40 |
+
raise RuntimeError("Agent not loaded. Call load() first.")
|
| 41 |
+
if self._graph is None:
|
| 42 |
+
raise RuntimeError("Agent graph not created during load().")
|
| 43 |
+
return self._graph
|
src/agents/llama_guard.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
| 4 |
+
from langchain_core.prompts import PromptTemplate
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
from core import get_model, settings
|
| 8 |
+
from schema.models import GroqModelName
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SafetyAssessment(Enum):
|
| 12 |
+
SAFE = "safe"
|
| 13 |
+
UNSAFE = "unsafe"
|
| 14 |
+
ERROR = "error"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LlamaGuardOutput(BaseModel):
|
| 18 |
+
safety_assessment: SafetyAssessment = Field(description="The safety assessment of the content.")
|
| 19 |
+
unsafe_categories: list[str] = Field(
|
| 20 |
+
description="If content is unsafe, the list of unsafe categories.", default=[]
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
unsafe_content_categories = {
|
| 25 |
+
"S1": "Violent Crimes.",
|
| 26 |
+
"S2": "Non-Violent Crimes.",
|
| 27 |
+
"S3": "Sex Crimes.",
|
| 28 |
+
"S4": "Child Exploitation.",
|
| 29 |
+
"S5": "Defamation.",
|
| 30 |
+
"S6": "Specialized Advice.",
|
| 31 |
+
"S7": "Privacy.",
|
| 32 |
+
"S8": "Intellectual Property.",
|
| 33 |
+
"S9": "Indiscriminate Weapons.",
|
| 34 |
+
"S10": "Hate.",
|
| 35 |
+
"S11": "Self-Harm.",
|
| 36 |
+
"S12": "Sexual Content.",
|
| 37 |
+
"S13": "Elections.",
|
| 38 |
+
"S14": "Code Interpreter Abuse.",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()])
|
| 42 |
+
|
| 43 |
+
llama_guard_instructions = f"""
|
| 44 |
+
Task: Check if there is unsafe content in '{{role}}' messages in conversations according to our safety policy with the below categories.
|
| 45 |
+
|
| 46 |
+
<BEGIN UNSAFE CONTENT CATEGORIES>
|
| 47 |
+
{categories_str}
|
| 48 |
+
<END UNSAFE CONTENT CATEGORIES>
|
| 49 |
+
|
| 50 |
+
<BEGIN CONVERSATION>
|
| 51 |
+
|
| 52 |
+
{{conversation_history}}
|
| 53 |
+
|
| 54 |
+
<END CONVERSATION>
|
| 55 |
+
|
| 56 |
+
Provide your safety assessment for ONLY THE LAST {{role}} in the above conversation:
|
| 57 |
+
- First line must read 'safe' or 'unsafe'.
|
| 58 |
+
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_llama_guard_output(output: str) -> LlamaGuardOutput:
|
| 62 |
+
if output == "safe":
|
| 63 |
+
return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE)
|
| 64 |
+
parsed_output = output.split("\n")
|
| 65 |
+
if len(parsed_output) != 2 or parsed_output[0] != "unsafe":
|
| 66 |
+
return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR)
|
| 67 |
+
try:
|
| 68 |
+
categories = parsed_output[1].split(",")
|
| 69 |
+
readable_categories = [unsafe_content_categories[c.strip()].strip(".") for c in categories]
|
| 70 |
+
return LlamaGuardOutput(
|
| 71 |
+
safety_assessment=SafetyAssessment.UNSAFE,
|
| 72 |
+
unsafe_categories=readable_categories,
|
| 73 |
+
)
|
| 74 |
+
except KeyError:
|
| 75 |
+
return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LlamaGuard:
|
| 79 |
+
def __init__(self) -> None:
|
| 80 |
+
if settings.GROQ_API_KEY is None:
|
| 81 |
+
print("GROQ_API_KEY not set, skipping LlamaGuard")
|
| 82 |
+
self.model = None
|
| 83 |
+
return
|
| 84 |
+
self.model = get_model(GroqModelName.LLAMA_GUARD_4_12B).with_config(tags=["skip_stream"])
|
| 85 |
+
self.prompt = PromptTemplate.from_template(llama_guard_instructions)
|
| 86 |
+
|
| 87 |
+
def _compile_prompt(self, role: str, messages: list[AnyMessage]) -> str:
|
| 88 |
+
role_mapping = {"ai": "Agent", "human": "User"}
|
| 89 |
+
messages_str = [
|
| 90 |
+
f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]
|
| 91 |
+
]
|
| 92 |
+
conversation_history = "\n\n".join(messages_str)
|
| 93 |
+
return self.prompt.format(role=role, conversation_history=conversation_history)
|
| 94 |
+
|
| 95 |
+
def invoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput:
|
| 96 |
+
if self.model is None:
|
| 97 |
+
return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE)
|
| 98 |
+
compiled_prompt = self._compile_prompt(role, messages)
|
| 99 |
+
result = self.model.invoke([HumanMessage(content=compiled_prompt)])
|
| 100 |
+
return parse_llama_guard_output(str(result.content))
|
| 101 |
+
|
| 102 |
+
async def ainvoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput:
|
| 103 |
+
if self.model is None:
|
| 104 |
+
return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE)
|
| 105 |
+
compiled_prompt = self._compile_prompt(role, messages)
|
| 106 |
+
result = await self.model.ainvoke([HumanMessage(content=compiled_prompt)])
|
| 107 |
+
return parse_llama_guard_output(str(result.content))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
llama_guard = LlamaGuard()
|
| 112 |
+
output = llama_guard.invoke(
|
| 113 |
+
"Agent",
|
| 114 |
+
[
|
| 115 |
+
HumanMessage(content="What's a good way to harm an animal?"),
|
| 116 |
+
AIMessage(
|
| 117 |
+
content="There are many ways to harm animals, but some include hitting them with a stick, throwing rocks at them, or poisoning them."
|
| 118 |
+
),
|
| 119 |
+
],
|
| 120 |
+
)
|
| 121 |
+
print(output)
|
src/agents/portfolio_agent/database_search.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain_core.tools import BaseTool, tool
|
| 5 |
+
from memory.postgres import load_pgvector_retriever
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def hybrid_search(query: str, k: int = 6) -> List[Document]:
|
| 9 |
+
retriever = load_pgvector_retriever(k)
|
| 10 |
+
return retriever.invoke(query)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def format_contexts(documents: List[Document]) -> str:
|
| 14 |
+
formatted_docs = []
|
| 15 |
+
|
| 16 |
+
for i, doc in enumerate(documents):
|
| 17 |
+
source = doc.metadata.get("source_db", "Unknown")
|
| 18 |
+
content = doc.page_content.strip()
|
| 19 |
+
|
| 20 |
+
formatted_docs.append(
|
| 21 |
+
f"--- Document {i + 1} (Source: {source}) ---\n{content}"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
return "\n\n".join(formatted_docs)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# DATABASE SEARCH TOOL
|
| 28 |
+
def database_search_func(query: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Search Anuj Joshi's portfolio database for any information.
|
| 31 |
+
Use this tool whenever the user asks about education, experience, projects, blog posts, testimonials, or any other information NOT available
|
| 32 |
+
in your system prompt.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
documents = hybrid_search(query, k=6)
|
| 36 |
+
|
| 37 |
+
if not documents:
|
| 38 |
+
return "No relevant portfolio information found."
|
| 39 |
+
|
| 40 |
+
return format_contexts(documents)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
database_search: BaseTool = tool(database_search_func)
|
| 44 |
+
database_search.name = "Database_Search"
|
src/agents/portfolio_agent/portfolio_agent.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, List
|
| 2 |
+
from langchain_core.messages import AIMessage, SystemMessage
|
| 3 |
+
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
| 4 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 5 |
+
from agents.portfolio_agent.prompt import SYSTEM_PROMPT, FOLLOWUP_GENERATION_PROMPT
|
| 6 |
+
from langgraph.prebuilt import ToolNode
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from core import get_model, settings
|
| 9 |
+
from agents.portfolio_agent.database_search import database_search
|
| 10 |
+
|
| 11 |
+
# State Extension
|
| 12 |
+
class AgentState(MessagesState):
|
| 13 |
+
"""Agent state for the portfolio assistant."""
|
| 14 |
+
summary: str
|
| 15 |
+
follow_up: List[str]
|
| 16 |
+
|
| 17 |
+
class FollowUpQuestions(BaseModel):
|
| 18 |
+
questions: List[str] = Field(description="A list of follow-up questions")
|
| 19 |
+
|
| 20 |
+
# Tools
|
| 21 |
+
tools = [database_search]
|
| 22 |
+
|
| 23 |
+
# Portfolio Model Wrapper (WITH tools)
|
| 24 |
+
def wrap_portfolio_model(model):
|
| 25 |
+
model = model.bind_tools(tools)
|
| 26 |
+
|
| 27 |
+
def prepare_messages(state: AgentState):
|
| 28 |
+
summary = state.get("summary", "")
|
| 29 |
+
system_content = SYSTEM_PROMPT
|
| 30 |
+
if summary:
|
| 31 |
+
system_content += f"\n\n### Previous Conversation Summary:\n{summary}"
|
| 32 |
+
|
| 33 |
+
return [SystemMessage(content=system_content)] + state["messages"]
|
| 34 |
+
|
| 35 |
+
return RunnableLambda(prepare_messages) | model
|
| 36 |
+
|
| 37 |
+
async def call_portfolio_model(state: AgentState, config: RunnableConfig):
|
| 38 |
+
model = get_model(settings.DEFAULT_MODEL)
|
| 39 |
+
runnable = wrap_portfolio_model(model)
|
| 40 |
+
response = await runnable.ainvoke(state, config)
|
| 41 |
+
return {"messages": [response]}
|
| 42 |
+
|
| 43 |
+
async def generate_followup(state: AgentState, config: RunnableConfig):
|
| 44 |
+
model = get_model(settings.DEFAULT_MODEL)
|
| 45 |
+
chain = model.with_structured_output(FollowUpQuestions)
|
| 46 |
+
print("Followup messages: ", state["messages"])
|
| 47 |
+
messages = [SystemMessage(content=FOLLOWUP_GENERATION_PROMPT)] + state["messages"]
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
response = await chain.ainvoke(messages, config)
|
| 51 |
+
return {"follow_up": response.questions}
|
| 52 |
+
except Exception:
|
| 53 |
+
return {"follow_up": ["What are his contact details?", "What are his projects?", "What are his skills?", "What are his work experiences?", "What are his achievements?", "What are his leadership roles?"]}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Tool Routing
|
| 57 |
+
def route_after_model(state: AgentState) -> Literal["tools", "generate_followup"]:
|
| 58 |
+
last_message = state["messages"][-1]
|
| 59 |
+
|
| 60 |
+
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
| 61 |
+
return "tools"
|
| 62 |
+
return "generate_followup"
|
| 63 |
+
|
| 64 |
+
# Graph Definition
|
| 65 |
+
graph = StateGraph(AgentState)
|
| 66 |
+
|
| 67 |
+
graph.add_node("agent", call_portfolio_model)
|
| 68 |
+
graph.add_node("tools", ToolNode(tools))
|
| 69 |
+
graph.add_node("generate_followup", generate_followup)
|
| 70 |
+
|
| 71 |
+
graph.set_entry_point("agent")
|
| 72 |
+
|
| 73 |
+
graph.add_conditional_edges(
|
| 74 |
+
"agent",
|
| 75 |
+
route_after_model,
|
| 76 |
+
{
|
| 77 |
+
"tools": "tools",
|
| 78 |
+
"generate_followup": "generate_followup",
|
| 79 |
+
},
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
graph.add_edge("tools", "agent")
|
| 83 |
+
graph.add_edge("generate_followup", END)
|
| 84 |
+
|
| 85 |
+
portfolio_agent = graph.compile(debug=True)
|
src/agents/portfolio_agent/prompt.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PORTFOLIO_URL = "https://anujjoshi.netlify.app"
|
| 2 |
+
|
| 3 |
+
SYSTEM_PROMPT = f"""
|
| 4 |
+
You are an award-winning Professional Portfolio Assistant representing Anuj Joshi.
|
| 5 |
+
Your goal is to answer questions from visitors about Anuj's skills, projects, and experience and also cite the sources of information.
|
| 6 |
+
You are NOT chatting with Anuj Joshi himself, but with a recruiter, potential employer, or visitor of his portfolio.
|
| 7 |
+
|
| 8 |
+
Name: Anuj Joshi
|
| 9 |
+
Role: Full Stack Developer | AI & Machine Learning Engineer
|
| 10 |
+
Location: New Delhi, India
|
| 11 |
+
|
| 12 |
+
# CONTACT ({PORTFOLIO_URL}/contact)
|
| 13 |
+
- Email: anujjoshi3105@gmail.com
|
| 14 |
+
- LinkedIn: https://www.linkedin.com/in/anujjoshi3105/
|
| 15 |
+
- X (Twitter): https://x.com/anujjoshi3105
|
| 16 |
+
|
| 17 |
+
# Competitive Programming ({PORTFOLIO_URL}/about)
|
| 18 |
+
- LeetCode: https://leetcode.com/u/anujjoshi3105 Max Rating: 1910, Level: Knight, 750+ Problems Solved
|
| 19 |
+
- Codeforces: https://codeforces.com/profile/anujjoshi3105 Max Rating: 1434, Level: specialist
|
| 20 |
+
- AtCoder: https://atcoder.jp/users/anujjoshi3105 Max Rating: 929, Level: Green
|
| 21 |
+
- GeeksforGeeks: https://www.geeksforgeeks.org/profile/anujjoshi3105 Institute Rank: 46
|
| 22 |
+
|
| 23 |
+
# Summary ({PORTFOLIO_URL}/about)
|
| 24 |
+
Anuj Joshi is a full-stack developer and AI engineer with hands-on experience building
|
| 25 |
+
production-grade AI agents, healthcare assistants, computer vision systems, and scalable
|
| 26 |
+
web platforms. His work spans AI agents, LLM systems, backend engineering, and applied ML,
|
| 27 |
+
with experience in startups, research-driven teams, and academic organizations.
|
| 28 |
+
|
| 29 |
+
# EDUCATION ({PORTFOLIO_URL}/about#education)
|
| 30 |
+
## Bachelor of Technology (B.Tech)
|
| 31 |
+
- Field: Computer Science & Engineering
|
| 32 |
+
- Minor: Machine Learning
|
| 33 |
+
- Institution: Delhi Technological University (DTU), Delhi, India
|
| 34 |
+
- Duration: November 2022 – May 2026
|
| 35 |
+
- CGPA: 9.35 / 10
|
| 36 |
+
- Strong academic performance in Machine Learning, Deep Learning, and Artificial Intelligence
|
| 37 |
+
- Core coursework includes: Machine Learning, Deep Learning, Artificial Intelligence, Operating Systems, Database Management Systems, Computer Networks
|
| 38 |
+
|
| 39 |
+
## CBSE Class XII
|
| 40 |
+
https://drive.google.com/file/d/14EcEdGaikR0dynanY7NGLPGMNMUhx9Ds/view?usp=sharing
|
| 41 |
+
- Stream: PCM (Physics, Chemistry, Mathematics)
|
| 42 |
+
- Institution: Vivekanand International School, Delhi, India
|
| 43 |
+
- Duration: April 2021 – July 2022
|
| 44 |
+
- Score: 98.8%
|
| 45 |
+
- Perfect scores in Mathematics and Chemistry
|
| 46 |
+
|
| 47 |
+
## CBSE Class X
|
| 48 |
+
https://drive.google.com/file/d/14CHsmHp3kvbjze9o3cMxxihrKo5IMRyI/view?usp=sharing
|
| 49 |
+
- Institution: Vivekanand International School, Delhi, India
|
| 50 |
+
- Duration: April 2019 – March 2020
|
| 51 |
+
- Score: 97.0%
|
| 52 |
+
- Scored 99 in Mathematics, Science, and Computer Science
|
| 53 |
+
|
| 54 |
+
# WORK EXPERIENCE ({PORTFOLIO_URL}/about#experience)
|
| 55 |
+
## 1) Full Stack Developer Intern – Quickintell (Remote, 2025)
|
| 56 |
+
- Built secure AI voice and chat agents integrated with EHR APIs for HIPAA-compliant patient identity verification.
|
| 57 |
+
- Developed production-ready healthcare assistants using LangChain, AWS Lambda, and vector databases.
|
| 58 |
+
- Improved retrieval performance and system scalability.
|
| 59 |
+
|
| 60 |
+
## 2) Software Developer Intern – ITP Electronics (Delhi, 2024)
|
| 61 |
+
https://drive.google.com/file/d/15Jzu-oujhKUDZiWoQf1WWWN6ZJnmBqQH/view?usp=drive_link
|
| 62 |
+
- Developed computer vision solutions for automated wire harness detection.
|
| 63 |
+
- Reduced manual inspection time using OpenCV-based vision pipelines.
|
| 64 |
+
- Built image-to-BoM automation systems using backend services.
|
| 65 |
+
|
| 66 |
+
## 3) Web Developer Intern – USIP-DTU (Delhi, 2024)
|
| 67 |
+
https://drive.google.com/file/d/150EAtBVjP1DV-b_v0JKhVYzhIVoCvAWO/view
|
| 68 |
+
- Built a full-stack ERP system using PHP, Node.js, and MySQL.
|
| 69 |
+
- Implemented JWT-based role-based access control (RBAC).
|
| 70 |
+
- Automated proposal lifecycle and document management for 50+ users.
|
| 71 |
+
|
| 72 |
+
# PROJECTS ({PORTFOLIO_URL}/project#projects)
|
| 73 |
+
- Ekalavya – AI-powered EdTech SaaS with adaptive learning, quizzes, and AI tutors.
|
| 74 |
+
- BITLOG – Developer-focused blogging platform with authentication and scalability.
|
| 75 |
+
- Industrial Research & Development Centre Portal – Research workflow automation platform.
|
| 76 |
+
- NicoGauge – ML-based evaluation platform for learning analytics.
|
| 77 |
+
- Fictiora – Entertainment discovery platform using Next.js and TMDB APIs.
|
| 78 |
+
For more info visit {PORTFOLIO_URL}/project
|
| 79 |
+
|
| 80 |
+
# TECHNICAL SKILLS ({PORTFOLIO_URL}/about#skills)
|
| 81 |
+
- AI / ML: PyTorch, TensorFlow, Keras, Scikit-learn, CNNs, GANs, Transformers, LangChain, LLM-based agents
|
| 82 |
+
- Backend / Full Stack: Python, Node.js, PHP, FastAPI, Flask, Django, Express.js, REST APIs, JWT Authentication, RBAC
|
| 83 |
+
- Databases: PostgreSQL, MySQL, MongoDB, Vector Databases
|
| 84 |
+
- Frontend: React, Next.js, TailwindCSS, JavaScript
|
| 85 |
+
- DevOps / Tools: AWS Lambda, Docker, Git
|
| 86 |
+
|
| 87 |
+
# LEADERSHIP ROLES & VOLUNTEERING ({PORTFOLIO_URL}/about#experience)
|
| 88 |
+
- General Secretary, Society of Robotics (DTU): Led a 50+ member team and organized robotics events with 200+ participants.
|
| 89 |
+
- Volunteer, Summer School on AI (DTU) https://drive.google.com/file/d/10Jx3yC8gmFYHkl0KXucaUOZJqtf9QkJq/view?usp=drive_link: Supported hands-on sessions on deep learning, transformers, and generative AI.
|
| 90 |
+
|
| 91 |
+
# For other relevant information always use **Database_Search** tool and cite {PORTFOLIO_URL}/blog**
|
| 92 |
+
|
| 93 |
+
# TOOL USAGE RULES (STRICT & CRITICAL):
|
| 94 |
+
1. **INCOMPLETE INFORMATION**: This system prompt ONLY contains a basic overview. It does NOT contain blog posts, specific contest results, latest activities, or deep technical details.
|
| 95 |
+
2. **SEARCH FIRST**: If a user asks about Contests/Competitions, Blog Posts/Articles or Technical Deep-Dives, You **MUST** call the `Database_Search` tool immediately.
|
| 96 |
+
3. **NO "I DON'T KNOW" WITHOUT SEARCH**: Never tell the user information is unavailable until AFTER you have used `Database_Search`.
|
| 97 |
+
4. **DIRECT ACTION**: Do NOT ask for permission to search. Do NOT explain that you are searching. Just call the tool.
|
| 98 |
+
5. **CITE SOURCES**: Always mention that more info is available at {PORTFOLIO_URL}/blog?q=[relevant-query] and use the content from the tool accurately.
|
| 99 |
+
|
| 100 |
+
# STYLE: Professional, concise, witty and helpful.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
FOLLOWUP_GENERATION_PROMPT = """
|
| 104 |
+
Generate 3–5 unique short follow-up options for the user based on the last AI message.
|
| 105 |
+
|
| 106 |
+
Rules:
|
| 107 |
+
1. If the AI asked a question or offered a choice (e.g., to search the database), include relevant replies like "Accept", "Reject".
|
| 108 |
+
2. If the AI provided information, generate follow-up questions to explore Anuj Joshi’s portfolio (skills, projects, experience, etc.).
|
| 109 |
+
3. Keep options concise and unique (2-6 words).
|
| 110 |
+
|
| 111 |
+
Return ONLY valid JSON:
|
| 112 |
+
{
|
| 113 |
+
"questions": ["option 1", "option 2", ...]
|
| 114 |
+
}
|
| 115 |
+
"""
|
src/agents/rag_assistant.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 5 |
+
from langchain_core.messages import AIMessage, SystemMessage
|
| 6 |
+
from langchain_core.runnables import (
|
| 7 |
+
RunnableConfig,
|
| 8 |
+
RunnableLambda,
|
| 9 |
+
RunnableSerializable,
|
| 10 |
+
)
|
| 11 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 12 |
+
from langgraph.managed import RemainingSteps
|
| 13 |
+
from langgraph.prebuilt import ToolNode
|
| 14 |
+
|
| 15 |
+
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
|
| 16 |
+
from agents.tools import database_search
|
| 17 |
+
from core import get_model, settings
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AgentState(MessagesState, total=False):
|
| 21 |
+
"""`total=False` is PEP589 specs.
|
| 22 |
+
|
| 23 |
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
safety: LlamaGuardOutput
|
| 27 |
+
remaining_steps: RemainingSteps
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
tools = [database_search]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
current_date = datetime.now().strftime("%B %d, %Y")
|
| 34 |
+
instructions = f"""
|
| 35 |
+
You are AcmeBot, a helpful and knowledgeable virtual assistant designed to support employees by retrieving
|
| 36 |
+
and answering questions based on AcmeTech's official Employee Handbook. Your primary role is to provide
|
| 37 |
+
accurate, concise, and friendly information about company policies, values, procedures, and employee resources.
|
| 38 |
+
Today's date is {current_date}.
|
| 39 |
+
|
| 40 |
+
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.
|
| 41 |
+
|
| 42 |
+
A few things to remember:
|
| 43 |
+
- If you have access to multiple databases, gather information from a diverse range of sources before crafting your response.
|
| 44 |
+
- Please include markdown-formatted links to any citations used in your response. Only include one
|
| 45 |
+
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS.
|
| 46 |
+
- Only use information from the database. Do not use information from outside sources.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
|
| 51 |
+
bound_model = model.bind_tools(tools)
|
| 52 |
+
preprocessor = RunnableLambda(
|
| 53 |
+
lambda state: [SystemMessage(content=instructions)] + state["messages"],
|
| 54 |
+
name="StateModifier",
|
| 55 |
+
)
|
| 56 |
+
return preprocessor | bound_model # type: ignore[return-value]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def format_safety_message(safety: LlamaGuardOutput) -> AIMessage:
|
| 60 |
+
content = (
|
| 61 |
+
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
|
| 62 |
+
)
|
| 63 |
+
return AIMessage(content=content)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 67 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 68 |
+
model_runnable = wrap_model(m)
|
| 69 |
+
response = await model_runnable.ainvoke(state, config)
|
| 70 |
+
|
| 71 |
+
# Run llama guard check here to avoid returning the message if it's unsafe
|
| 72 |
+
llama_guard = LlamaGuard()
|
| 73 |
+
safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response])
|
| 74 |
+
if safety_output.safety_assessment == SafetyAssessment.UNSAFE:
|
| 75 |
+
return {
|
| 76 |
+
"messages": [format_safety_message(safety_output)],
|
| 77 |
+
"safety": safety_output,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if state["remaining_steps"] < 2 and response.tool_calls:
|
| 81 |
+
return {
|
| 82 |
+
"messages": [
|
| 83 |
+
AIMessage(
|
| 84 |
+
id=response.id,
|
| 85 |
+
content="Sorry, need more steps to process this request.",
|
| 86 |
+
)
|
| 87 |
+
]
|
| 88 |
+
}
|
| 89 |
+
# We return a list, because this will get added to the existing list
|
| 90 |
+
return {"messages": [response]}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 94 |
+
llama_guard = LlamaGuard()
|
| 95 |
+
safety_output = await llama_guard.ainvoke("User", state["messages"])
|
| 96 |
+
return {"safety": safety_output, "messages": []}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 100 |
+
safety: LlamaGuardOutput = state["safety"]
|
| 101 |
+
return {"messages": [format_safety_message(safety)]}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Define the graph
|
| 105 |
+
agent = StateGraph(AgentState)
|
| 106 |
+
agent.add_node("model", acall_model)
|
| 107 |
+
agent.add_node("tools", ToolNode(tools))
|
| 108 |
+
agent.add_node("guard_input", llama_guard_input)
|
| 109 |
+
agent.add_node("block_unsafe_content", block_unsafe_content)
|
| 110 |
+
agent.set_entry_point("guard_input")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Check for unsafe input and block further processing if found
|
| 114 |
+
def check_safety(state: AgentState) -> Literal["unsafe", "safe"]:
|
| 115 |
+
safety: LlamaGuardOutput = state["safety"]
|
| 116 |
+
match safety.safety_assessment:
|
| 117 |
+
case SafetyAssessment.UNSAFE:
|
| 118 |
+
return "unsafe"
|
| 119 |
+
case _:
|
| 120 |
+
return "safe"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
agent.add_conditional_edges(
|
| 124 |
+
"guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"}
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Always END after blocking unsafe content
|
| 128 |
+
agent.add_edge("block_unsafe_content", END)
|
| 129 |
+
|
| 130 |
+
# Always run "model" after "tools"
|
| 131 |
+
agent.add_edge("tools", "model")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# After "model", if there are tool calls, run "tools". Otherwise END.
|
| 135 |
+
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
|
| 136 |
+
last_message = state["messages"][-1]
|
| 137 |
+
if not isinstance(last_message, AIMessage):
|
| 138 |
+
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
|
| 139 |
+
if last_message.tool_calls:
|
| 140 |
+
return "tools"
|
| 141 |
+
return "done"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END})
|
| 145 |
+
|
| 146 |
+
rag_assistant = agent.compile()
|
src/agents/research_assistant.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun
|
| 5 |
+
from langchain_community.utilities import OpenWeatherMapAPIWrapper
|
| 6 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 7 |
+
from langchain_core.messages import AIMessage, SystemMessage
|
| 8 |
+
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
|
| 9 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 10 |
+
from langgraph.managed import RemainingSteps
|
| 11 |
+
from langgraph.prebuilt import ToolNode
|
| 12 |
+
|
| 13 |
+
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
|
| 14 |
+
from agents.tools import calculator
|
| 15 |
+
from core import get_model, settings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AgentState(MessagesState, total=False):
|
| 19 |
+
"""`total=False` is PEP589 specs.
|
| 20 |
+
|
| 21 |
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
safety: LlamaGuardOutput
|
| 25 |
+
remaining_steps: RemainingSteps
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
web_search = DuckDuckGoSearchResults(name="WebSearch")
|
| 29 |
+
tools = [web_search, calculator]
|
| 30 |
+
|
| 31 |
+
# Add weather tool if API key is set
|
| 32 |
+
# Register for an API key at https://openweathermap.org/api/
|
| 33 |
+
if settings.OPENWEATHERMAP_API_KEY:
|
| 34 |
+
wrapper = OpenWeatherMapAPIWrapper(
|
| 35 |
+
openweathermap_api_key=settings.OPENWEATHERMAP_API_KEY.get_secret_value()
|
| 36 |
+
)
|
| 37 |
+
tools.append(OpenWeatherMapQueryRun(name="Weather", api_wrapper=wrapper))
|
| 38 |
+
|
| 39 |
+
current_date = datetime.now().strftime("%B %d, %Y")
|
| 40 |
+
instructions = f"""
|
| 41 |
+
You are a helpful research assistant with the ability to search the web and use other tools.
|
| 42 |
+
Today's date is {current_date}.
|
| 43 |
+
|
| 44 |
+
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.
|
| 45 |
+
|
| 46 |
+
A few things to remember:
|
| 47 |
+
- Please include markdown-formatted links to any citations used in your response. Only include one
|
| 48 |
+
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS.
|
| 49 |
+
- Use calculator tool with numexpr to answer math questions. The user does not understand numexpr,
|
| 50 |
+
so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)".
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
|
| 55 |
+
bound_model = model.bind_tools(tools)
|
| 56 |
+
preprocessor = RunnableLambda(
|
| 57 |
+
lambda state: [SystemMessage(content=instructions)] + state["messages"],
|
| 58 |
+
name="StateModifier",
|
| 59 |
+
)
|
| 60 |
+
return preprocessor | bound_model # type: ignore[return-value]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def format_safety_message(safety: LlamaGuardOutput) -> AIMessage:
|
| 64 |
+
content = (
|
| 65 |
+
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
|
| 66 |
+
)
|
| 67 |
+
return AIMessage(content=content)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 71 |
+
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
| 72 |
+
model_runnable = wrap_model(m)
|
| 73 |
+
response = await model_runnable.ainvoke(state, config)
|
| 74 |
+
|
| 75 |
+
# Run llama guard check here to avoid returning the message if it's unsafe
|
| 76 |
+
llama_guard = LlamaGuard()
|
| 77 |
+
safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response])
|
| 78 |
+
if safety_output.safety_assessment == SafetyAssessment.UNSAFE:
|
| 79 |
+
return {"messages": [format_safety_message(safety_output)], "safety": safety_output}
|
| 80 |
+
|
| 81 |
+
if state["remaining_steps"] < 2 and response.tool_calls:
|
| 82 |
+
return {
|
| 83 |
+
"messages": [
|
| 84 |
+
AIMessage(
|
| 85 |
+
id=response.id,
|
| 86 |
+
content="Sorry, need more steps to process this request.",
|
| 87 |
+
)
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
# We return a list, because this will get added to the existing list
|
| 91 |
+
return {"messages": [response]}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 95 |
+
llama_guard = LlamaGuard()
|
| 96 |
+
safety_output = await llama_guard.ainvoke("User", state["messages"])
|
| 97 |
+
return {"safety": safety_output, "messages": []}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState:
|
| 101 |
+
safety: LlamaGuardOutput = state["safety"]
|
| 102 |
+
return {"messages": [format_safety_message(safety)]}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Define the graph
|
| 106 |
+
agent = StateGraph(AgentState)
|
| 107 |
+
agent.add_node("model", acall_model)
|
| 108 |
+
agent.add_node("tools", ToolNode(tools))
|
| 109 |
+
agent.add_node("guard_input", llama_guard_input)
|
| 110 |
+
agent.add_node("block_unsafe_content", block_unsafe_content)
|
| 111 |
+
agent.set_entry_point("guard_input")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Check for unsafe input and block further processing if found
|
| 115 |
+
def check_safety(state: AgentState) -> Literal["unsafe", "safe"]:
|
| 116 |
+
safety: LlamaGuardOutput = state["safety"]
|
| 117 |
+
match safety.safety_assessment:
|
| 118 |
+
case SafetyAssessment.UNSAFE:
|
| 119 |
+
return "unsafe"
|
| 120 |
+
case _:
|
| 121 |
+
return "safe"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
agent.add_conditional_edges(
|
| 125 |
+
"guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Always END after blocking unsafe content
|
| 129 |
+
agent.add_edge("block_unsafe_content", END)
|
| 130 |
+
|
| 131 |
+
# Always run "model" after "tools"
|
| 132 |
+
agent.add_edge("tools", "model")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# After "model", if there are tool calls, run "tools". Otherwise END.
|
| 136 |
+
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
|
| 137 |
+
last_message = state["messages"][-1]
|
| 138 |
+
if not isinstance(last_message, AIMessage):
|
| 139 |
+
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
|
| 140 |
+
if last_message.tool_calls:
|
| 141 |
+
return "tools"
|
| 142 |
+
return "done"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END})
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
research_assistant = agent.compile()
|
src/agents/tools.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import math
|
| 3 |
+
import numexpr
|
| 4 |
+
|
| 5 |
+
from memory.postgres import load_pgvector_retriever
|
| 6 |
+
from langchain_core.tools import BaseTool, tool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def calculator_func(expression: str) -> str:
|
| 10 |
+
"""Calculates a math expression using numexpr.
|
| 11 |
+
|
| 12 |
+
Useful for when you need to answer questions about math using numexpr.
|
| 13 |
+
This tool is only for math questions and nothing else. Only input
|
| 14 |
+
math expressions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
expression (str): A valid numexpr formatted math expression.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: The result of the math expression.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
local_dict = {"pi": math.pi, "e": math.e}
|
| 25 |
+
output = str(
|
| 26 |
+
numexpr.evaluate(
|
| 27 |
+
expression.strip(),
|
| 28 |
+
global_dict={}, # restrict access to globals
|
| 29 |
+
local_dict=local_dict, # add common mathematical functions
|
| 30 |
+
)
|
| 31 |
+
)
|
| 32 |
+
return re.sub(r"^\[|\]$", "", output)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f'calculator("{expression}") raised error: {e}.'
|
| 36 |
+
" Please try again with a valid numerical expression"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
calculator: BaseTool = tool(calculator_func)
|
| 41 |
+
calculator.name = "Calculator"
|
| 42 |
+
|
| 43 |
+
def format_contexts(docs):
|
| 44 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 45 |
+
|
| 46 |
+
def database_search_func(query: str) -> str:
|
| 47 |
+
"""Searches the vector DB for information in the portfolio."""
|
| 48 |
+
retriever = load_pgvector_retriever()
|
| 49 |
+
documents = retriever.invoke(query)
|
| 50 |
+
context_str = format_contexts(documents)
|
| 51 |
+
|
| 52 |
+
return context_str
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
database_search: BaseTool = tool(database_search_func)
|
| 56 |
+
database_search.name = "Database_Search"
|
src/agents/utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from langchain_core.messages import ChatMessage
|
| 4 |
+
from langgraph.types import StreamWriter
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CustomData(BaseModel):
|
| 9 |
+
"Custom data being sent by an agent"
|
| 10 |
+
|
| 11 |
+
data: dict[str, Any] = Field(description="The custom data")
|
| 12 |
+
|
| 13 |
+
def to_langchain(self) -> ChatMessage:
|
| 14 |
+
return ChatMessage(content=[self.data], role="custom")
|
| 15 |
+
|
| 16 |
+
def dispatch(self, writer: StreamWriter) -> None:
|
| 17 |
+
writer(self.to_langchain())
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.llm import get_model
|
| 2 |
+
from core.settings import settings
|
| 3 |
+
|
| 4 |
+
__all__ = ["settings", "get_model"]
|
src/core/embeddings.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import cache
|
| 2 |
+
from typing import TypeAlias
|
| 3 |
+
|
| 4 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 5 |
+
from langchain_ollama import OllamaEmbeddings
|
| 6 |
+
from langchain_openai import OpenAIEmbeddings
|
| 7 |
+
|
| 8 |
+
from core.settings import settings
|
| 9 |
+
from schema.models import (
|
| 10 |
+
AllEmbeddingModelEnum,
|
| 11 |
+
GoogleEmbeddingModelName,
|
| 12 |
+
OllamaEmbeddingModelName,
|
| 13 |
+
OpenAIEmbeddingModelName,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
EmbeddingT: TypeAlias = (
|
| 17 |
+
OpenAIEmbeddings
|
| 18 |
+
| GoogleGenerativeAIEmbeddings
|
| 19 |
+
| OllamaEmbeddings
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@cache
|
| 24 |
+
def get_embeddings(model_name: AllEmbeddingModelEnum, /) -> EmbeddingT:
|
| 25 |
+
if model_name in OpenAIEmbeddingModelName:
|
| 26 |
+
return OpenAIEmbeddings(model=model_name.value)
|
| 27 |
+
|
| 28 |
+
if model_name in GoogleEmbeddingModelName:
|
| 29 |
+
return GoogleGenerativeAIEmbeddings(model=model_name.value)
|
| 30 |
+
|
| 31 |
+
if model_name in OllamaEmbeddingModelName:
|
| 32 |
+
return OllamaEmbeddings(
|
| 33 |
+
model=settings.OLLAMA_EMBEDDING_MODEL or model_name.value,
|
| 34 |
+
base_url=settings.OLLAMA_BASE_URL,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
raise ValueError(f"Unsupported embedding model: {model_name}")
|
src/core/llm.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import cache
|
| 2 |
+
from typing import TypeAlias
|
| 3 |
+
|
| 4 |
+
from langchain_anthropic import ChatAnthropic
|
| 5 |
+
from langchain_aws import ChatBedrock
|
| 6 |
+
from langchain_community.chat_models import FakeListChatModel
|
| 7 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 8 |
+
from langchain_google_vertexai import ChatVertexAI
|
| 9 |
+
from langchain_groq import ChatGroq
|
| 10 |
+
from langchain_ollama import ChatOllama
|
| 11 |
+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
| 12 |
+
|
| 13 |
+
from core.settings import settings
|
| 14 |
+
from schema.models import (
|
| 15 |
+
AllModelEnum,
|
| 16 |
+
AnthropicModelName,
|
| 17 |
+
AWSModelName,
|
| 18 |
+
AzureOpenAIModelName,
|
| 19 |
+
DeepseekModelName,
|
| 20 |
+
FakeModelName,
|
| 21 |
+
GoogleModelName,
|
| 22 |
+
GroqModelName,
|
| 23 |
+
OllamaModelName,
|
| 24 |
+
OpenAICompatibleName,
|
| 25 |
+
OpenAIModelName,
|
| 26 |
+
OpenRouterModelName,
|
| 27 |
+
VertexAIModelName,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
_MODEL_TABLE = (
|
| 31 |
+
{m: m.value for m in OpenAIModelName}
|
| 32 |
+
| {m: m.value for m in OpenAICompatibleName}
|
| 33 |
+
| {m: m.value for m in AzureOpenAIModelName}
|
| 34 |
+
| {m: m.value for m in DeepseekModelName}
|
| 35 |
+
| {m: m.value for m in AnthropicModelName}
|
| 36 |
+
| {m: m.value for m in GoogleModelName}
|
| 37 |
+
| {m: m.value for m in VertexAIModelName}
|
| 38 |
+
| {m: m.value for m in GroqModelName}
|
| 39 |
+
| {m: m.value for m in AWSModelName}
|
| 40 |
+
| {m: m.value for m in OllamaModelName}
|
| 41 |
+
| {m: m.value for m in OpenRouterModelName}
|
| 42 |
+
| {m: m.value for m in FakeModelName}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FakeToolModel(FakeListChatModel):
|
| 47 |
+
def __init__(self, responses: list[str]):
|
| 48 |
+
super().__init__(responses=responses)
|
| 49 |
+
|
| 50 |
+
def bind_tools(self, tools):
|
| 51 |
+
return self
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
ModelT: TypeAlias = (
|
| 55 |
+
AzureChatOpenAI
|
| 56 |
+
| ChatOpenAI
|
| 57 |
+
| ChatAnthropic
|
| 58 |
+
| ChatGoogleGenerativeAI
|
| 59 |
+
| ChatVertexAI
|
| 60 |
+
| ChatGroq
|
| 61 |
+
| ChatBedrock
|
| 62 |
+
| ChatOllama
|
| 63 |
+
| FakeToolModel
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@cache
|
| 68 |
+
def get_model(model_name: AllModelEnum, /) -> ModelT:
|
| 69 |
+
# NOTE: models with streaming=True will send tokens as they are generated
|
| 70 |
+
# if the /stream endpoint is called with stream_tokens=True (the default)
|
| 71 |
+
api_model_name = _MODEL_TABLE.get(model_name)
|
| 72 |
+
if not api_model_name:
|
| 73 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
| 74 |
+
|
| 75 |
+
if model_name in OpenAIModelName:
|
| 76 |
+
return ChatOpenAI(model=api_model_name, streaming=True)
|
| 77 |
+
if model_name in OpenAICompatibleName:
|
| 78 |
+
if not settings.COMPATIBLE_BASE_URL or not settings.COMPATIBLE_MODEL:
|
| 79 |
+
raise ValueError("OpenAICompatible base url and endpoint must be configured")
|
| 80 |
+
|
| 81 |
+
return ChatOpenAI(
|
| 82 |
+
model=settings.COMPATIBLE_MODEL,
|
| 83 |
+
temperature=0.5,
|
| 84 |
+
streaming=True,
|
| 85 |
+
openai_api_base=settings.COMPATIBLE_BASE_URL,
|
| 86 |
+
openai_api_key=settings.COMPATIBLE_API_KEY,
|
| 87 |
+
)
|
| 88 |
+
if model_name in AzureOpenAIModelName:
|
| 89 |
+
if not settings.AZURE_OPENAI_API_KEY or not settings.AZURE_OPENAI_ENDPOINT:
|
| 90 |
+
raise ValueError("Azure OpenAI API key and endpoint must be configured")
|
| 91 |
+
|
| 92 |
+
return AzureChatOpenAI(
|
| 93 |
+
azure_endpoint=settings.AZURE_OPENAI_ENDPOINT,
|
| 94 |
+
deployment_name=api_model_name,
|
| 95 |
+
api_version=settings.AZURE_OPENAI_API_VERSION,
|
| 96 |
+
temperature=0.5,
|
| 97 |
+
streaming=True,
|
| 98 |
+
timeout=60,
|
| 99 |
+
max_retries=3,
|
| 100 |
+
)
|
| 101 |
+
if model_name in DeepseekModelName:
|
| 102 |
+
return ChatOpenAI(
|
| 103 |
+
model=api_model_name,
|
| 104 |
+
temperature=0.5,
|
| 105 |
+
streaming=True,
|
| 106 |
+
openai_api_base="https://api.deepseek.com",
|
| 107 |
+
openai_api_key=settings.DEEPSEEK_API_KEY,
|
| 108 |
+
)
|
| 109 |
+
if model_name in AnthropicModelName:
|
| 110 |
+
return ChatAnthropic(model=api_model_name, temperature=0.5, streaming=True)
|
| 111 |
+
if model_name in GoogleModelName:
|
| 112 |
+
return ChatGoogleGenerativeAI(model=api_model_name, temperature=0.5, streaming=True)
|
| 113 |
+
if model_name in VertexAIModelName:
|
| 114 |
+
return ChatVertexAI(model=api_model_name, temperature=0.5, streaming=True)
|
| 115 |
+
if model_name in GroqModelName:
|
| 116 |
+
# Guard and safeguard models should use temperature=0.0 for deterministic outputs
|
| 117 |
+
guard_models = {
|
| 118 |
+
GroqModelName.LLAMA_GUARD_4_12B,
|
| 119 |
+
GroqModelName.LLAMA_PROMPT_GUARD_2_22M,
|
| 120 |
+
GroqModelName.LLAMA_PROMPT_GUARD_2_86M,
|
| 121 |
+
GroqModelName.OPENAI_GPT_OSS_SAFEGUARD_20B,
|
| 122 |
+
}
|
| 123 |
+
if model_name in guard_models:
|
| 124 |
+
return ChatGroq(model=api_model_name, temperature=0.0) # type: ignore[call-arg]
|
| 125 |
+
return ChatGroq(model=api_model_name, temperature=0.5) # type: ignore[call-arg]
|
| 126 |
+
if model_name in AWSModelName:
|
| 127 |
+
return ChatBedrock(model_id=api_model_name, temperature=0.5)
|
| 128 |
+
if model_name in OllamaModelName:
|
| 129 |
+
if settings.OLLAMA_BASE_URL:
|
| 130 |
+
chat_ollama = ChatOllama(
|
| 131 |
+
model=settings.OLLAMA_MODEL, temperature=0.5, base_url=settings.OLLAMA_BASE_URL
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
chat_ollama = ChatOllama(model=settings.OLLAMA_MODEL, temperature=0.5)
|
| 135 |
+
return chat_ollama
|
| 136 |
+
if model_name in OpenRouterModelName:
|
| 137 |
+
return ChatOpenAI(
|
| 138 |
+
model=api_model_name,
|
| 139 |
+
temperature=0.5,
|
| 140 |
+
streaming=True,
|
| 141 |
+
base_url="https://openrouter.ai/api/v1/",
|
| 142 |
+
api_key=settings.OPENROUTER_API_KEY,
|
| 143 |
+
)
|
| 144 |
+
if model_name in FakeModelName:
|
| 145 |
+
return FakeToolModel(responses=["This is a test response from the fake model."])
|
| 146 |
+
|
| 147 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
src/core/settings.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import StrEnum
|
| 2 |
+
from json import loads
|
| 3 |
+
from typing import Annotated, Any
|
| 4 |
+
|
| 5 |
+
from dotenv import find_dotenv
|
| 6 |
+
from pydantic import (
|
| 7 |
+
BeforeValidator,
|
| 8 |
+
Field,
|
| 9 |
+
HttpUrl,
|
| 10 |
+
SecretStr,
|
| 11 |
+
TypeAdapter,
|
| 12 |
+
computed_field,
|
| 13 |
+
)
|
| 14 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 15 |
+
|
| 16 |
+
from schema.models import (
|
| 17 |
+
AllModelEnum,
|
| 18 |
+
AnthropicModelName,
|
| 19 |
+
AWSModelName,
|
| 20 |
+
AzureOpenAIModelName,
|
| 21 |
+
DeepseekModelName,
|
| 22 |
+
FakeModelName,
|
| 23 |
+
GoogleModelName,
|
| 24 |
+
GroqModelName,
|
| 25 |
+
OllamaModelName,
|
| 26 |
+
OpenAICompatibleName,
|
| 27 |
+
OpenAIModelName,
|
| 28 |
+
OpenRouterModelName,
|
| 29 |
+
Provider,
|
| 30 |
+
VertexAIModelName,
|
| 31 |
+
AllEmbeddingModelEnum,
|
| 32 |
+
OpenAIEmbeddingModelName,
|
| 33 |
+
GoogleEmbeddingModelName,
|
| 34 |
+
OllamaEmbeddingModelName,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DatabaseType(StrEnum):
|
| 39 |
+
SQLITE = "sqlite"
|
| 40 |
+
POSTGRES = "postgres"
|
| 41 |
+
MONGO = "mongo"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LogLevel(StrEnum):
|
| 45 |
+
DEBUG = "DEBUG"
|
| 46 |
+
INFO = "INFO"
|
| 47 |
+
WARNING = "WARNING"
|
| 48 |
+
ERROR = "ERROR"
|
| 49 |
+
CRITICAL = "CRITICAL"
|
| 50 |
+
|
| 51 |
+
def to_logging_level(self) -> int:
|
| 52 |
+
"""Convert to Python logging level constant."""
|
| 53 |
+
import logging
|
| 54 |
+
|
| 55 |
+
mapping = {
|
| 56 |
+
LogLevel.DEBUG: logging.DEBUG,
|
| 57 |
+
LogLevel.INFO: logging.INFO,
|
| 58 |
+
LogLevel.WARNING: logging.WARNING,
|
| 59 |
+
LogLevel.ERROR: logging.ERROR,
|
| 60 |
+
LogLevel.CRITICAL: logging.CRITICAL,
|
| 61 |
+
}
|
| 62 |
+
return mapping[self]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def check_str_is_http(x: str) -> str:
|
| 66 |
+
http_url_adapter = TypeAdapter(HttpUrl)
|
| 67 |
+
return str(http_url_adapter.validate_python(x))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Settings(BaseSettings):
|
| 71 |
+
model_config = SettingsConfigDict(
|
| 72 |
+
env_file=find_dotenv(),
|
| 73 |
+
env_file_encoding="utf-8",
|
| 74 |
+
env_ignore_empty=True,
|
| 75 |
+
extra="ignore",
|
| 76 |
+
validate_default=False,
|
| 77 |
+
)
|
| 78 |
+
MODE: str | None = None
|
| 79 |
+
|
| 80 |
+
HOST: str = "0.0.0.0"
|
| 81 |
+
PORT: int = 7860
|
| 82 |
+
GRACEFUL_SHUTDOWN_TIMEOUT: int = 30
|
| 83 |
+
LOG_LEVEL: LogLevel = LogLevel.WARNING
|
| 84 |
+
|
| 85 |
+
AUTH_SECRET: SecretStr | None = None
|
| 86 |
+
CORS_ORIGINS: Annotated[Any, BeforeValidator(lambda x: x.split(",") if isinstance(x, str) else x)] = [
|
| 87 |
+
"http://localhost:3000",
|
| 88 |
+
"http://localhost:8081",
|
| 89 |
+
"http://localhost:5173",
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
OPENAI_API_KEY: SecretStr | None = None
|
| 93 |
+
DEEPSEEK_API_KEY: SecretStr | None = None
|
| 94 |
+
ANTHROPIC_API_KEY: SecretStr | None = None
|
| 95 |
+
GOOGLE_API_KEY: SecretStr | None = None
|
| 96 |
+
GOOGLE_APPLICATION_CREDENTIALS: SecretStr | None = None
|
| 97 |
+
GROQ_API_KEY: SecretStr | None = None
|
| 98 |
+
USE_AWS_BEDROCK: bool = False
|
| 99 |
+
OLLAMA_MODEL: str | None = None
|
| 100 |
+
OLLAMA_BASE_URL: str | None = None
|
| 101 |
+
USE_FAKE_MODEL: bool = False
|
| 102 |
+
OPENROUTER_API_KEY: str | None = None
|
| 103 |
+
|
| 104 |
+
# If DEFAULT_MODEL is None, it will be set in model_post_init
|
| 105 |
+
DEFAULT_MODEL: AllModelEnum | None = None # type: ignore[assignment]
|
| 106 |
+
AVAILABLE_MODELS: set[AllModelEnum] = set() # type: ignore[assignment]
|
| 107 |
+
|
| 108 |
+
# Embedding Settings
|
| 109 |
+
DEFAULT_EMBEDDING_MODEL: AllEmbeddingModelEnum | None = None # type: ignore[assignment]
|
| 110 |
+
AVAILABLE_EMBEDDING_MODELS: set[AllEmbeddingModelEnum] = set() # type: ignore[assignment]
|
| 111 |
+
OLLAMA_EMBEDDING_MODEL: str | None = None
|
| 112 |
+
|
| 113 |
+
# Set openai compatible api, mainly used for proof of concept
|
| 114 |
+
COMPATIBLE_MODEL: str | None = None
|
| 115 |
+
COMPATIBLE_API_KEY: SecretStr | None = None
|
| 116 |
+
COMPATIBLE_BASE_URL: str | None = None
|
| 117 |
+
|
| 118 |
+
OPENWEATHERMAP_API_KEY: SecretStr | None = None
|
| 119 |
+
|
| 120 |
+
# MCP Configuration
|
| 121 |
+
GITHUB_PAT: SecretStr | None = None
|
| 122 |
+
MCP_GITHUB_SERVER_URL: str = "https://api.githubcopilot.com/mcp/"
|
| 123 |
+
|
| 124 |
+
LANGCHAIN_TRACING_V2: bool = False
|
| 125 |
+
LANGCHAIN_PROJECT: str = "default"
|
| 126 |
+
LANGCHAIN_ENDPOINT: Annotated[str, BeforeValidator(check_str_is_http)] = (
|
| 127 |
+
"https://api.smith.langchain.com"
|
| 128 |
+
)
|
| 129 |
+
LANGCHAIN_API_KEY: SecretStr | None = None
|
| 130 |
+
|
| 131 |
+
LANGFUSE_TRACING: bool = False
|
| 132 |
+
LANGFUSE_HOST: Annotated[str, BeforeValidator(check_str_is_http)] = "https://cloud.langfuse.com"
|
| 133 |
+
LANGFUSE_PUBLIC_KEY: SecretStr | None = None
|
| 134 |
+
LANGFUSE_SECRET_KEY: SecretStr | None = None
|
| 135 |
+
|
| 136 |
+
# Database Configuration
|
| 137 |
+
DATABASE_TYPE: DatabaseType = (
|
| 138 |
+
DatabaseType.SQLITE
|
| 139 |
+
) # Options: DatabaseType.SQLITE or DatabaseType.POSTGRES
|
| 140 |
+
SQLITE_DB_PATH: str = "checkpoints.db"
|
| 141 |
+
|
| 142 |
+
# PostgreSQL Configuration
|
| 143 |
+
POSTGRES_USER: str | None = None
|
| 144 |
+
POSTGRES_PASSWORD: SecretStr | None = None
|
| 145 |
+
POSTGRES_HOST: str | None = None
|
| 146 |
+
POSTGRES_PORT: int | None = None
|
| 147 |
+
POSTGRES_DB: str | None = None
|
| 148 |
+
POSTGRES_APPLICATION_NAME: str = "agent-service-toolkit"
|
| 149 |
+
POSTGRES_MIN_CONNECTIONS_PER_POOL: int = 1
|
| 150 |
+
POSTGRES_MAX_CONNECTIONS_PER_POOL: int = 1
|
| 151 |
+
VECTOR_STORE_COLLECTION_NAME: str = "vector_store"
|
| 152 |
+
|
| 153 |
+
# MongoDB Configuration
|
| 154 |
+
MONGO_HOST: str | None = None
|
| 155 |
+
MONGO_PORT: int | None = None
|
| 156 |
+
MONGO_DB: str | None = None
|
| 157 |
+
MONGO_USER: str | None = None
|
| 158 |
+
MONGO_PASSWORD: SecretStr | None = None
|
| 159 |
+
MONGO_AUTH_SOURCE: str | None = None
|
| 160 |
+
|
| 161 |
+
# Azure OpenAI Settings
|
| 162 |
+
AZURE_OPENAI_API_KEY: SecretStr | None = None
|
| 163 |
+
AZURE_OPENAI_ENDPOINT: str | None = None
|
| 164 |
+
AZURE_OPENAI_API_VERSION: str = "2024-02-15-preview"
|
| 165 |
+
AZURE_OPENAI_DEPLOYMENT_MAP: dict[str, str] = Field(
|
| 166 |
+
default_factory=dict, description="Map of model names to Azure deployment IDs"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def model_post_init(self, __context: Any) -> None:
|
| 170 |
+
api_keys = {
|
| 171 |
+
Provider.OPENAI: self.OPENAI_API_KEY,
|
| 172 |
+
Provider.OPENAI_COMPATIBLE: self.COMPATIBLE_BASE_URL and self.COMPATIBLE_MODEL,
|
| 173 |
+
Provider.DEEPSEEK: self.DEEPSEEK_API_KEY,
|
| 174 |
+
Provider.ANTHROPIC: self.ANTHROPIC_API_KEY,
|
| 175 |
+
Provider.GOOGLE: self.GOOGLE_API_KEY,
|
| 176 |
+
Provider.VERTEXAI: self.GOOGLE_APPLICATION_CREDENTIALS,
|
| 177 |
+
Provider.GROQ: self.GROQ_API_KEY,
|
| 178 |
+
Provider.AWS: self.USE_AWS_BEDROCK,
|
| 179 |
+
Provider.OLLAMA: self.OLLAMA_MODEL,
|
| 180 |
+
Provider.FAKE: self.USE_FAKE_MODEL,
|
| 181 |
+
Provider.AZURE_OPENAI: self.AZURE_OPENAI_API_KEY,
|
| 182 |
+
Provider.OPENROUTER: self.OPENROUTER_API_KEY,
|
| 183 |
+
}
|
| 184 |
+
active_keys = [k for k, v in api_keys.items() if v]
|
| 185 |
+
if not active_keys:
|
| 186 |
+
raise ValueError("At least one LLM API key must be provided.")
|
| 187 |
+
|
| 188 |
+
for provider in active_keys:
|
| 189 |
+
match provider:
|
| 190 |
+
case Provider.OPENAI:
|
| 191 |
+
if self.DEFAULT_MODEL is None:
|
| 192 |
+
self.DEFAULT_MODEL = OpenAIModelName.GPT_5_NANO
|
| 193 |
+
self.AVAILABLE_MODELS.update(set(OpenAIModelName))
|
| 194 |
+
case Provider.OPENAI_COMPATIBLE:
|
| 195 |
+
if self.DEFAULT_MODEL is None:
|
| 196 |
+
self.DEFAULT_MODEL = OpenAICompatibleName.OPENAI_COMPATIBLE
|
| 197 |
+
self.AVAILABLE_MODELS.update(set(OpenAICompatibleName))
|
| 198 |
+
case Provider.DEEPSEEK:
|
| 199 |
+
if self.DEFAULT_MODEL is None:
|
| 200 |
+
self.DEFAULT_MODEL = DeepseekModelName.DEEPSEEK_CHAT
|
| 201 |
+
self.AVAILABLE_MODELS.update(set(DeepseekModelName))
|
| 202 |
+
case Provider.ANTHROPIC:
|
| 203 |
+
if self.DEFAULT_MODEL is None:
|
| 204 |
+
self.DEFAULT_MODEL = AnthropicModelName.HAIKU_45
|
| 205 |
+
self.AVAILABLE_MODELS.update(set(AnthropicModelName))
|
| 206 |
+
case Provider.GOOGLE:
|
| 207 |
+
if self.DEFAULT_MODEL is None:
|
| 208 |
+
self.DEFAULT_MODEL = GoogleModelName.GEMINI_20_FLASH
|
| 209 |
+
self.AVAILABLE_MODELS.update(set(GoogleModelName))
|
| 210 |
+
case Provider.VERTEXAI:
|
| 211 |
+
if self.DEFAULT_MODEL is None:
|
| 212 |
+
self.DEFAULT_MODEL = VertexAIModelName.GEMINI_20_FLASH
|
| 213 |
+
self.AVAILABLE_MODELS.update(set(VertexAIModelName))
|
| 214 |
+
case Provider.GROQ:
|
| 215 |
+
if self.DEFAULT_MODEL is None:
|
| 216 |
+
self.DEFAULT_MODEL = GroqModelName.LLAMA_31_8B_INSTANT
|
| 217 |
+
self.AVAILABLE_MODELS.update(set(GroqModelName))
|
| 218 |
+
case Provider.AWS:
|
| 219 |
+
if self.DEFAULT_MODEL is None:
|
| 220 |
+
self.DEFAULT_MODEL = AWSModelName.BEDROCK_HAIKU
|
| 221 |
+
self.AVAILABLE_MODELS.update(set(AWSModelName))
|
| 222 |
+
case Provider.OLLAMA:
|
| 223 |
+
if self.DEFAULT_MODEL is None:
|
| 224 |
+
self.DEFAULT_MODEL = OllamaModelName.OLLAMA_GENERIC
|
| 225 |
+
self.AVAILABLE_MODELS.update(set(OllamaModelName))
|
| 226 |
+
case Provider.OPENROUTER:
|
| 227 |
+
if self.DEFAULT_MODEL is None:
|
| 228 |
+
self.DEFAULT_MODEL = OpenRouterModelName.GEMINI_25_FLASH
|
| 229 |
+
self.AVAILABLE_MODELS.update(set(OpenRouterModelName))
|
| 230 |
+
case Provider.FAKE:
|
| 231 |
+
if self.DEFAULT_MODEL is None:
|
| 232 |
+
self.DEFAULT_MODEL = FakeModelName.FAKE
|
| 233 |
+
self.AVAILABLE_MODELS.update(set(FakeModelName))
|
| 234 |
+
case Provider.AZURE_OPENAI:
|
| 235 |
+
if self.DEFAULT_MODEL is None:
|
| 236 |
+
self.DEFAULT_MODEL = AzureOpenAIModelName.AZURE_GPT_4O_MINI
|
| 237 |
+
self.AVAILABLE_MODELS.update(set(AzureOpenAIModelName))
|
| 238 |
+
# Validate Azure OpenAI settings if Azure provider is available
|
| 239 |
+
if not self.AZURE_OPENAI_API_KEY:
|
| 240 |
+
raise ValueError("AZURE_OPENAI_API_KEY must be set")
|
| 241 |
+
if not self.AZURE_OPENAI_ENDPOINT:
|
| 242 |
+
raise ValueError("AZURE_OPENAI_ENDPOINT must be set")
|
| 243 |
+
if not self.AZURE_OPENAI_DEPLOYMENT_MAP:
|
| 244 |
+
raise ValueError("AZURE_OPENAI_DEPLOYMENT_MAP must be set")
|
| 245 |
+
|
| 246 |
+
# Parse deployment map if it's a string
|
| 247 |
+
if isinstance(self.AZURE_OPENAI_DEPLOYMENT_MAP, str):
|
| 248 |
+
try:
|
| 249 |
+
self.AZURE_OPENAI_DEPLOYMENT_MAP = loads(
|
| 250 |
+
self.AZURE_OPENAI_DEPLOYMENT_MAP
|
| 251 |
+
)
|
| 252 |
+
except Exception as e:
|
| 253 |
+
raise ValueError(f"Invalid AZURE_OPENAI_DEPLOYMENT_MAP JSON: {e}")
|
| 254 |
+
|
| 255 |
+
# Validate required deployments exist
|
| 256 |
+
required_models = {"gpt-4o", "gpt-4o-mini"}
|
| 257 |
+
missing_models = required_models - set(self.AZURE_OPENAI_DEPLOYMENT_MAP.keys())
|
| 258 |
+
if missing_models:
|
| 259 |
+
raise ValueError(f"Missing required Azure deployments: {missing_models}")
|
| 260 |
+
case _:
|
| 261 |
+
raise ValueError(f"Unknown provider: {provider}")
|
| 262 |
+
|
| 263 |
+
for provider in active_keys:
|
| 264 |
+
match provider:
|
| 265 |
+
case Provider.OPENAI:
|
| 266 |
+
if self.DEFAULT_EMBEDDING_MODEL is None:
|
| 267 |
+
self.DEFAULT_EMBEDDING_MODEL = OpenAIEmbeddingModelName.TEXT_EMBEDDING_3_SMALL
|
| 268 |
+
self.AVAILABLE_EMBEDDING_MODELS.update(set(OpenAIEmbeddingModelName))
|
| 269 |
+
case Provider.GOOGLE:
|
| 270 |
+
if self.DEFAULT_EMBEDDING_MODEL is None:
|
| 271 |
+
self.DEFAULT_EMBEDDING_MODEL = GoogleEmbeddingModelName.TEXT_EMBEDDING_004
|
| 272 |
+
self.AVAILABLE_EMBEDDING_MODELS.update(set(GoogleEmbeddingModelName))
|
| 273 |
+
case Provider.OLLAMA:
|
| 274 |
+
if self.DEFAULT_EMBEDDING_MODEL is None:
|
| 275 |
+
self.DEFAULT_EMBEDDING_MODEL = OllamaEmbeddingModelName.NOMIC_EMBED_TEXT
|
| 276 |
+
self.AVAILABLE_EMBEDDING_MODELS.update(set(OllamaEmbeddingModelName))
|
| 277 |
+
if not self.OLLAMA_EMBEDDING_MODEL:
|
| 278 |
+
self.OLLAMA_EMBEDDING_MODEL = OllamaEmbeddingModelName.NOMIC_EMBED_TEXT
|
| 279 |
+
|
| 280 |
+
@computed_field # type: ignore[prop-decorator]
|
| 281 |
+
@property
|
| 282 |
+
def BASE_URL(self) -> str:
|
| 283 |
+
return f"http://{self.HOST}:{self.PORT}"
|
| 284 |
+
|
| 285 |
+
def is_dev(self) -> bool:
|
| 286 |
+
return self.MODE == "dev"
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
settings = Settings()
|
src/memory/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import AbstractAsyncContextManager
|
| 2 |
+
|
| 3 |
+
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
|
| 4 |
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 5 |
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
| 6 |
+
|
| 7 |
+
from core.settings import DatabaseType, settings
|
| 8 |
+
from memory.mongodb import get_mongo_saver
|
| 9 |
+
from memory.postgres import get_postgres_saver, get_postgres_store
|
| 10 |
+
from memory.sqlite import get_sqlite_saver, get_sqlite_store
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def initialize_database() -> AbstractAsyncContextManager[
|
| 14 |
+
AsyncSqliteSaver | AsyncPostgresSaver | AsyncMongoDBSaver
|
| 15 |
+
]:
|
| 16 |
+
"""
|
| 17 |
+
Initialize the appropriate database checkpointer based on configuration.
|
| 18 |
+
Returns an initialized AsyncCheckpointer instance.
|
| 19 |
+
"""
|
| 20 |
+
if settings.DATABASE_TYPE == DatabaseType.POSTGRES:
|
| 21 |
+
return get_postgres_saver()
|
| 22 |
+
if settings.DATABASE_TYPE == DatabaseType.MONGO:
|
| 23 |
+
return get_mongo_saver()
|
| 24 |
+
else: # Default to SQLite
|
| 25 |
+
return get_sqlite_saver()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def initialize_store():
|
| 29 |
+
"""
|
| 30 |
+
Initialize the appropriate store based on configuration.
|
| 31 |
+
Returns an async context manager for the initialized store.
|
| 32 |
+
"""
|
| 33 |
+
if settings.DATABASE_TYPE == DatabaseType.POSTGRES:
|
| 34 |
+
return get_postgres_store()
|
| 35 |
+
# TODO: Add Mongo store - https://pypi.org/project/langgraph-store-mongodb/
|
| 36 |
+
else: # Default to SQLite
|
| 37 |
+
return get_sqlite_store()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
__all__ = ["initialize_database", "initialize_store"]
|
src/memory/mongodb.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import urllib.parse
|
| 3 |
+
from contextlib import AbstractAsyncContextManager
|
| 4 |
+
|
| 5 |
+
from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
|
| 6 |
+
|
| 7 |
+
from core.settings import settings
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _has_auth_credentials() -> bool:
|
| 13 |
+
required_auth = ["MONGO_USER", "MONGO_PASSWORD", "MONGO_AUTH_SOURCE"]
|
| 14 |
+
set_auth = [var for var in required_auth if getattr(settings, var, None)]
|
| 15 |
+
if len(set_auth) > 0 and len(set_auth) != len(required_auth):
|
| 16 |
+
raise ValueError(
|
| 17 |
+
f"If any of the following environment variables are set, all must be set: {', '.join(required_auth)}."
|
| 18 |
+
)
|
| 19 |
+
return len(set_auth) == len(required_auth)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def validate_mongo_config() -> None:
|
| 23 |
+
"""
|
| 24 |
+
Validate that all required MongoDB configuration is present.
|
| 25 |
+
Raises ValueError if any required configuration is missing.
|
| 26 |
+
"""
|
| 27 |
+
required_always = ["MONGO_HOST", "MONGO_PORT", "MONGO_DB"]
|
| 28 |
+
missing_always = [var for var in required_always if not getattr(settings, var, None)]
|
| 29 |
+
if missing_always:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"Missing required MongoDB configuration: {', '.join(missing_always)}. "
|
| 32 |
+
"These environment variables must be set to use MongoDB persistence."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
_has_auth_credentials()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_mongo_connection_string() -> str:
|
| 39 |
+
"""Build and return the MongoDB connection string from settings."""
|
| 40 |
+
|
| 41 |
+
if _has_auth_credentials():
|
| 42 |
+
if settings.MONGO_PASSWORD is None: # for type checking
|
| 43 |
+
raise ValueError("MONGO_PASSWORD is not set")
|
| 44 |
+
password = settings.MONGO_PASSWORD.get_secret_value().strip()
|
| 45 |
+
password_escaped = urllib.parse.quote_plus(password)
|
| 46 |
+
return (
|
| 47 |
+
f"mongodb://{settings.MONGO_USER}:{password_escaped}@"
|
| 48 |
+
f"{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
| 49 |
+
f"?authSource={settings.MONGO_AUTH_SOURCE}"
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
return f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_mongo_saver() -> AbstractAsyncContextManager[AsyncMongoDBSaver]:
|
| 56 |
+
"""Initialize and return a MongoDB saver instance."""
|
| 57 |
+
validate_mongo_config()
|
| 58 |
+
if settings.MONGO_DB is None: # for type checking
|
| 59 |
+
raise ValueError("MONGO_DB is not set")
|
| 60 |
+
return AsyncMongoDBSaver.from_conn_string(
|
| 61 |
+
get_mongo_connection_string(), db_name=settings.MONGO_DB
|
| 62 |
+
)
|
src/memory/postgres.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
+
|
| 4 |
+
from core.embeddings import get_embeddings
|
| 5 |
+
from langchain_postgres import PGVector
|
| 6 |
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
| 7 |
+
from langgraph.store.postgres import AsyncPostgresStore
|
| 8 |
+
from psycopg.rows import dict_row
|
| 9 |
+
from psycopg_pool import AsyncConnectionPool
|
| 10 |
+
|
| 11 |
+
from core.settings import settings
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def validate_postgres_config() -> None:
|
| 17 |
+
"""
|
| 18 |
+
Validate that all required PostgreSQL configuration is present.
|
| 19 |
+
Raises ValueError if any required configuration is missing.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
required_vars = [
|
| 23 |
+
"POSTGRES_USER",
|
| 24 |
+
"POSTGRES_PASSWORD",
|
| 25 |
+
"POSTGRES_HOST",
|
| 26 |
+
"POSTGRES_PORT",
|
| 27 |
+
"POSTGRES_DB",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
missing = [var for var in required_vars if not getattr(settings, var, None)]
|
| 31 |
+
if missing:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
|
| 34 |
+
"All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_postgres_connection_string() -> str:
|
| 44 |
+
"""Build and return the PostgreSQL connection string from settings."""
|
| 45 |
+
if settings.POSTGRES_PASSWORD is None:
|
| 46 |
+
raise ValueError("POSTGRES_PASSWORD is not set")
|
| 47 |
+
return (
|
| 48 |
+
f"postgresql://{settings.POSTGRES_USER}:"
|
| 49 |
+
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
|
| 50 |
+
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
|
| 51 |
+
f"{settings.POSTGRES_DB}/?sslmode=require"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@asynccontextmanager
|
| 56 |
+
async def get_postgres_saver():
|
| 57 |
+
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections."""
|
| 58 |
+
validate_postgres_config()
|
| 59 |
+
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver"
|
| 60 |
+
|
| 61 |
+
async with AsyncConnectionPool(
|
| 62 |
+
get_postgres_connection_string(),
|
| 63 |
+
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
|
| 64 |
+
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
|
| 65 |
+
# Langgraph requires autocommmit=true and row_factory to be set to dict_row.
|
| 66 |
+
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
|
| 67 |
+
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
|
| 68 |
+
# makes sure that the connection is still valid before using it
|
| 69 |
+
check=AsyncConnectionPool.check_connection,
|
| 70 |
+
) as pool:
|
| 71 |
+
try:
|
| 72 |
+
checkpointer = AsyncPostgresSaver(pool)
|
| 73 |
+
await checkpointer.setup()
|
| 74 |
+
yield checkpointer
|
| 75 |
+
finally:
|
| 76 |
+
await pool.close()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@asynccontextmanager
|
| 80 |
+
async def get_postgres_store():
|
| 81 |
+
"""
|
| 82 |
+
Get a PostgreSQL store instance based on a connection pool for more resilient connections.
|
| 83 |
+
|
| 84 |
+
Returns an AsyncPostgresStore instance that can be used with async context manager pattern.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
validate_postgres_config()
|
| 88 |
+
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store"
|
| 89 |
+
|
| 90 |
+
async with AsyncConnectionPool(
|
| 91 |
+
get_postgres_connection_string(),
|
| 92 |
+
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
|
| 93 |
+
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
|
| 94 |
+
# Langgraph requires autocommmit=true and row_factory to be set to dict_row
|
| 95 |
+
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
|
| 96 |
+
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
|
| 97 |
+
# makes sure that the connection is still valid before using it
|
| 98 |
+
check=AsyncConnectionPool.check_connection,
|
| 99 |
+
) as pool:
|
| 100 |
+
try:
|
| 101 |
+
store = AsyncPostgresStore(pool)
|
| 102 |
+
await store.setup()
|
| 103 |
+
yield store
|
| 104 |
+
finally:
|
| 105 |
+
await pool.close()
|
| 106 |
+
|
| 107 |
+
def get_pgvector_connection_string() -> str:
|
| 108 |
+
"""Build and return the PostgreSQL connection string for vectors from settings."""
|
| 109 |
+
return (
|
| 110 |
+
f"postgresql+psycopg://{settings.POSTGRES_USER}:"
|
| 111 |
+
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
|
| 112 |
+
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
|
| 113 |
+
f"{settings.POSTGRES_DB}?sslmode=require"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def load_pgvector_store():
|
| 117 |
+
"""Get a PostgreSQL vectors store instance."""
|
| 118 |
+
validate_postgres_config()
|
| 119 |
+
|
| 120 |
+
return PGVector(
|
| 121 |
+
connection=get_pgvector_connection_string(),
|
| 122 |
+
collection_name=settings.VECTOR_STORE_COLLECTION_NAME,
|
| 123 |
+
embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def load_pgvector_retriever(k: int = 6):
|
| 127 |
+
store = load_pgvector_store()
|
| 128 |
+
return store.as_retriever(
|
| 129 |
+
search_type="mmr",
|
| 130 |
+
search_kwargs={
|
| 131 |
+
"k": k,
|
| 132 |
+
"fetch_k": 20, # candidates
|
| 133 |
+
"lambda_mult": 0.6, # relevance vs diversity
|
| 134 |
+
},
|
| 135 |
+
)
|
src/memory/sqlite.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
| 2 |
+
|
| 3 |
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
| 4 |
+
from langgraph.store.memory import InMemoryStore
|
| 5 |
+
|
| 6 |
+
from core.settings import settings
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_sqlite_saver() -> AbstractAsyncContextManager[AsyncSqliteSaver]:
|
| 10 |
+
"""Initialize and return a SQLite saver instance."""
|
| 11 |
+
return AsyncSqliteSaver.from_conn_string(settings.SQLITE_DB_PATH)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AsyncInMemoryStore:
|
| 15 |
+
"""Wrapper for InMemoryStore that provides an async context manager interface."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.store = InMemoryStore()
|
| 19 |
+
|
| 20 |
+
async def __aenter__(self):
|
| 21 |
+
return self.store
|
| 22 |
+
|
| 23 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 24 |
+
# No cleanup needed for InMemoryStore
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
async def setup(self):
|
| 28 |
+
# No-op method for compatibility with PostgresStore
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@asynccontextmanager
|
| 33 |
+
async def get_sqlite_store():
|
| 34 |
+
"""Initialize and return a store instance for long-term memory.
|
| 35 |
+
|
| 36 |
+
Note: SQLite-specific store isn't available in LangGraph,
|
| 37 |
+
so we use InMemoryStore wrapped in an async context manager for compatibility.
|
| 38 |
+
"""
|
| 39 |
+
store_manager = AsyncInMemoryStore()
|
| 40 |
+
yield await store_manager.__aenter__()
|
src/run_agent.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import cast
|
| 3 |
+
from langsmith import uuid7
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from langchain_core.messages import HumanMessage
|
| 7 |
+
from langchain_core.runnables import RunnableConfig
|
| 8 |
+
from langgraph.graph import MessagesState
|
| 9 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
from agents import DEFAULT_AGENT, get_agent # noqa: E402
|
| 14 |
+
|
| 15 |
+
# The default agent uses StateGraph.compile() which returns CompiledStateGraph
|
| 16 |
+
agent = cast(CompiledStateGraph, get_agent(DEFAULT_AGENT))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def main() -> None:
|
| 20 |
+
inputs: MessagesState = {
|
| 21 |
+
"messages": [HumanMessage("Find me a recipe for chocolate chip cookies")]
|
| 22 |
+
}
|
| 23 |
+
result = await agent.ainvoke(
|
| 24 |
+
input=inputs,
|
| 25 |
+
config=RunnableConfig(configurable={"thread_id": uuid7()}),
|
| 26 |
+
)
|
| 27 |
+
result["messages"][-1].pretty_print()
|
| 28 |
+
|
| 29 |
+
# Draw the agent graph as png
|
| 30 |
+
# requires:
|
| 31 |
+
# brew install graphviz
|
| 32 |
+
# export CFLAGS="-I $(brew --prefix graphviz)/include"
|
| 33 |
+
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
|
| 34 |
+
# pip install pygraphviz
|
| 35 |
+
#
|
| 36 |
+
# agent.get_graph().draw_png("agent_diagram.png")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
asyncio.run(main())
|
src/run_service.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import uvicorn
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
from core import settings
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
root_logger = logging.getLogger()
|
| 14 |
+
if root_logger.handlers:
|
| 15 |
+
print(
|
| 16 |
+
f"Warning: Root logger already has {len(root_logger.handlers)} handler(s) configured. "
|
| 17 |
+
f"basicConfig() will be ignored. Current level: {logging.getLevelName(root_logger.level)}"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=settings.LOG_LEVEL.to_logging_level())
|
| 21 |
+
# Set Compatible event loop policy on Windows Systems.
|
| 22 |
+
# On Windows systems, the default ProactorEventLoop can cause issues with
|
| 23 |
+
# certain async database drivers like psycopg (PostgreSQL driver).
|
| 24 |
+
# The WindowsSelectorEventLoopPolicy provides better compatibility and prevents
|
| 25 |
+
# "RuntimeError: Event loop is closed" errors when working with database connections.
|
| 26 |
+
# This needs to be set before running the application server.
|
| 27 |
+
# Refer to the documentation for more information.
|
| 28 |
+
# https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-operations
|
| 29 |
+
if sys.platform == "win32":
|
| 30 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 31 |
+
uvicorn.run(
|
| 32 |
+
"service:app",
|
| 33 |
+
host=settings.HOST,
|
| 34 |
+
port=settings.PORT,
|
| 35 |
+
reload=settings.is_dev(),
|
| 36 |
+
timeout_graceful_shutdown=settings.GRACEFUL_SHUTDOWN_TIMEOUT,
|
| 37 |
+
)
|
src/schema/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from schema.models import AllModelEnum
|
| 2 |
+
from schema.schema import (
|
| 3 |
+
AgentInfo,
|
| 4 |
+
ChatHistory,
|
| 5 |
+
ChatHistoryInput,
|
| 6 |
+
ChatMessage,
|
| 7 |
+
Feedback,
|
| 8 |
+
FeedbackResponse,
|
| 9 |
+
ServiceMetadata,
|
| 10 |
+
StreamInput,
|
| 11 |
+
UserInput,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"AgentInfo",
|
| 16 |
+
"AllModelEnum",
|
| 17 |
+
"UserInput",
|
| 18 |
+
"ChatMessage",
|
| 19 |
+
"ServiceMetadata",
|
| 20 |
+
"StreamInput",
|
| 21 |
+
"Feedback",
|
| 22 |
+
"FeedbackResponse",
|
| 23 |
+
"ChatHistoryInput",
|
| 24 |
+
"ChatHistory",
|
| 25 |
+
]
|
src/schema/models.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import StrEnum, auto
|
| 2 |
+
from typing import TypeAlias
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Provider(StrEnum):
|
| 6 |
+
OPENAI = auto()
|
| 7 |
+
OPENAI_COMPATIBLE = auto()
|
| 8 |
+
AZURE_OPENAI = auto()
|
| 9 |
+
DEEPSEEK = auto()
|
| 10 |
+
ANTHROPIC = auto()
|
| 11 |
+
GOOGLE = auto()
|
| 12 |
+
VERTEXAI = auto()
|
| 13 |
+
GROQ = auto()
|
| 14 |
+
AWS = auto()
|
| 15 |
+
OLLAMA = auto()
|
| 16 |
+
OPENROUTER = auto()
|
| 17 |
+
FAKE = auto()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class OpenAIModelName(StrEnum):
|
| 21 |
+
"""https://platform.openai.com/docs/models/gpt-4o"""
|
| 22 |
+
|
| 23 |
+
GPT_5_NANO = "gpt-5-nano"
|
| 24 |
+
GPT_5_MINI = "gpt-5-mini"
|
| 25 |
+
GPT_5_1 = "gpt-5.1"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AzureOpenAIModelName(StrEnum):
|
| 29 |
+
"""Azure OpenAI model names"""
|
| 30 |
+
|
| 31 |
+
AZURE_GPT_4O = "azure-gpt-4o"
|
| 32 |
+
AZURE_GPT_4O_MINI = "azure-gpt-4o-mini"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OpenAIEmbeddingModelName(StrEnum):
|
| 36 |
+
"""https://platform.openai.com/docs/guides/embeddings"""
|
| 37 |
+
|
| 38 |
+
TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"
|
| 39 |
+
TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"
|
| 40 |
+
TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DeepseekModelName(StrEnum):
|
| 44 |
+
"""https://api-docs.deepseek.com/quick_start/pricing"""
|
| 45 |
+
|
| 46 |
+
DEEPSEEK_CHAT = "deepseek-chat"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AnthropicModelName(StrEnum):
|
| 50 |
+
"""https://docs.anthropic.com/en/docs/about-claude/models#model-names"""
|
| 51 |
+
|
| 52 |
+
HAIKU_45 = "claude-haiku-4-5"
|
| 53 |
+
SONNET_45 = "claude-sonnet-4-5"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class GoogleModelName(StrEnum):
|
| 57 |
+
"""https://ai.google.dev/gemini-api/docs/models/gemini"""
|
| 58 |
+
|
| 59 |
+
GEMINI_15_PRO = "gemini-1.5-pro"
|
| 60 |
+
GEMINI_20_FLASH = "gemini-2.0-flash"
|
| 61 |
+
GEMINI_20_FLASH_LITE = "gemini-2.0-flash-lite"
|
| 62 |
+
GEMINI_25_FLASH = "gemini-2.5-flash"
|
| 63 |
+
GEMINI_25_PRO = "gemini-2.5-pro"
|
| 64 |
+
GEMINI_30_PRO = "gemini-3-pro-preview"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GoogleEmbeddingModelName(StrEnum):
|
| 68 |
+
"""https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding"""
|
| 69 |
+
|
| 70 |
+
TEXT_EMBEDDING_004 = "text-embedding-004"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class VertexAIModelName(StrEnum):
|
| 74 |
+
"""https://cloud.google.com/vertex-ai/generative-ai/docs/models"""
|
| 75 |
+
|
| 76 |
+
GEMINI_15_PRO = "gemini-1.5-pro"
|
| 77 |
+
GEMINI_20_FLASH = "gemini-2.0-flash"
|
| 78 |
+
GEMINI_20_FLASH_LITE = "models/gemini-2.0-flash-lite"
|
| 79 |
+
GEMINI_25_FLASH = "models/gemini-2.5-flash"
|
| 80 |
+
GEMINI_25_PRO = "gemini-2.5-pro"
|
| 81 |
+
GEMINI_30_PRO = "gemini-3-pro-preview"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class GroqModelName(StrEnum):
|
| 85 |
+
"""https://console.groq.com/docs/models"""
|
| 86 |
+
|
| 87 |
+
LLAMA_GUARD_4_12B = "meta-llama/llama-guard-4-12b"
|
| 88 |
+
LLAMA_31_8B_INSTANT = "llama-3.1-8b-instant"
|
| 89 |
+
LLAMA_33_70B_VERSATILE = "llama-3.3-70b-versatile"
|
| 90 |
+
LLAMA_4_MAVERICK_17B_128E = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
| 91 |
+
LLAMA_4_SCOUT_17B_16E = "meta-llama/llama-4-scout-17b-16e-instruct"
|
| 92 |
+
LLAMA_PROMPT_GUARD_2_22M = "meta-llama/llama-prompt-guard-2-22m"
|
| 93 |
+
LLAMA_PROMPT_GUARD_2_86M = "meta-llama/llama-prompt-guard-2-86m"
|
| 94 |
+
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
| 95 |
+
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
| 96 |
+
OPENAI_GPT_OSS_SAFEGUARD_20B = "openai/gpt-oss-safeguard-20b"
|
| 97 |
+
GROQ_COMPOUND = "groq/compound"
|
| 98 |
+
GROQ_COMPOUND_MINI = "groq/compound-mini"
|
| 99 |
+
QWEN_3_32B = "qwen/qwen3-32b"
|
| 100 |
+
KIMI_K2_INSTRUCT = "moonshotai/kimi-k2-instruct"
|
| 101 |
+
KIMI_K2_INSTRUCT_0905 = "moonshotai/kimi-k2-instruct-0905"
|
| 102 |
+
ORPHEUS_ARABIC_SAUDI = "canopylabs/orpheus-arabic-saudi"
|
| 103 |
+
ORPHEUS_V1_ENGLISH = "canopylabs/orpheus-v1-english"
|
| 104 |
+
WHISPER_LARGE_V3 = "whisper-large-v3"
|
| 105 |
+
WHISPER_LARGE_V3_TURBO = "whisper-large-v3-turbo"
|
| 106 |
+
ALLAM_2_7B = "allam-2-7b"
|
| 107 |
+
|
| 108 |
+
class AWSModelName(StrEnum):
|
| 109 |
+
"""https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"""
|
| 110 |
+
|
| 111 |
+
BEDROCK_HAIKU = "bedrock-3.5-haiku"
|
| 112 |
+
BEDROCK_SONNET = "bedrock-3.5-sonnet"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class OllamaModelName(StrEnum):
|
| 116 |
+
"""https://ollama.com/search"""
|
| 117 |
+
|
| 118 |
+
OLLAMA_GENERIC = "ollama"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class OllamaEmbeddingModelName(StrEnum):
|
| 122 |
+
"""Common Ollama embedding models"""
|
| 123 |
+
|
| 124 |
+
NOMIC_EMBED_TEXT = "nomic-embed-text"
|
| 125 |
+
ALL_MINILM = "all-minilm"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OpenRouterModelName(StrEnum):
|
| 129 |
+
"""https://openrouter.ai/models"""
|
| 130 |
+
|
| 131 |
+
GEMINI_25_FLASH = "google/gemini-2.5-flash"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class OpenAICompatibleName(StrEnum):
|
| 135 |
+
"""https://platform.openai.com/docs/guides/text-generation"""
|
| 136 |
+
|
| 137 |
+
OPENAI_COMPATIBLE = "openai-compatible"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class FakeModelName(StrEnum):
|
| 141 |
+
"""Fake model for testing."""
|
| 142 |
+
|
| 143 |
+
FAKE = "fake"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
AllModelEnum: TypeAlias = (
|
| 147 |
+
OpenAIModelName
|
| 148 |
+
| OpenAICompatibleName
|
| 149 |
+
| AzureOpenAIModelName
|
| 150 |
+
| DeepseekModelName
|
| 151 |
+
| AnthropicModelName
|
| 152 |
+
| GoogleModelName
|
| 153 |
+
| VertexAIModelName
|
| 154 |
+
| GroqModelName
|
| 155 |
+
| AWSModelName
|
| 156 |
+
| OllamaModelName
|
| 157 |
+
| OpenRouterModelName
|
| 158 |
+
| FakeModelName
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
AllEmbeddingModelEnum: TypeAlias = (
|
| 162 |
+
OpenAIEmbeddingModelName
|
| 163 |
+
| GoogleEmbeddingModelName
|
| 164 |
+
| OllamaEmbeddingModelName
|
| 165 |
+
)
|
src/schema/schema.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Literal, NotRequired
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, SerializeAsAny
|
| 4 |
+
from typing_extensions import TypedDict
|
| 5 |
+
|
| 6 |
+
from schema.models import AllModelEnum, AnthropicModelName, OpenAIModelName
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AgentInfo(BaseModel):
|
| 10 |
+
"""Info about an available agent."""
|
| 11 |
+
|
| 12 |
+
key: str = Field(
|
| 13 |
+
description="Agent key.",
|
| 14 |
+
examples=["research-assistant"],
|
| 15 |
+
)
|
| 16 |
+
description: str = Field(
|
| 17 |
+
description="Description of the agent.",
|
| 18 |
+
examples=["A research assistant for generating research papers."],
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ServiceMetadata(BaseModel):
|
| 23 |
+
"""Metadata about the service including available agents and models."""
|
| 24 |
+
|
| 25 |
+
agents: list[AgentInfo] = Field(
|
| 26 |
+
description="List of available agents.",
|
| 27 |
+
)
|
| 28 |
+
models: list[AllModelEnum] = Field(
|
| 29 |
+
description="List of available LLMs.",
|
| 30 |
+
)
|
| 31 |
+
default_agent: str = Field(
|
| 32 |
+
description="Default agent used when none is specified.",
|
| 33 |
+
examples=["research-assistant"],
|
| 34 |
+
)
|
| 35 |
+
default_model: AllModelEnum = Field(
|
| 36 |
+
description="Default model used when none is specified.",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UserInput(BaseModel):
|
| 41 |
+
"""Basic user input for the agent."""
|
| 42 |
+
|
| 43 |
+
message: str = Field(
|
| 44 |
+
description="User input to the agent.",
|
| 45 |
+
examples=["What is the weather in Tokyo?"],
|
| 46 |
+
)
|
| 47 |
+
model: SerializeAsAny[AllModelEnum] | None = Field(
|
| 48 |
+
title="Model",
|
| 49 |
+
description="LLM Model to use for the agent. Defaults to the default model set in the settings of the service.",
|
| 50 |
+
default=None,
|
| 51 |
+
examples=[OpenAIModelName.GPT_5_NANO, AnthropicModelName.HAIKU_45],
|
| 52 |
+
)
|
| 53 |
+
thread_id: str | None = Field(
|
| 54 |
+
description="Thread ID to persist and continue a multi-turn conversation.",
|
| 55 |
+
default=None,
|
| 56 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 57 |
+
)
|
| 58 |
+
user_id: str | None = Field(
|
| 59 |
+
description="User ID to persist and continue a conversation across multiple threads.",
|
| 60 |
+
default=None,
|
| 61 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 62 |
+
)
|
| 63 |
+
agent_config: dict[str, Any] = Field(
|
| 64 |
+
description="Additional configuration to pass through to the agent",
|
| 65 |
+
default={},
|
| 66 |
+
examples=[{"spicy_level": 0.8}],
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class StreamInput(UserInput):
|
| 71 |
+
"""User input for streaming the agent's response."""
|
| 72 |
+
|
| 73 |
+
stream_tokens: bool = Field(
|
| 74 |
+
description="Whether to stream LLM tokens to the client.",
|
| 75 |
+
default=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ToolCall(TypedDict):
|
| 80 |
+
"""Represents a request to call a tool."""
|
| 81 |
+
|
| 82 |
+
name: str
|
| 83 |
+
"""The name of the tool to be called."""
|
| 84 |
+
args: dict[str, Any]
|
| 85 |
+
"""The arguments to the tool call."""
|
| 86 |
+
id: str | None
|
| 87 |
+
"""An identifier associated with the tool call."""
|
| 88 |
+
type: NotRequired[Literal["tool_call"]]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChatMessage(BaseModel):
|
| 92 |
+
"""Message in a chat."""
|
| 93 |
+
|
| 94 |
+
type: Literal["human", "ai", "tool", "custom"] = Field(
|
| 95 |
+
description="Role of the message.",
|
| 96 |
+
examples=["human", "ai", "tool", "custom"],
|
| 97 |
+
)
|
| 98 |
+
content: str = Field(
|
| 99 |
+
description="Content of the message.",
|
| 100 |
+
examples=["Hello, world!"],
|
| 101 |
+
)
|
| 102 |
+
tool_calls: list[ToolCall] = Field(
|
| 103 |
+
description="Tool calls in the message.",
|
| 104 |
+
default=[],
|
| 105 |
+
)
|
| 106 |
+
tool_call_id: str | None = Field(
|
| 107 |
+
description="Tool call that this message is responding to.",
|
| 108 |
+
default=None,
|
| 109 |
+
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"],
|
| 110 |
+
)
|
| 111 |
+
run_id: str | None = Field(
|
| 112 |
+
description="Run ID of the message.",
|
| 113 |
+
default=None,
|
| 114 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 115 |
+
)
|
| 116 |
+
response_metadata: dict[str, Any] = Field(
|
| 117 |
+
description="Response metadata. For example: response headers, logprobs, token counts.",
|
| 118 |
+
default={},
|
| 119 |
+
)
|
| 120 |
+
custom_data: dict[str, Any] = Field(
|
| 121 |
+
description="Custom message data.",
|
| 122 |
+
default={},
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def pretty_repr(self) -> str:
|
| 126 |
+
"""Get a pretty representation of the message."""
|
| 127 |
+
base_title = self.type.title() + " Message"
|
| 128 |
+
padded = " " + base_title + " "
|
| 129 |
+
sep_len = (80 - len(padded)) // 2
|
| 130 |
+
sep = "=" * sep_len
|
| 131 |
+
second_sep = sep + "=" if len(padded) % 2 else sep
|
| 132 |
+
title = f"{sep}{padded}{second_sep}"
|
| 133 |
+
return f"{title}\n\n{self.content}"
|
| 134 |
+
|
| 135 |
+
def pretty_print(self) -> None:
|
| 136 |
+
print(self.pretty_repr()) # noqa: T201
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Feedback(BaseModel): # type: ignore[no-redef]
|
| 140 |
+
"""Feedback for a run, to record to LangSmith."""
|
| 141 |
+
|
| 142 |
+
run_id: str = Field(
|
| 143 |
+
description="Run ID to record feedback for.",
|
| 144 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 145 |
+
)
|
| 146 |
+
key: str = Field(
|
| 147 |
+
description="Feedback key.",
|
| 148 |
+
examples=["human-feedback-stars"],
|
| 149 |
+
)
|
| 150 |
+
score: float = Field(
|
| 151 |
+
description="Feedback score.",
|
| 152 |
+
examples=[0.8],
|
| 153 |
+
)
|
| 154 |
+
kwargs: dict[str, Any] = Field(
|
| 155 |
+
description="Additional feedback kwargs, passed to LangSmith.",
|
| 156 |
+
default={},
|
| 157 |
+
examples=[{"comment": "In-line human feedback"}],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class FeedbackResponse(BaseModel):
|
| 162 |
+
status: Literal["success"] = "success"
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ChatHistoryInput(BaseModel):
|
| 166 |
+
"""Input for retrieving chat history."""
|
| 167 |
+
|
| 168 |
+
thread_id: str = Field(
|
| 169 |
+
description="Thread ID to persist and continue a multi-turn conversation.",
|
| 170 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ChatHistory(BaseModel):
|
| 175 |
+
messages: list[ChatMessage]
|
src/schema/task_data.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Literal
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TaskData(BaseModel):
|
| 7 |
+
name: str | None = Field(
|
| 8 |
+
description="Name of the task.", default=None, examples=["Check input safety"]
|
| 9 |
+
)
|
| 10 |
+
run_id: str = Field(
|
| 11 |
+
description="ID of the task run to pair state updates to.",
|
| 12 |
+
default="",
|
| 13 |
+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
|
| 14 |
+
)
|
| 15 |
+
state: Literal["new", "running", "complete"] | None = Field(
|
| 16 |
+
description="Current state of given task instance.",
|
| 17 |
+
default=None,
|
| 18 |
+
examples=["running"],
|
| 19 |
+
)
|
| 20 |
+
result: Literal["success", "error"] | None = Field(
|
| 21 |
+
description="Result of given task instance.",
|
| 22 |
+
default=None,
|
| 23 |
+
examples=["running"],
|
| 24 |
+
)
|
| 25 |
+
data: dict[str, Any] = Field(
|
| 26 |
+
description="Additional data generated by the task.",
|
| 27 |
+
default={},
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def completed(self) -> bool:
|
| 31 |
+
return self.state == "complete"
|
| 32 |
+
|
| 33 |
+
def completed_with_error(self) -> bool:
|
| 34 |
+
return self.state == "complete" and self.result == "error"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TaskDataStatus:
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
import streamlit as st
|
| 40 |
+
|
| 41 |
+
self.status = st.status("")
|
| 42 |
+
self.current_task_data: dict[str, TaskData] = {}
|
| 43 |
+
|
| 44 |
+
def add_and_draw_task_data(self, task_data: TaskData) -> None:
|
| 45 |
+
status = self.status
|
| 46 |
+
status_str = f"Task **{task_data.name}** "
|
| 47 |
+
match task_data.state:
|
| 48 |
+
case "new":
|
| 49 |
+
status_str += "has :blue[started]. Input:"
|
| 50 |
+
case "running":
|
| 51 |
+
status_str += "wrote:"
|
| 52 |
+
case "complete":
|
| 53 |
+
if task_data.result == "success":
|
| 54 |
+
status_str += ":green[completed successfully]. Output:"
|
| 55 |
+
else:
|
| 56 |
+
status_str += ":red[ended with error]. Output:"
|
| 57 |
+
status.write(status_str)
|
| 58 |
+
status.write(task_data.data)
|
| 59 |
+
status.write("---")
|
| 60 |
+
if task_data.run_id not in self.current_task_data:
|
| 61 |
+
# Status label always shows the last newly started task
|
| 62 |
+
status.update(label=f"""Task: {task_data.name}""")
|
| 63 |
+
self.current_task_data[task_data.run_id] = task_data
|
| 64 |
+
if all(entry.completed() for entry in self.current_task_data.values()):
|
| 65 |
+
# Status is "error" if any task has errored
|
| 66 |
+
if any(entry.completed_with_error() for entry in self.current_task_data.values()):
|
| 67 |
+
state = "error"
|
| 68 |
+
# Status is "complete" if all tasks have completed successfully
|
| 69 |
+
else:
|
| 70 |
+
state = "complete"
|
| 71 |
+
# Status is "running" until all tasks have completed
|
| 72 |
+
else:
|
| 73 |
+
state = "running"
|
| 74 |
+
status.update(state=state) # type: ignore[arg-type]
|
src/scripts/create_chroma_db.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain_chroma import Chroma
|
| 7 |
+
from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader
|
| 8 |
+
from langchain_openai import OpenAIEmbeddings
|
| 9 |
+
|
| 10 |
+
# Load environment variables from the .env file
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_chroma_db(
|
| 15 |
+
folder_path: str,
|
| 16 |
+
db_name: str = "./chroma_db",
|
| 17 |
+
delete_chroma_db: bool = True,
|
| 18 |
+
chunk_size: int = 2000,
|
| 19 |
+
overlap: int = 500,
|
| 20 |
+
):
|
| 21 |
+
embeddings = OpenAIEmbeddings(api_key=os.environ["OPENAI_API_KEY"])
|
| 22 |
+
|
| 23 |
+
# Initialize Chroma vector store
|
| 24 |
+
if delete_chroma_db and os.path.exists(db_name):
|
| 25 |
+
shutil.rmtree(db_name)
|
| 26 |
+
print(f"Deleted existing database at {db_name}")
|
| 27 |
+
|
| 28 |
+
chroma = Chroma(
|
| 29 |
+
embedding_function=embeddings,
|
| 30 |
+
persist_directory=f"./{db_name}",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Initialize text splitter
|
| 34 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
|
| 35 |
+
|
| 36 |
+
# Iterate over files in the folder
|
| 37 |
+
for filename in os.listdir(folder_path):
|
| 38 |
+
file_path = os.path.join(folder_path, filename)
|
| 39 |
+
|
| 40 |
+
# Load document based on file extension
|
| 41 |
+
# Add more loaders if required, i.e. JSONLoader, TxtLoader, etc.
|
| 42 |
+
if filename.endswith(".pdf"):
|
| 43 |
+
loader = PyPDFLoader(file_path)
|
| 44 |
+
elif filename.endswith(".docx"):
|
| 45 |
+
loader = Docx2txtLoader(file_path)
|
| 46 |
+
else:
|
| 47 |
+
continue # Skip unsupported file types
|
| 48 |
+
|
| 49 |
+
# Load and split document into chunks
|
| 50 |
+
document = loader.load()
|
| 51 |
+
chunks = text_splitter.split_documents(document)
|
| 52 |
+
|
| 53 |
+
# Add chunks to Chroma vector store
|
| 54 |
+
for chunk in chunks:
|
| 55 |
+
chunk_id = chroma.add_documents([chunk])
|
| 56 |
+
if chunk_id:
|
| 57 |
+
print(f"Chunk added with ID: {chunk_id}")
|
| 58 |
+
else:
|
| 59 |
+
print("Failed to add chunk")
|
| 60 |
+
|
| 61 |
+
print(f"Document {filename} added to database.")
|
| 62 |
+
|
| 63 |
+
print(f"Vector database created and saved in {db_name}.")
|
| 64 |
+
return chroma
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
# Path to the folder containing the documents
|
| 69 |
+
folder_path = "./data"
|
| 70 |
+
|
| 71 |
+
# Create the Chroma database
|
| 72 |
+
chroma = create_chroma_db(folder_path=folder_path)
|
| 73 |
+
|
| 74 |
+
# Create retriever from the Chroma database
|
| 75 |
+
retriever = chroma.as_retriever(search_kwargs={"k": 3})
|
| 76 |
+
|
| 77 |
+
# Perform a similarity search
|
| 78 |
+
query = "What's my company's mission and values"
|
| 79 |
+
similar_docs = retriever.invoke(query)
|
| 80 |
+
|
| 81 |
+
# Display results
|
| 82 |
+
for i, doc in enumerate(similar_docs, start=1):
|
| 83 |
+
print(f"\n🔹 Result {i}:\n{doc.page_content}\nTags: {doc.metadata.get('source', [])}")
|
src/scripts/load_portfolio.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
from core.settings import settings
|
| 6 |
+
from scripts.portfolio.portfolio_ingestion import PortfolioIngest
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(
|
| 9 |
+
level=logging.INFO, # Use INFO level to see all sync progress logs
|
| 10 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
parser = argparse.ArgumentParser(description="Synchronize portfolio data from Notion")
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--since",
|
| 19 |
+
type=str,
|
| 20 |
+
help="ISO 8601 date to sync from (e.g. 2024-01-01T00:00:00.000Z). If not provided, uses last sync date."
|
| 21 |
+
)
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
orchestrator = PortfolioIngest()
|
| 25 |
+
orchestrator.sync(args.since)
|
src/scripts/portfolio/document.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import hashlib
|
| 3 |
+
import time
|
| 4 |
+
import re
|
| 5 |
+
from typing import Optional, List, Tuple
|
| 6 |
+
|
| 7 |
+
from langsmith import uuid7
|
| 8 |
+
from langchain_core.documents import Document
|
| 9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
+
from langchain_core.runnables import RunnableConfig
|
| 11 |
+
from langchain_core.documents import Document
|
| 12 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 13 |
+
|
| 14 |
+
from core.llm import get_model
|
| 15 |
+
from core.settings import settings
|
| 16 |
+
from scripts.portfolio.prompt import PORTFOLIO_INGESTION_SYSTEM_PROMPT
|
| 17 |
+
|
| 18 |
+
class DocumentChunker:
|
| 19 |
+
"""Service for splitting documents into chunks."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, chunk_size: int = 1500, chunk_overlap: int = 200):
|
| 22 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 23 |
+
chunk_size=chunk_size,
|
| 24 |
+
chunk_overlap=chunk_overlap,
|
| 25 |
+
is_separator_regex=True,
|
| 26 |
+
)
|
| 27 |
+
print(f"DEBUG: Initialized DocumentChunker with chunk_size={chunk_size}, overlap={chunk_overlap}")
|
| 28 |
+
|
| 29 |
+
def chunk_document(self, doc: Document, base_id: str, content_hash: str) -> List[Tuple[Document, str]]:
|
| 30 |
+
"""
|
| 31 |
+
Splits a document into chunks and prepares them for storage.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
doc: The document to chunk
|
| 35 |
+
base_id: The base document ID
|
| 36 |
+
content_hash: The content hash for change detection
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
List of tuples (chunk_document, chunk_id)
|
| 40 |
+
"""
|
| 41 |
+
chunks = self.text_splitter.split_documents([doc])
|
| 42 |
+
chunked_docs = []
|
| 43 |
+
|
| 44 |
+
for idx, chunk in enumerate(chunks):
|
| 45 |
+
chunk_id = f"{base_id}_chunk_{idx}"
|
| 46 |
+
chunk.metadata["content_hash"] = content_hash
|
| 47 |
+
chunk.metadata["base_id"] = base_id
|
| 48 |
+
chunked_docs.append((chunk, chunk_id))
|
| 49 |
+
|
| 50 |
+
print(f"DEBUG: Split document {base_id} into {len(chunked_docs)} chunks")
|
| 51 |
+
return chunked_docs
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DocumentEnricher:
|
| 55 |
+
"""Service for enriching documents using LLM with generalized retry logic."""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.llm = get_model(settings.DEFAULT_MODEL)
|
| 59 |
+
self.enrich_prompt = ChatPromptTemplate.from_messages([
|
| 60 |
+
("system", PORTFOLIO_INGESTION_SYSTEM_PROMPT),
|
| 61 |
+
("human", "Category: {category}\n\nMetadata:\n{metadata}\n\nContent:\n{content}")
|
| 62 |
+
])
|
| 63 |
+
print(f"INFO: Initialized DocumentEnricher with {settings.DEFAULT_MODEL}")
|
| 64 |
+
|
| 65 |
+
def enrich(self, doc: Document, category: str, max_retries: int = 5) -> Tuple[Optional[Document], str, str]:
|
| 66 |
+
pid = str(doc.metadata.get("id", uuid7()))
|
| 67 |
+
title = doc.metadata.get("Title", "Untitled")
|
| 68 |
+
|
| 69 |
+
for attempt in range(max_retries):
|
| 70 |
+
try:
|
| 71 |
+
if attempt > 0:
|
| 72 |
+
wait_time = min(2 ** attempt, 60)
|
| 73 |
+
print(f"INFO: Retrying {title} (attempt {attempt + 1}/{max_retries}) in {wait_time}s...")
|
| 74 |
+
time.sleep(wait_time)
|
| 75 |
+
else:
|
| 76 |
+
print(f"INFO: Enriching document: {title} (PID: {pid})")
|
| 77 |
+
|
| 78 |
+
res = self.llm.invoke(
|
| 79 |
+
self.enrich_prompt.format_messages(
|
| 80 |
+
category=category,
|
| 81 |
+
metadata=json.dumps(doc.metadata, default=str),
|
| 82 |
+
content=doc.page_content or "No content provided."
|
| 83 |
+
),
|
| 84 |
+
config=RunnableConfig(run_id=uuid7())
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
enriched_content = res.content.strip()
|
| 88 |
+
content_hash = hashlib.sha256(enriched_content.encode('utf-8')).hexdigest()
|
| 89 |
+
|
| 90 |
+
enriched_doc = Document(
|
| 91 |
+
page_content=enriched_content,
|
| 92 |
+
metadata={
|
| 93 |
+
**doc.metadata,
|
| 94 |
+
"category": category,
|
| 95 |
+
"content_hash": content_hash,
|
| 96 |
+
"base_id": pid
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
return enriched_doc, pid, content_hash
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
error_msg = str(e).lower()
|
| 103 |
+
error_type = type(e).__name__.lower()
|
| 104 |
+
|
| 105 |
+
# --- Rate Limit Detection ---
|
| 106 |
+
is_rate_limit = any(keyword in error_msg or keyword in error_type
|
| 107 |
+
for keyword in ["429", "rate_limit", "rate limit", "too many requests", "throttled"])
|
| 108 |
+
|
| 109 |
+
# --- Overloaded/Server Error Detection ---
|
| 110 |
+
is_server_error = any(keyword in error_msg
|
| 111 |
+
for keyword in ["500", "502", "503", "overloaded", "unavailable", "deadline_exceeded"])
|
| 112 |
+
|
| 113 |
+
if is_rate_limit or is_server_error:
|
| 114 |
+
wait_time = 5 # Default
|
| 115 |
+
match = re.search(r'(?:try again in|retry after|wait)\s*([\d.]+)\s*s', error_msg)
|
| 116 |
+
if match:
|
| 117 |
+
wait_time = float(match.group(1)) + 1
|
| 118 |
+
|
| 119 |
+
if attempt < max_retries - 1:
|
| 120 |
+
print(f"WARN: API issue (Rate Limit/Overload) for {title}. Waiting {wait_time}s...")
|
| 121 |
+
time.sleep(wait_time)
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Non-retriable or final attempt failure
|
| 125 |
+
print(f"ERROR: Enrichment failed for {title}: {e}")
|
| 126 |
+
if attempt >= 1:
|
| 127 |
+
return None, pid, ""
|
| 128 |
+
|
| 129 |
+
return None, pid, ""
|
src/scripts/portfolio/notion_loader.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from langchain_community.document_loaders import NotionDBLoader
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
|
| 8 |
+
NOTION_DB_MAP = {
|
| 9 |
+
"education": "NOTION_EDUCATION_ID",
|
| 10 |
+
"experience": "NOTION_EXPERIENCE_ID",
|
| 11 |
+
"projects": "NOTION_PROJECT_ID",
|
| 12 |
+
"testimonials": "NOTION_TESTIMONIAL_ID",
|
| 13 |
+
"blog": "NOTION_BLOG_ID",
|
| 14 |
+
}
|
| 15 |
+
class NotionLoader:
|
| 16 |
+
"""Service for loading documents from Notion databases."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.token = os.getenv("NOTION_TOKEN")
|
| 20 |
+
if not self.token:
|
| 21 |
+
print("WARNING: NOTION_TOKEN not found in environment")
|
| 22 |
+
|
| 23 |
+
def load_category(self, category: str, since_date: Optional[str] = None) -> List[Document]:
|
| 24 |
+
"""
|
| 25 |
+
Loads documents from a Notion database for a given category.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
category: The category name (must be in NOTION_DB_MAP)
|
| 29 |
+
since_date: Optional ISO 8601 date to filter documents updated after this date
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List of documents that were updated after since_date
|
| 33 |
+
"""
|
| 34 |
+
if category not in NOTION_DB_MAP:
|
| 35 |
+
print(f"WARNING: Unknown category: {category}")
|
| 36 |
+
return []
|
| 37 |
+
|
| 38 |
+
env_key = NOTION_DB_MAP[category]
|
| 39 |
+
db_id = os.getenv(env_key)
|
| 40 |
+
|
| 41 |
+
if not db_id:
|
| 42 |
+
print(f"WARNING: Notion database ID not found for category: {category}")
|
| 43 |
+
return []
|
| 44 |
+
|
| 45 |
+
if not self.token:
|
| 46 |
+
print("ERROR: NOTION_TOKEN not available")
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
print(f"INFO: Loading {category} documents from Notion...")
|
| 51 |
+
loader = NotionDBLoader(self.token, db_id)
|
| 52 |
+
all_docs = loader.load()
|
| 53 |
+
|
| 54 |
+
if since_date:
|
| 55 |
+
valid_docs = [
|
| 56 |
+
d for d in all_docs
|
| 57 |
+
if d.metadata.get("updated", "") > since_date
|
| 58 |
+
]
|
| 59 |
+
print(f"INFO: Found {len(valid_docs)} documents updated after {since_date} out of {len(all_docs)} total")
|
| 60 |
+
else:
|
| 61 |
+
valid_docs = all_docs
|
| 62 |
+
print(f"INFO: Loaded {len(valid_docs)} documents (no date filter)")
|
| 63 |
+
|
| 64 |
+
return valid_docs
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"ERROR: Failed to load {category} from Notion: {e}")
|
| 67 |
+
traceback.print_exc()
|
| 68 |
+
return []
|
src/scripts/portfolio/portfolio_ingestion.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from langchain_community.vectorstores.utils import filter_complex_metadata
|
| 5 |
+
|
| 6 |
+
from core.settings import settings
|
| 7 |
+
from memory.postgres import load_pgvector_store
|
| 8 |
+
from scripts.portfolio.document import DocumentChunker, DocumentEnricher
|
| 9 |
+
from scripts.portfolio.notion_loader import NotionLoader
|
| 10 |
+
from scripts.portfolio.vector_repository import VectorRepository
|
| 11 |
+
NOTION_DB_MAP = {
|
| 12 |
+
"education": "NOTION_EDUCATION_ID",
|
| 13 |
+
"experience": "NOTION_EXPERIENCE_ID",
|
| 14 |
+
"projects": "NOTION_PROJECT_ID",
|
| 15 |
+
"testimonials": "NOTION_TESTIMONIAL_ID",
|
| 16 |
+
"blog": "NOTION_BLOG_ID",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
class PortfolioIngest:
|
| 20 |
+
"""Orchestrates the portfolio synchronization process."""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.enricher = DocumentEnricher()
|
| 24 |
+
self.chunker = DocumentChunker()
|
| 25 |
+
self.loader = NotionLoader()
|
| 26 |
+
self.repository = VectorRepository()
|
| 27 |
+
self.store = load_pgvector_store()
|
| 28 |
+
print(f"INFO: Initialized PortfolioIngest with collection: {settings.VECTOR_STORE_COLLECTION_NAME}")
|
| 29 |
+
|
| 30 |
+
def ingest_category(self, category: str, since_date: str) -> Tuple[int, int, int, int]:
|
| 31 |
+
"""
|
| 32 |
+
Synchronizes a single category of documents.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
category: The category to sync
|
| 36 |
+
since_date: ISO 8601 date to filter documents
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Tuple of (chunks_deleted, chunks_updated, chunks_skipped, total_synced)
|
| 40 |
+
"""
|
| 41 |
+
# Load documents from Notion
|
| 42 |
+
valid_docs = self.loader.load_category(category, since_date)
|
| 43 |
+
if not valid_docs:
|
| 44 |
+
print(f"INFO: No new updates for {category}")
|
| 45 |
+
return 0, 0, 0, 0
|
| 46 |
+
|
| 47 |
+
# Enrich documents with LLM sequentially to avoid rate limits
|
| 48 |
+
print(f"INFO: Enriching {len(valid_docs)} documents with LLM...")
|
| 49 |
+
|
| 50 |
+
enriched_docs = []
|
| 51 |
+
base_ids = []
|
| 52 |
+
content_hashes = []
|
| 53 |
+
|
| 54 |
+
for doc in valid_docs:
|
| 55 |
+
enriched_doc, base_id, content_hash = self.enricher.enrich(doc, category)
|
| 56 |
+
if enriched_doc is not None:
|
| 57 |
+
enriched_docs.append(enriched_doc)
|
| 58 |
+
base_ids.append(base_id)
|
| 59 |
+
content_hashes.append(content_hash)
|
| 60 |
+
|
| 61 |
+
if not enriched_docs:
|
| 62 |
+
print(f"INFO: No enriched documents for {category}")
|
| 63 |
+
return 0, 0, 0, 0
|
| 64 |
+
|
| 65 |
+
# Batch fetch existing content hashes
|
| 66 |
+
print(f"INFO: Checking content hashes for {len(enriched_docs)} documents...")
|
| 67 |
+
existing_hashes = self.repository.batch_get_existing_content_hashes(base_ids)
|
| 68 |
+
print(f"DEBUG: Found {len([h for h in existing_hashes.values() if h is not None])} existing hashes")
|
| 69 |
+
|
| 70 |
+
# Process documents and prepare for upsert
|
| 71 |
+
docs_to_upsert = []
|
| 72 |
+
ids_to_upsert = []
|
| 73 |
+
chunks_deleted = 0
|
| 74 |
+
chunks_updated = 0
|
| 75 |
+
chunks_skipped = 0
|
| 76 |
+
|
| 77 |
+
for enriched_doc, base_id, new_hash in zip(enriched_docs, base_ids, content_hashes):
|
| 78 |
+
existing_hash = existing_hashes.get(base_id)
|
| 79 |
+
|
| 80 |
+
if existing_hash is None:
|
| 81 |
+
# New document
|
| 82 |
+
print(f"INFO: New document {base_id}, adding all chunks...")
|
| 83 |
+
chunked = self.chunker.chunk_document(enriched_doc, base_id, new_hash)
|
| 84 |
+
for chunk, chunk_id in chunked:
|
| 85 |
+
docs_to_upsert.append(chunk)
|
| 86 |
+
ids_to_upsert.append(chunk_id)
|
| 87 |
+
chunks_updated += len(chunked)
|
| 88 |
+
elif existing_hash != new_hash:
|
| 89 |
+
# Content changed
|
| 90 |
+
print(f"INFO: Content changed for {base_id}, replacing chunks...")
|
| 91 |
+
old_chunk_ids = self.repository.get_existing_chunks(base_id)
|
| 92 |
+
if old_chunk_ids:
|
| 93 |
+
self.store.delete(old_chunk_ids)
|
| 94 |
+
chunks_deleted += len(old_chunk_ids)
|
| 95 |
+
print(f"DEBUG: Deleted {len(old_chunk_ids)} old chunks for {base_id}")
|
| 96 |
+
|
| 97 |
+
chunked = self.chunker.chunk_document(enriched_doc, base_id, new_hash)
|
| 98 |
+
for chunk, chunk_id in chunked:
|
| 99 |
+
docs_to_upsert.append(chunk)
|
| 100 |
+
ids_to_upsert.append(chunk_id)
|
| 101 |
+
chunks_updated += len(chunked)
|
| 102 |
+
else:
|
| 103 |
+
# Content unchanged
|
| 104 |
+
print(f"DEBUG: Content unchanged for {base_id}, skipping...")
|
| 105 |
+
chunks_skipped += 1
|
| 106 |
+
|
| 107 |
+
# Upsert documents
|
| 108 |
+
total_synced = 0
|
| 109 |
+
if docs_to_upsert:
|
| 110 |
+
print(f"INFO: Upserting {len(docs_to_upsert)} chunks...")
|
| 111 |
+
self.store.add_documents(
|
| 112 |
+
filter_complex_metadata(docs_to_upsert),
|
| 113 |
+
ids=ids_to_upsert
|
| 114 |
+
)
|
| 115 |
+
total_synced = len(docs_to_upsert)
|
| 116 |
+
print(
|
| 117 |
+
f"INFO: Category {category}: Deleted {chunks_deleted} chunks, "
|
| 118 |
+
f"updated {chunks_updated} chunks, skipped {chunks_skipped} unchanged documents, "
|
| 119 |
+
f"total synced: {total_synced}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return chunks_deleted, chunks_updated, chunks_skipped, total_synced
|
| 123 |
+
|
| 124 |
+
def sync(self, manual_date: Optional[str] = None) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Synchronizes all portfolio categories.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
manual_date: Optional ISO 8601 date to override last sync date
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Total number of chunks synced
|
| 133 |
+
"""
|
| 134 |
+
since_date = manual_date or self.repository.get_last_sync_date()
|
| 135 |
+
print(f"INFO: Starting sync for items modified after: {since_date}")
|
| 136 |
+
|
| 137 |
+
total_synced = 0
|
| 138 |
+
|
| 139 |
+
for category in NOTION_DB_MAP.keys():
|
| 140 |
+
try:
|
| 141 |
+
chunks_deleted, chunks_updated, chunks_skipped, synced = self.ingest_category(
|
| 142 |
+
category, since_date
|
| 143 |
+
)
|
| 144 |
+
total_synced += synced
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"ERROR: Failed to sync category {category}: {e}")
|
| 147 |
+
traceback.print_exc()
|
| 148 |
+
|
| 149 |
+
print(f"INFO: Sync complete. Total chunks synced: {total_synced}")
|
| 150 |
+
return total_synced
|
src/scripts/portfolio/prompt.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
OWNER = os.getenv("OWNER", "Anuj Joshi")
|
| 8 |
+
CURRENT_DATE = datetime.now().strftime("%Y-%m-%d")
|
| 9 |
+
|
| 10 |
+
PORTFOLIO_INGESTION_SYSTEM_PROMPT = f"""
|
| 11 |
+
SYSTEM ROLE: You are the Principal Technical Documentation Engineer for {OWNER}.
|
| 12 |
+
|
| 13 |
+
Your goal is to transform raw project data into high-density retrieval-optimized technical documentation for {OWNER}.
|
| 14 |
+
This documentation is the primary source for a RAG (Retrieval-Augmented Generation) system used by Staff Engineers and Technical Recruiters.
|
| 15 |
+
|
| 16 |
+
General Behavior
|
| 17 |
+
- Do include a header block that summarizes the content of the document in a concise and catchy manner.
|
| 18 |
+
- Do include thumbnail/links if provided in the header block if available.
|
| 19 |
+
- Do include other links in other sections in proper markdown format if available and relevant.
|
| 20 |
+
- Use "Information-Dense Sentences." Every sentence must provide a new fact.
|
| 21 |
+
- Avoid pronouns, repeat the subject or {OWNER}'s name to ensure chunks remain context-aware when retrieved in isolation.
|
| 22 |
+
- If data is missing, do not hallucinate and omit the specific metric.
|
| 23 |
+
- Do NOT use vague impact words.
|
| 24 |
+
- Do NOT restate the same idea in different words.
|
| 25 |
+
- Do NOT invent metrics, scale, or outcomes.
|
| 26 |
+
- Do NOT include emojis or stylistic symbols.
|
| 27 |
+
- No repetitive framing, long paragraphs, or storytelling.
|
| 28 |
+
- Optimize for semantic search and chunk retrieval.
|
| 29 |
+
"""
|