AshenH commited on
Commit
0f166dc
·
verified ·
1 Parent(s): 7162e44

Uploaded support files

Browse files
templates/report_styles.css ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; padding: 24px; line-height: 1.5; }
2
+ h1,h2,h3 { margin-top: 1.2em; }
3
+ code, pre { background: #f6f8fa; padding: 2px 4px; border-radius: 4px; }
4
+ table { border-collapse: collapse; width: 100%; }
5
+ th, td { border: 1px solid #ddd; padding: 8px; }
6
+ th { background: #fafafa; }
templates/report_template.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Insight Report
2
+
3
+ **User Query**: {{ user_query }}
4
+
5
+ **Plan**: {{ plan.steps }}
6
+ **Rationale**: {{ plan.rationale }}
7
+
8
+ {% if sql_preview %}
9
+ ## SQL Preview
10
+ {{ sql_preview }}
11
+ {% endif %}
12
+
13
+ {% if predict_preview %}
14
+ ## Predictions Preview
15
+ {{ predict_preview }}
16
+ {% endif %}
17
+
18
+ {% if explain_images.global_bar %}
19
+ ## Global Feature Importance (SHAP)
20
+ <img src="{{ explain_images.global_bar }}" style="max-width: 100%;" />
21
+ {% endif %}
22
+
23
+ {% if explain_images.beeswarm %}
24
+ ## SHAP Beeswarm
25
+ <img src="{{ explain_images.beeswarm }}" style="max-width: 100%;" />
26
+ {% endif %}
tools/__init__.py ADDED
File without changes
tools/explain_tool.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import shap
4
+ import base64
5
+ import pandas as pd
6
+ from huggingface_hub import hf_hub_download
7
+ from ..utils.config import AppConfig
8
+ from ..utils.tracing import Tracer
9
+
10
+ class ExplainTool:
11
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
12
+ self.cfg = cfg
13
+ self.tracer = tracer
14
+ self._model = None
15
+
16
+ def _ensure_model(self):
17
+ if self._model is None:
18
+ import joblib
19
+ path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
20
+ self._model = joblib.load(path)
21
+
22
+ def _to_data_uri(self, fig) -> str:
23
+ buf = io.BytesIO()
24
+ fig.savefig(buf, format="png", bbox_inches="tight")
25
+ buf.seek(0)
26
+ return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
27
+
28
+ def run(self, df: pd.DataFrame):
29
+ self._ensure_model()
30
+ # Use a small sample for speed on CPU Spaces
31
+ sample = df.sample(min(len(df), 500), random_state=42)
32
+ explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns))
33
+ shap_values = explainer(sample)
34
+
35
+ # Global summary plot
36
+ fig1 = shap.plots.bar(shap_values, show=False)
37
+ img1 = self._to_data_uri(fig1)
38
+
39
+ # Beeswarm (optional)
40
+ fig2 = shap.plots.beeswarm(shap_values, show=False)
41
+ img2 = self._to_data_uri(fig2)
42
+
43
+ self.tracer.trace_event("explain", {"rows": len(sample)})
44
+ return {"global_bar": img1, "beeswarm": img2}
tools/predict_tool.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import joblib
4
+ from huggingface_hub import hf_hub_download
5
+ from ..utils.config import AppConfig
6
+ from ..utils.tracing import Tracer
7
+
8
+ class PredictTool:
9
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
10
+ self.cfg = cfg
11
+ self.tracer = tracer
12
+ self._model = None
13
+ self._feature_meta = None
14
+
15
+ def _ensure_loaded(self):
16
+ if self._model is None:
17
+ path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
18
+ self._model = joblib.load(path)
19
+ meta = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="feature_metadata.json", token=os.getenv("HF_TOKEN"))
20
+ import json
21
+ with open(meta, "r") as f:
22
+ self._feature_meta = json.load(f)
23
+
24
+ def run(self, df: pd.DataFrame) -> pd.DataFrame:
25
+ self._ensure_loaded()
26
+ use_cols = self._feature_meta.get("feature_order", list(df.columns))
27
+ X = df[use_cols].copy()
28
+ preds = self._model.predict_proba(X)[:, 1] if hasattr(self._model, "predict_proba") else self._model.predict(X)
29
+ out = df.copy()
30
+ out[self._feature_meta.get("prediction_column", "prediction")] = preds
31
+ self.tracer.trace_event("predict", {"rows": len(out)})
32
+ return out
tools/report_tool.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from jinja2 import Environment, FileSystemLoader
3
+ import pandas as pd
4
+ from ..utils.tracing import Tracer
5
+
6
+ class ReportTool:
7
+ def __init__(self, cfg, tracer: Tracer):
8
+ self.cfg = cfg
9
+ self.tracer = tracer
10
+ self.env = Environment(loader=FileSystemLoader("templates"))
11
+
12
+ def render_and_save(self, user_query: str, sql_preview: pd.DataFrame | None, predict_preview: pd.DataFrame | None, explain_images: dict, plan: dict):
13
+ tmpl = self.env.get_template("report_template.md")
14
+ html = tmpl.render(
15
+ user_query=user_query,
16
+ plan=plan,
17
+ sql_preview=sql_preview.to_markdown(index=False) if sql_preview is not None else "",
18
+ predict_preview=predict_preview.to_markdown(index=False) if predict_preview is not None else "",
19
+ explain_images=explain_images,
20
+ )
21
+ out_path = f"report_{pd.Timestamp.utcnow().strftime('%Y%m%d_%H%M%S')}.html"
22
+ with open(out_path, "w", encoding="utf-8") as f:
23
+ f.write("<link rel=\"stylesheet\" href=\"templates/report_styles.css\">\n" + html)
24
+ self.tracer.trace_event("report", {"path": out_path})
25
+ return out_path
tools/sql_tool.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from typing import Optional
5
+ from ..utils.config import AppConfig
6
+ from ..utils.tracing import Tracer
7
+
8
+ class SQLTool:
9
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
10
+ self.cfg = cfg
11
+ self.tracer = tracer
12
+ self.backend = cfg.sql_backend # "bigquery" or "motherduck"
13
+ if self.backend == "bigquery":
14
+ from google.cloud import bigquery
15
+ from google.oauth2 import service_account
16
+ key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
17
+ if not key_json:
18
+ raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
19
+ creds = service_account.Credentials.from_service_account_info(
20
+ eval(key_json) if key_json.strip().startswith("{") else {}
21
+ )
22
+ self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
23
+ elif self.backend == "motherduck":
24
+ import duckdb
25
+ token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
26
+ db_name = self.cfg.motherduck_db or "default"
27
+ self.client = duckdb.connect(f"md:/{db_name}?motherduck_token={token}")
28
+ else:
29
+ raise RuntimeError("Unknown SQL backend")
30
+
31
+ def _nl_to_sql(self, message: str) -> str:
32
+ # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
33
+ # Expect users to include table names. Example: "avg revenue by month from dataset.sales"
34
+ m = message.lower()
35
+ if "avg" in m and " by " in m:
36
+ return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
37
+ # fallback: pass-through if user typed SQL explicitly
38
+ if re.match(r"^\s*select ", m):
39
+ return message
40
+ return "SELECT * FROM dataset.table LIMIT 100;"
41
+
42
+ def run(self, message: str) -> pd.DataFrame:
43
+ sql = self._nl_to_sql(message)
44
+ self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
45
+ if self.backend == "bigquery":
46
+ df = self.client.query(sql).to_dataframe()
47
+ else:
48
+ df = self.client.execute(sql).fetch_df()
49
+ return df
utils/config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class AppConfig:
6
+ hf_model_repo: str
7
+ sql_backend: str # "bigquery" or "motherduck"
8
+ gcp_project: str | None = None
9
+ motherduck_db: str | None = None
10
+ motherduck_token: str | None = None
11
+
12
+
13
+ @classmethod
14
+ def from_env(cls):
15
+ return cls(
16
+ hf_model_repo=os.getenv("HF_MODEL_REPO", "your-username/your-private-tabular-model"),
17
+ sql_backend=os.getenv("SQL_BACKEND", "motherduck"),
18
+ gcp_project=os.getenv("GCP_PROJECT"),
19
+ motherduck_db=os.getenv("MOTHERDUCK_DB", "default"),
20
+ motherduck_token=os.getenv("MOTHERDUCK_TOKEN")
21
+ )
utils/hf_io.py ADDED
File without changes
utils/tracing.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Optional
4
+
5
+ class Tracer:
6
+ def __init__(self, client=None, trace_url: Optional[str] = None):
7
+ self.client = client
8
+ self.trace_url = trace_url
9
+
10
+ @classmethod
11
+ def from_env(cls):
12
+ try:
13
+ from langfuse import Langfuse
14
+ pk = os.getenv("LANGFUSE_PUBLIC_KEY")
15
+ sk = os.getenv("LANGFUSE_SECRET_KEY")
16
+ host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
17
+ if pk and sk:
18
+ client = Langfuse(public_key=pk, secret_key=sk, host=host)
19
+ session = client.trace("tabular-agentic-xai")
20
+ return cls(client=session, trace_url=session.get_url() if hasattr(session, "get_url") else None)
21
+ except Exception:
22
+ pass
23
+ return cls()
24
+
25
+ def trace_event(self, name: str, payload: dict):
26
+ if self.client:
27
+ try:
28
+ self.client.event(name=name, input=json.dumps(payload))
29
+ except Exception:
30
+ pass