| |
| |
| |
| |
| |
|
|
| """ |
| Sandboxed Python code executor for the REPL environment. |
| |
| Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed |
| execution, with RLM-specific features on top: |
| - Context loading (set_context) |
| - Variable access (get_variable, list_variables) |
| - Function injection (inject_function for llm_query, llm_query_batched) |
| - Output capped at 8,192 characters per turn (configurable) |
| - Persistent namespace across code blocks |
| """ |
|
|
| import json |
| import logging |
| import time |
| import traceback |
| from collections.abc import Callable |
| from typing import Any, Dict, List, Optional |
|
|
| from smolagents import LocalPythonExecutor |
|
|
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
|
|
|
|
| class PythonExecutor: |
| """Sandboxed Python code executor with persistent namespace. |
| |
| Wraps smolagents.LocalPythonExecutor with RLM-specific features: |
| - Context loading for RLM tasks |
| - Variable tracking for observation |
| - Function injection for llm_query, llm_query_batched |
| - Configurable output length limit (default 8192 chars per Prime Intellect) |
| """ |
|
|
| def __init__( |
| self, |
| max_output_length: int = 8192, |
| allowed_imports: Optional[List[str]] = None, |
| ): |
| """Initialize the executor. |
| |
| Args: |
| max_output_length: Maximum characters for stdout/stderr (default 8192) |
| allowed_imports: List of allowed module names for import |
| |
| Note: |
| smolagents.LocalPythonExecutor does NOT support wall-clock timeouts. |
| Instead, it limits operations (10M ops) and while iterations (1M). |
| """ |
| self.max_output_length = max_output_length |
|
|
| |
| default_imports = [ |
| "re", |
| "json", |
| "math", |
| "random", |
| "collections", |
| "itertools", |
| "functools", |
| "operator", |
| "string", |
| "textwrap", |
| "difflib", |
| "statistics", |
| "decimal", |
| "fractions", |
| "datetime", |
| "copy", |
| "pprint", |
| "typing", |
| "dataclasses", |
| "enum", |
| "bisect", |
| "heapq", |
| "array", |
| "struct", |
| "base64", |
| "hashlib", |
| "hmac", |
| "uuid", |
| ] |
|
|
| self.allowed_imports = allowed_imports or default_imports |
|
|
| |
| self._executor = LocalPythonExecutor( |
| additional_authorized_imports=self.allowed_imports |
| ) |
|
|
| |
| self._user_variables: set[str] = set() |
|
|
| |
| self._callable_tools: Dict[str, Callable[..., Any]] = {} |
|
|
| |
| self._register_helpers() |
|
|
| def _register_helpers(self) -> None: |
| """Register helper functions with the executor.""" |
| helpers = { |
| "format_exc": traceback.format_exc, |
| "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), |
| } |
| |
| for name, func in helpers.items(): |
| self.inject_function(name, func) |
|
|
| def _sync_callable_tools(self) -> None: |
| """Sync callable functions with the executor via send_tools.""" |
| if self._callable_tools: |
| try: |
| |
| self._executor.send_tools(self._callable_tools) |
| except Exception: |
| logger.debug( |
| "send_tools failed; continuing without extra tools", |
| exc_info=True, |
| ) |
|
|
| def set_context(self, context: str, variable_name: str = "context") -> None: |
| """Load context into namespace as a variable. |
| |
| Args: |
| context: The context string to load |
| variable_name: Name of the variable (default "context") |
| """ |
| self.set_variable(variable_name, context) |
|
|
| def set_variable(self, name: str, value: Any) -> None: |
| """Set a variable in the namespace. |
| |
| Args: |
| name: Variable name |
| value: Variable value |
| """ |
| |
| if hasattr(self._executor, "state"): |
| self._executor.state[name] = value |
| else: |
| |
| self._executor._injected_vars = getattr( |
| self._executor, "_injected_vars", {} |
| ) |
| self._executor._injected_vars[name] = value |
|
|
| self._user_variables.add(name) |
|
|
| def get_variable(self, name: str) -> Optional[Any]: |
| """Retrieve a variable from namespace. |
| |
| Args: |
| name: Variable name |
| |
| Returns: |
| The variable value or None if not found |
| """ |
| |
| if hasattr(self._executor, "state"): |
| return self._executor.state.get(name) |
|
|
| |
| if hasattr(self._executor, "_injected_vars"): |
| return self._executor._injected_vars.get(name) |
|
|
| return None |
|
|
| def list_variables(self) -> List[str]: |
| """List non-private variables in namespace. |
| |
| Returns: |
| List of variable names (excluding private and builtins) |
| """ |
| variables = set() |
|
|
| |
| if hasattr(self._executor, "state"): |
| for key in self._executor.state: |
| if not key.startswith("_"): |
| variables.add(key) |
|
|
| |
| variables.update(self._user_variables) |
|
|
| return list(variables) |
|
|
| def execute(self, code: str) -> Dict[str, Any]: |
| """Execute Python code and return results. |
| |
| Args: |
| code: Python code to execute |
| |
| Returns: |
| Dictionary with stdout, stderr, locals_snapshot, execution_time, |
| success, and exception fields |
| """ |
| start_time = time.time() |
| success = True |
| exception_msg = None |
| new_locals: Dict[str, str] = {} |
|
|
| |
| pre_state_keys = set() |
| if hasattr(self._executor, "state"): |
| pre_state_keys = set(self._executor.state.keys()) |
|
|
| stdout_parts: list[str] = [] |
| stderr_parts: list[str] = [] |
|
|
| try: |
| exec_result = self._executor(code) |
|
|
| |
| try: |
| logs = getattr(exec_result, "logs", None) |
| if logs: |
| stdout_parts.append(str(logs)) |
| except Exception: |
| logger.debug("Failed to read exec_result.logs", exc_info=True) |
|
|
| |
| try: |
| if hasattr(exec_result, "output"): |
| out_val = exec_result.output |
| if out_val is not None: |
| try: |
| stdout_parts.append(json.dumps(out_val)) |
| except Exception: |
| stdout_parts.append(repr(out_val)) |
| except Exception: |
| logger.debug("Failed to read exec_result.output", exc_info=True) |
|
|
| |
| try: |
| err = getattr(exec_result, "error", None) |
| if err: |
| stderr_parts.append(str(err)) |
| success = False |
| exception_msg = str(err) |
| except Exception: |
| logger.debug("Failed to read exec_result.error", exc_info=True) |
|
|
| try: |
| ex = getattr(exec_result, "exception", None) |
| if ex: |
| stderr_parts.append(str(ex)) |
| success = False |
| exception_msg = str(ex) |
| except Exception: |
| logger.debug("Failed to read exec_result.exception", exc_info=True) |
|
|
| |
| try: |
| if hasattr(exec_result, "exit_code"): |
| if exec_result.exit_code is not None and exec_result.exit_code != 0: |
| success = False |
| elif hasattr(exec_result, "success"): |
| success = bool(exec_result.success) |
| except Exception: |
| logger.debug("Failed to determine exec_result exit code", exc_info=True) |
|
|
| except Exception as e: |
| success = False |
| exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" |
| stderr_parts.append(exception_msg) |
|
|
| execution_time = time.time() - start_time |
|
|
| |
| if hasattr(self._executor, "state"): |
| for key in self._executor.state: |
| if key not in pre_state_keys and not key.startswith("_"): |
| try: |
| val = self._executor.state[key] |
| val_repr = repr(val) |
| if len(val_repr) > 500: |
| val_repr = val_repr[:500] + "..." |
| new_locals[key] = val_repr |
| self._user_variables.add(key) |
| except Exception: |
| new_locals[key] = "<unrepresentable>" |
|
|
| |
| stdout = "\n".join(part for part in stdout_parts if part) |
| stderr = "\n".join(part for part in stderr_parts if part) |
|
|
| |
| if len(stdout) > self.max_output_length: |
| stdout = ( |
| stdout[: self.max_output_length] |
| + f"\n... (truncated, total {len(stdout)} chars)" |
| ) |
|
|
| if len(stderr) > self.max_output_length: |
| stderr = ( |
| stderr[: self.max_output_length] |
| + f"\n... (truncated, total {len(stderr)} chars)" |
| ) |
|
|
| return { |
| "stdout": stdout, |
| "stderr": stderr, |
| "locals_snapshot": new_locals, |
| "execution_time": execution_time, |
| "success": success, |
| "exception": exception_msg, |
| } |
|
|
| def reset(self) -> None: |
| """Reset namespace to initial state.""" |
| |
| self._executor = LocalPythonExecutor( |
| additional_authorized_imports=self.allowed_imports |
| ) |
| self._user_variables.clear() |
| self._callable_tools.clear() |
| self._register_helpers() |
|
|
| def inject_function(self, name: str, func: Callable[..., Any]) -> None: |
| """Inject a callable function into the namespace. |
| |
| Used for adding llm_query, llm_query_batched, FINAL, etc. |
| |
| Args: |
| name: Function name in namespace |
| func: The callable to inject |
| """ |
| |
| self._callable_tools[name] = func |
| self._user_variables.add(name) |
| self._sync_callable_tools() |
|
|