import anthropic import json import os import re from dotenv import load_dotenv load_dotenv() class NaturalLanguageParser: """Advanced Natural Language to SQL Engine using Claude""" def __init__(self): api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise ValueError("ANTHROPIC_API_KEY not found in .env file!") self.client = anthropic.Anthropic(api_key=api_key) def generate_sql(self, description, schema, dialect="PostgreSQL"): """ Generate complete, optimized SQL directly from natural language. This is the main engine that handles all SQL generation logic. """ schema_text = self._format_schema(schema) prompt = f"""You are an advanced Natural-Language-to-SQL Engine. Your job is to convert user instructions into correct, executable, efficient SQL queries, based solely on the provided database schema. šŸ”„ Core Responsibilities 1. NEVER hallucinate tables or columns. Only use what exists in the provided schema. 2. Follow the SQL dialect: {dialect} 3. Validate user intent and generate the best query, even if their natural language is unclear. 4. Fix common SQL mistakes: - Use IS NULL / IS NOT NULL (never = NULL) - Use >= and < for date ranges instead of BETWEEN - Proper JOIN syntax - Correct aggregate/grouping logic 5. Optimize the SQL: - Use proper filters - Use EXISTS/NOT EXISTS for subqueries - Use CTEs for clarity in complex queries - Avoid unnecessary computation 6. For complex analysis (cohorts, trends, rankings, window functions): - Use Common Table Expressions (CTEs) - Use window functions when appropriate šŸ—ƒļø Available Database Schema: {schema_text} šŸ‘¤ User Request: {description} šŸŽÆ SQL Construction Rules: WHERE clauses: - Use IS NULL / IS NOT NULL, not = 'NULL' - For date ranges use >= and < instead of BETWEEN Aggregations: - GROUP BY all non-aggregated fields - Avoid GROUP BY unnecessary columns Joins: - Prefer explicit JOIN syntax - Use LEFT JOIN for "missing data" queries - Always specify join conditions clearly Subqueries: - Prefer EXISTS/NOT EXISTS for performance - Use CTEs for readability in complex queries Window Functions (use when user asks for): - Top N per group - Rankings - Running totals - Comparisons to averages - Rolling windows šŸ“‹ Output Format: Return a JSON object with this exact structure: {{ "sql": "the complete SQL query here", "explanation": "brief explanation of what the query does", "query_type": "simple|aggregate|join|window|cte|analytical", "warnings": ["any warnings about schema limitations or assumptions"], "optimizations": ["list of optimizations applied"] }} CRITICAL RULES FOR JSON OUTPUT: - Return ONLY valid JSON (no markdown, no code blocks, no extra text) - Escape ALL special characters in the SQL string: * Newlines must be \\n * Quotes must be \\" * Backslashes must be \\\\ - The "sql" field must be a single-line string with \\n for line breaks - Always end SQL with semicolon - If the request is impossible with the given schema, set "sql" to "-- ERROR: " and explain in "warnings" Generate the SQL query now:""" try: response = self.client.messages.create( model="claude-3-opus-20240229", max_tokens=4000, messages=[{"role": "user", "content": prompt}] ) content = response.content[0].text.strip() # Remove markdown code blocks if present content = content.replace("```json", "").replace("```", "").strip() # Try to parse JSON try: result = json.loads(content) # Ensure all required fields exist if "sql" not in result or not result["sql"]: result["sql"] = "-- ERROR: No SQL generated" if "explanation" not in result: result["explanation"] = "SQL query generated" if "query_type" not in result: result["query_type"] = "select" if "warnings" not in result: result["warnings"] = [] if "optimizations" not in result: result["optimizations"] = [] return result except json.JSONDecodeError as e: # If JSON parsing fails, try to extract components manually using regex print(f"JSON parsing failed: {e}, attempting manual extraction...") # Try to extract SQL between "sql": " and next quote sql_match = re.search(r'"sql"\s*:\s*"((?:[^"\\]|\\.|\\n)*)"', content, re.DOTALL) if sql_match: sql = sql_match.group(1) # Unescape the JSON string sql = sql.replace('\\n', '\n').replace('\\"', '"').replace('\\\\', '\\') else: # Try alternative format or use entire content sql = content # Try to extract explanation expl_match = re.search(r'"explanation"\s*:\s*"((?:[^"\\]|\\.)*)"', content, re.DOTALL) explanation = expl_match.group(1) if expl_match else "SQL generated (with parsing issues)" # Try to extract query type type_match = re.search(r'"query_type"\s*:\s*"([^"]*)"', content) query_type = type_match.group(1) if type_match else "select" result = { "sql": sql, "explanation": explanation, "query_type": query_type, "warnings": ["JSON parsing issue - SQL may need review"], "optimizations": [] } return result except Exception as e: return { "sql": f"-- ERROR: {str(e)}", "explanation": "An error occurred during SQL generation", "query_type": "error", "warnings": [str(e)], "optimizations": [] } def _format_schema(self, schema): """Format schema for prompt""" lines = ["Tables and Columns:"] for table, columns in schema.items(): lines.append(f"\nšŸ“Š {table}") for col in columns: nullable = "NULL" if col.get('nullable', True) else "NOT NULL" pk = " (PRIMARY KEY)" if col.get('primary_key', False) else "" lines.append(f" - {col['name']}: {col['type']} {nullable}{pk}") return "\n".join(lines) # Legacy method for backward compatibility def parse(self, description, schema): """Legacy method - redirects to generate_sql""" result = self.generate_sql(description, schema) # Return just the SQL for backward compatibility return {"raw_sql": result.get("sql", ""), "metadata": result}