Spaces:
Build error
Build error
File size: 4,596 Bytes
d6703a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import logging
import sys
import inspect
from typing import Any
from fastapi import Request
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.socket.main import get_event_call, get_event_emitter
from open_webui.utils.plugin import get_function_module_from_cache
from open_webui.utils.models import get_all_models
from open_webui.utils.middleware import process_tool_result
from open_webui.env import GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
if "." in action_id:
action_id, sub_action_id = action_id.split(".")
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
if not action:
raise Exception(f"Action not found: {action_id}")
if not request.app.state.MODELS:
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
function_module, _, _ = get_function_module_from_cache(request, action_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": sub_action_id if sub_action_id is not None else action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__request__": request,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = user.model_dump() if isinstance(user, UserModel) else {}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
log.exception(f"Failed to get user values: {e}")
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
# Process action result for Rich UI embeds (HTMLResponse, tuple with headers)
processed_result, _, action_embeds = process_tool_result(
request,
action_id,
data,
"action",
)
if action_embeds:
await __event_emitter__(
{
"type": "embeds",
"data": {
"embeds": action_embeds,
},
}
)
# Replace data with the processed status dict so we don't
# try to serialize the raw HTMLResponse / tuple back to the client
data = processed_result
except Exception as e:
raise Exception(f"Error: {e}")
return data
|