| | from enum import Enum |
| | from typing import Any, List, Literal, Optional, Union |
| |
|
| | from pydantic import BaseModel, Field |
| |
|
| |
|
| | class Role(str, Enum): |
| | """Message role options""" |
| |
|
| | SYSTEM = "system" |
| | USER = "user" |
| | ASSISTANT = "assistant" |
| | TOOL = "tool" |
| |
|
| |
|
| | ROLE_VALUES = tuple(role.value for role in Role) |
| | ROLE_TYPE = Literal[ROLE_VALUES] |
| |
|
| |
|
| | class ToolChoice(str, Enum): |
| | """Tool choice options""" |
| |
|
| | NONE = "none" |
| | AUTO = "auto" |
| | REQUIRED = "required" |
| |
|
| |
|
| | TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) |
| | TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] |
| |
|
| |
|
| | class AgentState(str, Enum): |
| | """Agent execution states""" |
| |
|
| | IDLE = "IDLE" |
| | RUNNING = "RUNNING" |
| | FINISHED = "FINISHED" |
| | ERROR = "ERROR" |
| |
|
| |
|
| | class Function(BaseModel): |
| | name: str |
| | arguments: str |
| |
|
| |
|
| | class ToolCall(BaseModel): |
| | """Represents a tool/function call in a message""" |
| |
|
| | id: str |
| | type: str = "function" |
| | function: Function |
| |
|
| |
|
| | class Message(BaseModel): |
| | """Represents a chat message in the conversation""" |
| |
|
| | role: ROLE_TYPE = Field(...) |
| | content: Optional[str] = Field(default=None) |
| | tool_calls: Optional[List[ToolCall]] = Field(default=None) |
| | name: Optional[str] = Field(default=None) |
| | tool_call_id: Optional[str] = Field(default=None) |
| | base64_image: Optional[str] = Field(default=None) |
| |
|
| | def __add__(self, other) -> List["Message"]: |
| | """ζ―ζ Message + list ζ Message + Message ηζδ½""" |
| | if isinstance(other, list): |
| | return [self] + other |
| | elif isinstance(other, Message): |
| | return [self, other] |
| | else: |
| | raise TypeError( |
| | f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" |
| | ) |
| |
|
| | def __radd__(self, other) -> List["Message"]: |
| | """ζ―ζ list + Message ηζδ½""" |
| | if isinstance(other, list): |
| | return other + [self] |
| | else: |
| | raise TypeError( |
| | f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" |
| | ) |
| |
|
| | def to_dict(self) -> dict: |
| | """Convert message to dictionary format""" |
| | message = {"role": self.role} |
| | if self.content is not None: |
| | message["content"] = self.content |
| | if self.tool_calls is not None: |
| | message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] |
| | if self.name is not None: |
| | message["name"] = self.name |
| | if self.tool_call_id is not None: |
| | message["tool_call_id"] = self.tool_call_id |
| | if self.base64_image is not None: |
| | message["base64_image"] = self.base64_image |
| | return message |
| |
|
| | @classmethod |
| | def user_message( |
| | cls, content: str, base64_image: Optional[str] = None |
| | ) -> "Message": |
| | """Create a user message""" |
| | return cls(role=Role.USER, content=content, base64_image=base64_image) |
| |
|
| | @classmethod |
| | def system_message(cls, content: str) -> "Message": |
| | """Create a system message""" |
| | return cls(role=Role.SYSTEM, content=content) |
| |
|
| | @classmethod |
| | def assistant_message( |
| | cls, content: Optional[str] = None, base64_image: Optional[str] = None |
| | ) -> "Message": |
| | """Create an assistant message""" |
| | return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) |
| |
|
| | @classmethod |
| | def tool_message( |
| | cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None |
| | ) -> "Message": |
| | """Create a tool message""" |
| | return cls( |
| | role=Role.TOOL, |
| | content=content, |
| | name=name, |
| | tool_call_id=tool_call_id, |
| | base64_image=base64_image, |
| | ) |
| |
|
| | @classmethod |
| | def from_tool_calls( |
| | cls, |
| | tool_calls: List[Any], |
| | content: Union[str, List[str]] = "", |
| | base64_image: Optional[str] = None, |
| | **kwargs, |
| | ): |
| | """Create ToolCallsMessage from raw tool calls. |
| | |
| | Args: |
| | tool_calls: Raw tool calls from LLM |
| | content: Optional message content |
| | base64_image: Optional base64 encoded image |
| | """ |
| | formatted_calls = [ |
| | {"id": call.id, "function": call.function.model_dump(), "type": "function"} |
| | for call in tool_calls |
| | ] |
| | return cls( |
| | role=Role.ASSISTANT, |
| | content=content, |
| | tool_calls=formatted_calls, |
| | base64_image=base64_image, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | class Memory(BaseModel): |
| | messages: List[Message] = Field(default_factory=list) |
| | max_messages: int = Field(default=100) |
| |
|
| | def add_message(self, message: Message) -> None: |
| | """Add a message to memory""" |
| | self.messages.append(message) |
| | |
| | if len(self.messages) > self.max_messages: |
| | self.messages = self.messages[-self.max_messages :] |
| |
|
| | def add_messages(self, messages: List[Message]) -> None: |
| | """Add multiple messages to memory""" |
| | self.messages.extend(messages) |
| | |
| | if len(self.messages) > self.max_messages: |
| | self.messages = self.messages[-self.max_messages :] |
| |
|
| | def clear(self) -> None: |
| | """Clear all messages""" |
| | self.messages.clear() |
| |
|
| | def get_recent_messages(self, n: int) -> List[Message]: |
| | """Get n most recent messages""" |
| | return self.messages[-n:] |
| |
|
| | def to_dict_list(self) -> List[dict]: |
| | """Convert messages to list of dicts""" |
| | return [msg.to_dict() for msg in self.messages] |
| |
|