# space/app.py import os import json import logging import gradio as gr import pandas as pd from typing import Optional, Tuple from tools.sql_tool import SQLTool from tools.predict_tool import PredictTool from tools.explain_tool import ExplainTool from tools.report_tool import ReportTool from tools.ts_preprocess import build_timeseries from tools.ts_forecast_tool import TimeseriesForecastTool from utils.tracing import Tracer from utils.config import AppConfig # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Constants MAX_RESPONSE_LENGTH = 10000 MAX_FORECAST_HORIZON = 365 DEFAULT_FORECAST_HORIZON = 96 # Optional LLM for planning llm = None LLM_ID = os.getenv("ORCHESTRATOR_MODEL") if LLM_ID: try: from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline logger.info(f"Loading orchestrator model: {LLM_ID}") _tok = AutoTokenizer.from_pretrained(LLM_ID) _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID) llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512) logger.info("Orchestrator model loaded successfully") except Exception as e: logger.warning(f"Failed to load orchestrator model: {e}. Using fallback planner.") llm = None # Initialize configuration and tools try: cfg = AppConfig.from_env() tracer = Tracer.from_env() sql_tool = SQLTool(cfg, tracer) predict_tool = PredictTool(cfg, tracer) explain_tool = ExplainTool(cfg, tracer) report_tool = ReportTool(cfg, tracer) ts_tool = TimeseriesForecastTool(cfg, tracer) logger.info("All tools initialized successfully") except Exception as e: logger.error(f"Failed to initialize application: {e}") raise SYSTEM_PROMPT = ( "You are an analytical assistant for tabular data. " "Decide which tools to call in order: " "1) SQL (retrieve) 2) Predict (score) 3) Explain (SHAP) 4) Report (document) 5) Forecast (Granite TTM). " "Always disclose the steps taken." ) def validate_message(message: str) -> Tuple[bool, str]: """Validate user input message.""" if not message or not message.strip(): return False, "Please enter a valid question." if len(message) > MAX_RESPONSE_LENGTH: return False, f"Message too long. Please limit to {MAX_RESPONSE_LENGTH} characters." # Basic SQL injection pattern detection suspicious_patterns = [ r';\s*drop\s+table', r';\s*delete\s+from', r';\s*truncate', r'union\s+select.*from', r'exec\s*\(', r'execute\s*\(' ] import re msg_lower = message.lower() for pattern in suspicious_patterns: if re.search(pattern, msg_lower): logger.warning(f"Suspicious SQL pattern detected: {pattern}") return False, "Query contains potentially unsafe patterns. Please rephrase." return True, "" def plan_actions(message: str) -> dict: """ Determine which tools to execute based on the user message. Uses LLM if available, otherwise falls back to heuristics. """ if llm is not None: prompt = ( f"{SYSTEM_PROMPT}\nUser: {message}\n" "Return JSON with fields: steps (array subset of ['sql','predict','explain','report','forecast']), rationale." ) try: out = llm(prompt)[0]["generated_text"] last = out.split("\n")[-1].strip() obj = json.loads(last) if last.startswith("{") else json.loads(out[out.rfind("{"):]) if isinstance(obj, dict) and "steps" in obj: # Validate steps valid_steps = {'sql', 'predict', 'explain', 'report', 'forecast'} obj["steps"] = [s for s in obj["steps"] if s in valid_steps] if obj["steps"]: logger.info(f"LLM plan: {obj['steps']}") return obj except json.JSONDecodeError as e: logger.warning(f"Failed to parse LLM output as JSON: {e}") except Exception as e: logger.warning(f"LLM planning failed: {e}") # Fallback heuristic planning m = message.lower() steps = [] # SQL keywords if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi", "data", "retrieve", "fetch", "list", "view"]): steps.append("sql") # Prediction keywords if any(k in m for k in ["predict", "score", "risk", "propensity", "probability", "classification", "regression"]): steps.append("predict") if "sql" not in steps: steps.insert(0, "sql") # Need data first # Explanation keywords if any(k in m for k in ["why", "explain", "shap", "feature", "attribution", "importance", "interpret"]): steps.append("explain") if "predict" not in steps: steps.insert(0, "predict") if "sql" not in steps: steps.insert(0, "sql") # Report keywords if any(k in m for k in ["report", "download", "pdf", "summary", "document", "export"]): steps.append("report") # Forecast keywords if any(k in m for k in ["forecast", "next", "horizon", "granite", "predict future", "time series", "timeseries"]): steps.append("forecast") if "sql" not in steps: steps.insert(0, "sql") # Default to SQL if no steps identified if not steps: steps = ["sql"] rationale = f"Rule-based plan based on keywords: {', '.join(steps)}" logger.info(f"Heuristic plan: {steps}") return {"steps": steps, "rationale": rationale} def run_agent( message: str, hitl_decision: str = "Approve", reviewer_note: str = "" ) -> Tuple[str, pd.DataFrame]: """ Main agent execution function. Args: message: User query hitl_decision: Human-in-the-loop decision reviewer_note: Optional review notes Returns: Tuple of (response_text, preview_dataframe) """ try: # Validate input is_valid, error_msg = validate_message(message) if not is_valid: logger.warning(f"Invalid message: {error_msg}") return f"❌ **Error:** {error_msg}", pd.DataFrame() tracer.trace_event("user_message", {"message": message[:500]}) # Limit traced message length # Plan actions try: plan = plan_actions(message) tracer.trace_event("plan", plan) except Exception as e: logger.error(f"Planning failed: {e}") return f"❌ **Planning Error:** Unable to create execution plan. {str(e)}", pd.DataFrame() # Initialize result containers sql_df = None predict_df = None explain_imgs = {} artifacts = {} ts_forecast_df = None errors = [] # Execute SQL step if "sql" in plan["steps"]: try: sql_df = sql_tool.run(message) if isinstance(sql_df, pd.DataFrame): artifacts["sql_rows"] = len(sql_df) logger.info(f"SQL returned {len(sql_df)} rows") else: errors.append("SQL query returned no data") except Exception as e: error_msg = f"SQL execution failed: {str(e)}" logger.error(error_msg) errors.append(error_msg) # Execute prediction step if "predict" in plan["steps"]: try: if sql_df is not None and not sql_df.empty: predict_df = predict_tool.run(sql_df) if isinstance(predict_df, pd.DataFrame): artifacts["predict_rows"] = len(predict_df) logger.info(f"Predictions generated for {len(predict_df)} rows") else: errors.append("Prediction skipped: no data available") except Exception as e: error_msg = f"Prediction failed: {str(e)}" logger.error(error_msg) errors.append(error_msg) # Build time series if possible ts_df = None if sql_df is not None and not sql_df.empty: try: ts_df = build_timeseries(sql_df) logger.info(f"Time series built with {len(ts_df)} records") except Exception as e: logger.info(f"Time series preprocessing skipped: {e}") # Not always an error - data might not be suitable for TS # Execute forecast step if "forecast" in plan["steps"]: if ts_df is not None and not ts_df.empty: try: # Aggregate portfolio value by timestamp agg = ts_df.groupby("timestamp", as_index=True)["portfolio_value"].sum().sort_index() if len(agg) < 2: errors.append("Insufficient time series data for forecasting (need at least 2 points)") else: # Validate horizon horizon = min(DEFAULT_FORECAST_HORIZON, MAX_FORECAST_HORIZON) ts_forecast_df = ts_tool.zeroshot_forecast(agg, horizon=horizon) if isinstance(ts_forecast_df, pd.DataFrame): if "error" in ts_forecast_df.columns: errors.append(f"Forecast error: {ts_forecast_df['error'].iloc[0]}") ts_forecast_df = None else: artifacts["forecast_horizon"] = len(ts_forecast_df) logger.info(f"Forecast generated for {len(ts_forecast_df)} periods") except Exception as e: error_msg = f"Forecasting failed: {str(e)}" logger.error(error_msg) errors.append(error_msg) else: errors.append("Forecast skipped: no suitable time series data") # Execute explanation step if "explain" in plan["steps"]: try: explain_data = predict_df if predict_df is not None else sql_df if explain_data is not None and not explain_data.empty: explain_imgs = explain_tool.run(explain_data) artifacts["explain_charts"] = len(explain_imgs) logger.info(f"Generated {len(explain_imgs)} explanation charts") else: errors.append("Explanation skipped: no data available") except Exception as e: error_msg = f"Explanation failed: {str(e)}" logger.error(error_msg) errors.append(error_msg) # Execute report generation report_link = None if "report" in plan["steps"]: try: forecast_preview = ts_forecast_df.head(50) if isinstance(ts_forecast_df, pd.DataFrame) else None report_link = report_tool.render_and_save( user_query=message, sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None, predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else forecast_preview, explain_images=explain_imgs, plan=plan, ) logger.info(f"Report generated: {report_link}") except Exception as e: error_msg = f"Report generation failed: {str(e)}" logger.error(error_msg) errors.append(error_msg) # Log human-in-the-loop decision tracer.trace_event("hitl", { "message": message[:500], "decision": hitl_decision, "reviewer_note": reviewer_note[:500] if reviewer_note else "", "artifacts": artifacts, "plan": plan, "errors": errors, }) # Compose response response = f"**Plan:** {', '.join(plan['steps'])}\n\n**Rationale:** {plan['rationale']}\n\n" # Add artifacts info if artifacts: response += "**Results:**\n" if "sql_rows" in artifacts: response += f"- SQL query returned {artifacts['sql_rows']} rows\n" if "predict_rows" in artifacts: response += f"- Generated predictions for {artifacts['predict_rows']} rows\n" if "forecast_horizon" in artifacts: response += f"- Forecast generated for {artifacts['forecast_horizon']} periods\n" if "explain_charts" in artifacts: response += f"- Created {artifacts['explain_charts']} explanation charts\n" response += "\n" # Add report link if report_link: response += f"📄 **Report:** {report_link}\n\n" # Add trace URL if tracer.trace_url: response += f"🔍 **Trace:** {tracer.trace_url}\n\n" # Add errors if any if errors: response += "**⚠️ Warnings/Errors:**\n" for err in errors: response += f"- {err}\n" # Determine preview dataframe if isinstance(ts_forecast_df, pd.DataFrame) and not ts_forecast_df.empty: preview_df = ts_forecast_df.head(100) elif isinstance(predict_df, pd.DataFrame) and not predict_df.empty: preview_df = predict_df.head(100) elif isinstance(sql_df, pd.DataFrame) and not sql_df.empty: preview_df = sql_df.head(100) else: preview_df = pd.DataFrame({"message": ["No data to display"]}) return response, preview_df except Exception as e: error_msg = f"Unexpected error in agent execution: {str(e)}" logger.exception(error_msg) tracer.trace_event("error", {"message": error_msg}) return f"❌ **Critical Error:** {error_msg}", pd.DataFrame() # Gradio Interface with gr.Blocks(title="Tabular Agentic XAI") as demo: gr.Markdown(""" # 🤖 Tabular Agentic XAI (Enterprise Edition) An intelligent assistant for analyzing tabular data with ML predictions, explanations, and time-series forecasting. **Capabilities:** - 📊 SQL queries and data retrieval - 🎯 ML predictions with confidence scores - 🔍 SHAP-based model explanations - 📈 Time-series forecasting with Granite TTM - 📄 Automated report generation """) with gr.Row(): msg = gr.Textbox( label="Ask your question", placeholder="e.g., Show me the top 10 customers by revenue, predict churn risk, forecast next quarter...", lines=3 ) with gr.Row(): hitl = gr.Radio( ["Approve", "Needs Changes"], value="Approve", label="Human Review", info="Review the planned actions before execution" ) note = gr.Textbox( label="Reviewer note (optional)", placeholder="Add any review comments...", lines=2 ) out_md = gr.Markdown(label="Response") out_df = gr.Dataframe( interactive=False, label="Data Preview (max 100 rows)", wrap=True ) with gr.Row(): ask = gr.Button("🚀 Run Analysis", variant="primary") clear = gr.Button("🔄 Clear") ask.click( run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df] ) clear.click( lambda: ("", "Approve", "", "", pd.DataFrame()), outputs=[msg, hitl, note, out_md, out_df] ) gr.Markdown(""" --- **Tips:** - Be specific in your queries for better results - Use natural language - the system will interpret your intent - Review the execution plan before approving - Check the trace link for detailed execution logs """) if __name__ == "__main__": logger.info("Starting Gradio application...") demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True )