# space/tools/explain_tool.py import os import io import json import base64 import logging from typing import Dict, Optional import shap import pandas as pd import matplotlib matplotlib.use('Agg') # Non-interactive backend import matplotlib.pyplot as plt import joblib from huggingface_hub import hf_hub_download from utils.config import AppConfig from utils.tracing import Tracer logger = logging.getLogger(__name__) # Constants MAX_SAMPLE_SIZE = 1000 MIN_SAMPLE_SIZE = 10 DEFAULT_SAMPLE_SIZE = 500 MAX_IMAGE_SIZE_MB = 5 class ExplainToolError(Exception): """Custom exception for explanation tool errors.""" pass class ExplainTool: """ Generates SHAP-based model explanations with global visualizations. CPU-friendly with sampling for large datasets. """ def __init__(self, cfg: AppConfig, tracer: Tracer): self.cfg = cfg self.tracer = tracer self._model = None self._feature_order = None logger.info("ExplainTool initialized (lazy loading)") def _ensure_model(self): """Lazy load model and metadata from HuggingFace.""" if self._model is not None: return try: token = os.getenv("HF_TOKEN") repo = self.cfg.hf_model_repo if not repo: raise ExplainToolError("HF_MODEL_REPO not configured") logger.info(f"Loading model for explanations from: {repo}") # Download and load model try: model_path = hf_hub_download( repo_id=repo, filename="model.pkl", token=token ) self._model = joblib.load(model_path) logger.info(f"Model loaded: {type(self._model).__name__}") except Exception as e: raise ExplainToolError(f"Failed to load model: {e}") from e # Load feature metadata try: meta_path = hf_hub_download( repo_id=repo, filename="feature_metadata.json", token=token ) with open(meta_path, "r", encoding="utf-8") as f: meta = json.load(f) or {} self._feature_order = meta.get("feature_order") logger.info(f"Loaded feature order: {len(self._feature_order or [])} features") except Exception as e: logger.warning(f"Could not load feature metadata: {e}") self._feature_order = None except ExplainToolError: raise except Exception as e: raise ExplainToolError(f"Model initialization failed: {e}") from e def _validate_data(self, df: pd.DataFrame) -> tuple[bool, str]: """ Validate input dataframe. Returns (is_valid, error_message). """ if df is None or df.empty: return False, "Input dataframe is empty" if len(df.columns) == 0: return False, "Dataframe has no columns" return True, "" def _prepare_features(self, df: pd.DataFrame) -> pd.DataFrame: """ Prepare feature matrix for SHAP analysis. Selects and orders features according to model expectations. """ if self._feature_order: # Use specified feature order available_features = [col for col in self._feature_order if col in df.columns] missing_features = [col for col in self._feature_order if col not in df.columns] if missing_features: logger.warning( f"Missing {len(missing_features)} features for explanation: " f"{missing_features[:5]}" ) if not available_features: raise ExplainToolError( f"No required features found in dataframe. " f"Required: {self._feature_order}, " f"Available: {list(df.columns)}" ) X = df[available_features].copy() logger.info(f"Using {len(available_features)} features for explanation") else: # Use all columns X = df.copy() logger.warning("No feature order specified - using all columns") # Remove non-numeric columns numeric_cols = X.select_dtypes(include=['number']).columns if len(numeric_cols) < len(X.columns): dropped = set(X.columns) - set(numeric_cols) logger.warning(f"Dropping {len(dropped)} non-numeric columns: {list(dropped)[:5]}") X = X[numeric_cols] if X.empty or len(X.columns) == 0: raise ExplainToolError("No numeric features available for explanation") return X def _sample_data(self, X: pd.DataFrame, sample_size: int = DEFAULT_SAMPLE_SIZE) -> pd.DataFrame: """ Sample data for SHAP analysis to keep computation manageable. """ n = len(X) if n <= MIN_SAMPLE_SIZE: logger.info(f"Using all {n} rows (below minimum sample size)") return X # Determine sample size target_size = min(sample_size, MAX_SAMPLE_SIZE) target_size = max(target_size, MIN_SAMPLE_SIZE) if n <= target_size: logger.info(f"Using all {n} rows (below target sample size)") return X # Stratified sampling if possible try: sample = X.sample(n=target_size, random_state=42) logger.info(f"Sampled {target_size} rows from {n} total") return sample except Exception as e: logger.warning(f"Sampling failed: {e}, using head()") return X.head(target_size) @staticmethod def _to_data_uri(fig) -> str: """ Convert matplotlib figure to base64 data URI. Includes size validation. """ try: buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) plt.close(fig) buf.seek(0) # Check size size_mb = len(buf.getvalue()) / (1024 * 1024) if size_mb > MAX_IMAGE_SIZE_MB: logger.warning(f"Generated image is large: {size_mb:.2f} MB") data_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode() logger.debug(f"Generated data URI of size: {len(data_uri)} chars") return data_uri except Exception as e: logger.error(f"Failed to convert figure to data URI: {e}") raise ExplainToolError(f"Image conversion failed: {e}") from e def _generate_shap_values(self, X: pd.DataFrame) -> shap.Explanation: """ Generate SHAP values for the sample. """ try: logger.info("Creating SHAP explainer...") explainer = shap.Explainer(self._model, X) logger.info("Computing SHAP values...") shap_values = explainer(X) logger.info(f"SHAP values computed: shape={shap_values.values.shape}") return shap_values except Exception as e: raise ExplainToolError(f"SHAP computation failed: {e}") from e def _create_bar_plot(self, shap_values: shap.Explanation) -> str: """Create global feature importance bar plot.""" try: logger.info("Creating bar plot...") fig = plt.figure(figsize=(10, 6)) shap.plots.bar(shap_values, show=False, max_display=20) plt.title("Feature Importance (Global)", fontsize=14, pad=20) plt.xlabel("Mean |SHAP value|", fontsize=12) plt.tight_layout() uri = self._to_data_uri(fig) logger.info("Bar plot created successfully") return uri except Exception as e: logger.error(f"Bar plot creation failed: {e}") # Return empty data URI rather than failing completely return "" def _create_beeswarm_plot(self, shap_values: shap.Explanation) -> str: """Create beeswarm plot showing feature effects.""" try: logger.info("Creating beeswarm plot...") fig = plt.figure(figsize=(10, 8)) shap.plots.beeswarm(shap_values, show=False, max_display=20) plt.title("Feature Effects Distribution", fontsize=14, pad=20) plt.tight_layout() uri = self._to_data_uri(fig) logger.info("Beeswarm plot created successfully") return uri except Exception as e: logger.error(f"Beeswarm plot creation failed: {e}") return "" def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]: """ Generate SHAP explanations for input data. Args: df: Input dataframe with features Returns: Dictionary mapping plot names to base64 data URIs Raises: ExplainToolError: If explanation generation fails """ try: # Validate input is_valid, error_msg = self._validate_data(df) if not is_valid: logger.warning(f"Invalid input: {error_msg}") return {} # Ensure model is loaded self._ensure_model() # Prepare features X = self._prepare_features(df) logger.info(f"Prepared features: {X.shape}") # Sample data for efficiency sample = self._sample_data(X) # Generate SHAP values shap_values = self._generate_shap_values(sample) # Create visualizations result = {} # Bar plot (feature importance) bar_uri = self._create_bar_plot(shap_values) if bar_uri: result["global_bar"] = bar_uri # Beeswarm plot (feature effects) bee_uri = self._create_beeswarm_plot(shap_values) if bee_uri: result["beeswarm"] = bee_uri # Log success logger.info(f"Generated {len(result)} explanation visualizations") if self.tracer: self.tracer.trace_event("explain", { "rows": len(sample), "features": len(X.columns), "visualizations": len(result) }) return result except ExplainToolError: raise except Exception as e: error_msg = f"Explanation generation failed: {str(e)}" logger.error(error_msg) if self.tracer: self.tracer.trace_event("explain_error", {"error": error_msg}) raise ExplainToolError(error_msg) from e