Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Callable, Optional | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.graph.state import CompiledStateGraph | |
| import uuid | |
| def random_uuid(): | |
| return str(uuid.uuid4()) | |
| async def astream_graph( | |
| graph: CompiledStateGraph, | |
| inputs: dict, | |
| config: Optional[RunnableConfig] = None, | |
| node_names: List[str] = [], | |
| callback: Optional[Callable] = None, | |
| stream_mode: str = "messages", | |
| include_subgraphs: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| LangGraph์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋น๋๊ธฐ์ ์ผ๋ก ์คํธ๋ฆฌ๋ฐํ๊ณ ์ง์ ์ถ๋ ฅํ๋ ํจ์์ ๋๋ค. | |
| Args: | |
| graph (CompiledStateGraph): ์คํํ ์ปดํ์ผ๋ LangGraph ๊ฐ์ฒด | |
| inputs (dict): ๊ทธ๋ํ์ ์ ๋ฌํ ์ ๋ ฅ๊ฐ ๋์ ๋๋ฆฌ | |
| config (Optional[RunnableConfig]): ์คํ ์ค์ (์ ํ์ ) | |
| node_names (List[str], optional): ์ถ๋ ฅํ ๋ ธ๋ ์ด๋ฆ ๋ชฉ๋ก. ๊ธฐ๋ณธ๊ฐ์ ๋น ๋ฆฌ์คํธ | |
| callback (Optional[Callable], optional): ๊ฐ ์ฒญํฌ ์ฒ๋ฆฌ๋ฅผ ์ํ ์ฝ๋ฐฑ ํจ์. ๊ธฐ๋ณธ๊ฐ์ None | |
| ์ฝ๋ฐฑ ํจ์๋ {"node": str, "content": Any} ํํ์ ๋์ ๋๋ฆฌ๋ฅผ ์ธ์๋ก ๋ฐ์ต๋๋ค. | |
| stream_mode (str, optional): ์คํธ๋ฆฌ๋ฐ ๋ชจ๋ ("messages" ๋๋ "updates"). ๊ธฐ๋ณธ๊ฐ์ "messages" | |
| include_subgraphs (bool, optional): ์๋ธ๊ทธ๋ํ ํฌํจ ์ฌ๋ถ. ๊ธฐ๋ณธ๊ฐ์ False | |
| Returns: | |
| Dict[str, Any]: ์ต์ข ๊ฒฐ๊ณผ (์ ํ์ ) | |
| """ | |
| config = config or {} | |
| final_result = {} | |
| def format_namespace(namespace): | |
| return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph" | |
| prev_node = "" | |
| if stream_mode == "messages": | |
| async for chunk_msg, metadata in graph.astream( | |
| inputs, config, stream_mode=stream_mode | |
| ): | |
| curr_node = metadata["langgraph_node"] | |
| final_result = { | |
| "node": curr_node, | |
| "content": chunk_msg, | |
| "metadata": metadata, | |
| } | |
| # node_names๊ฐ ๋น์ด์๊ฑฐ๋ ํ์ฌ ๋ ธ๋๊ฐ node_names์ ์๋ ๊ฒฝ์ฐ์๋ง ์ฒ๋ฆฌ | |
| if not node_names or curr_node in node_names: | |
| # ์ฝ๋ฐฑ ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ ์คํ | |
| if callback: | |
| result = callback({"node": curr_node, "content": chunk_msg}) | |
| if hasattr(result, "__await__"): | |
| await result | |
| # ์ฝ๋ฐฑ์ด ์๋ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ์ถ๋ ฅ | |
| else: | |
| # ๋ ธ๋๊ฐ ๋ณ๊ฒฝ๋ ๊ฒฝ์ฐ์๋ง ๊ตฌ๋ถ์ ์ถ๋ ฅ | |
| if curr_node != prev_node: | |
| print("\n" + "=" * 50) | |
| print(f"๐ Node: \033[1;36m{curr_node}\033[0m ๐") | |
| print("- " * 25) | |
| # Claude/Anthropic ๋ชจ๋ธ์ ํ ํฐ ์ฒญํฌ ์ฒ๋ฆฌ - ํญ์ ํ ์คํธ๋ง ์ถ์ถ | |
| if hasattr(chunk_msg, "content"): | |
| # ๋ฆฌ์คํธ ํํ์ content (Anthropic/Claude ์คํ์ผ) | |
| if isinstance(chunk_msg.content, list): | |
| for item in chunk_msg.content: | |
| if isinstance(item, dict) and "text" in item: | |
| print(item["text"], end="", flush=True) | |
| # ๋ฌธ์์ด ํํ์ content | |
| elif isinstance(chunk_msg.content, str): | |
| print(chunk_msg.content, end="", flush=True) | |
| # ๊ทธ ์ธ ํํ์ chunk_msg ์ฒ๋ฆฌ | |
| else: | |
| print(chunk_msg, end="", flush=True) | |
| prev_node = curr_node | |
| elif stream_mode == "updates": | |
| # ์๋ฌ ์์ : ์ธํจํน ๋ฐฉ์ ๋ณ๊ฒฝ | |
| # REACT ์์ด์ ํธ ๋ฑ ์ผ๋ถ ๊ทธ๋ํ์์๋ ๋จ์ผ ๋์ ๋๋ฆฌ๋ง ๋ฐํํจ | |
| async for chunk in graph.astream( | |
| inputs, config, stream_mode=stream_mode, subgraphs=include_subgraphs | |
| ): | |
| # ๋ฐํ ํ์์ ๋ฐ๋ผ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ ๋ถ๊ธฐ | |
| if isinstance(chunk, tuple) and len(chunk) == 2: | |
| # ๊ธฐ์กด ์์ ํ์: (namespace, chunk_dict) | |
| namespace, node_chunks = chunk | |
| else: | |
| # ๋จ์ผ ๋์ ๋๋ฆฌ๋ง ๋ฐํํ๋ ๊ฒฝ์ฐ (REACT ์์ด์ ํธ ๋ฑ) | |
| namespace = [] # ๋น ๋ค์์คํ์ด์ค (๋ฃจํธ ๊ทธ๋ํ) | |
| node_chunks = chunk # chunk ์์ฒด๊ฐ ๋ ธ๋ ์ฒญํฌ ๋์ ๋๋ฆฌ | |
| # ๋์ ๋๋ฆฌ์ธ์ง ํ์ธํ๊ณ ํญ๋ชฉ ์ฒ๋ฆฌ | |
| if isinstance(node_chunks, dict): | |
| for node_name, node_chunk in node_chunks.items(): | |
| final_result = { | |
| "node": node_name, | |
| "content": node_chunk, | |
| "namespace": namespace, | |
| } | |
| # node_names๊ฐ ๋น์ด์์ง ์์ ๊ฒฝ์ฐ์๋ง ํํฐ๋ง | |
| if len(node_names) > 0 and node_name not in node_names: | |
| continue | |
| # ์ฝ๋ฐฑ ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ ์คํ | |
| if callback is not None: | |
| result = callback({"node": node_name, "content": node_chunk}) | |
| if hasattr(result, "__await__"): | |
| await result | |
| # ์ฝ๋ฐฑ์ด ์๋ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ์ถ๋ ฅ | |
| else: | |
| # ๋ ธ๋๊ฐ ๋ณ๊ฒฝ๋ ๊ฒฝ์ฐ์๋ง ๊ตฌ๋ถ์ ์ถ๋ ฅ (messages ๋ชจ๋์ ๋์ผํ๊ฒ) | |
| if node_name != prev_node: | |
| print("\n" + "=" * 50) | |
| print(f"๐ Node: \033[1;36m{node_name}\033[0m ๐") | |
| print("- " * 25) | |
| # ๋ ธ๋์ ์ฒญํฌ ๋ฐ์ดํฐ ์ถ๋ ฅ - ํ ์คํธ ์ค์ฌ์ผ๋ก ์ฒ๋ฆฌ | |
| if isinstance(node_chunk, dict): | |
| for k, v in node_chunk.items(): | |
| if isinstance(v, BaseMessage): | |
| # BaseMessage์ content ์์ฑ์ด ํ ์คํธ๋ ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ๋ฅผ ์ฒ๋ฆฌ | |
| if hasattr(v, "content"): | |
| if isinstance(v.content, list): | |
| for item in v.content: | |
| if ( | |
| isinstance(item, dict) | |
| and "text" in item | |
| ): | |
| print( | |
| item["text"], end="", flush=True | |
| ) | |
| else: | |
| print(v.content, end="", flush=True) | |
| else: | |
| v.pretty_print() | |
| elif isinstance(v, list): | |
| for list_item in v: | |
| if isinstance(list_item, BaseMessage): | |
| if hasattr(list_item, "content"): | |
| if isinstance(list_item.content, list): | |
| for item in list_item.content: | |
| if ( | |
| isinstance(item, dict) | |
| and "text" in item | |
| ): | |
| print( | |
| item["text"], | |
| end="", | |
| flush=True, | |
| ) | |
| else: | |
| print( | |
| list_item.content, | |
| end="", | |
| flush=True, | |
| ) | |
| else: | |
| list_item.pretty_print() | |
| elif ( | |
| isinstance(list_item, dict) | |
| and "text" in list_item | |
| ): | |
| print(list_item["text"], end="", flush=True) | |
| else: | |
| print(list_item, end="", flush=True) | |
| elif isinstance(v, dict) and "text" in v: | |
| print(v["text"], end="", flush=True) | |
| else: | |
| print(v, end="", flush=True) | |
| elif node_chunk is not None: | |
| if hasattr(node_chunk, "__iter__") and not isinstance( | |
| node_chunk, str | |
| ): | |
| for item in node_chunk: | |
| if isinstance(item, dict) and "text" in item: | |
| print(item["text"], end="", flush=True) | |
| else: | |
| print(item, end="", flush=True) | |
| else: | |
| print(node_chunk, end="", flush=True) | |
| # ๊ตฌ๋ถ์ ์ ์ฌ๊ธฐ์ ์ถ๋ ฅํ์ง ์์ (messages ๋ชจ๋์ ๋์ผํ๊ฒ) | |
| prev_node = node_name | |
| else: | |
| # ๋์ ๋๋ฆฌ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ์ฒญํฌ ์ถ๋ ฅ | |
| print("\n" + "=" * 50) | |
| print(f"๐ Raw output ๐") | |
| print("- " * 25) | |
| print(node_chunks, end="", flush=True) | |
| # ๊ตฌ๋ถ์ ์ ์ฌ๊ธฐ์ ์ถ๋ ฅํ์ง ์์ | |
| final_result = {"content": node_chunks} | |
| else: | |
| raise ValueError( | |
| f"Invalid stream_mode: {stream_mode}. Must be 'messages' or 'updates'." | |
| ) | |
| # ํ์์ ๋ฐ๋ผ ์ต์ข ๊ฒฐ๊ณผ ๋ฐํ | |
| return final_result | |
| async def ainvoke_graph( | |
| graph: CompiledStateGraph, | |
| inputs: dict, | |
| config: Optional[RunnableConfig] = None, | |
| node_names: List[str] = [], | |
| callback: Optional[Callable] = None, | |
| include_subgraphs: bool = True, | |
| ) -> Dict[str, Any]: | |
| """ | |
| LangGraph ์ฑ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋น๋๊ธฐ์ ์ผ๋ก ์คํธ๋ฆฌ๋ฐํ์ฌ ์ถ๋ ฅํ๋ ํจ์์ ๋๋ค. | |
| Args: | |
| graph (CompiledStateGraph): ์คํํ ์ปดํ์ผ๋ LangGraph ๊ฐ์ฒด | |
| inputs (dict): ๊ทธ๋ํ์ ์ ๋ฌํ ์ ๋ ฅ๊ฐ ๋์ ๋๋ฆฌ | |
| config (Optional[RunnableConfig]): ์คํ ์ค์ (์ ํ์ ) | |
| node_names (List[str], optional): ์ถ๋ ฅํ ๋ ธ๋ ์ด๋ฆ ๋ชฉ๋ก. ๊ธฐ๋ณธ๊ฐ์ ๋น ๋ฆฌ์คํธ | |
| callback (Optional[Callable], optional): ๊ฐ ์ฒญํฌ ์ฒ๋ฆฌ๋ฅผ ์ํ ์ฝ๋ฐฑ ํจ์. ๊ธฐ๋ณธ๊ฐ์ None | |
| ์ฝ๋ฐฑ ํจ์๋ {"node": str, "content": Any} ํํ์ ๋์ ๋๋ฆฌ๋ฅผ ์ธ์๋ก ๋ฐ์ต๋๋ค. | |
| include_subgraphs (bool, optional): ์๋ธ๊ทธ๋ํ ํฌํจ ์ฌ๋ถ. ๊ธฐ๋ณธ๊ฐ์ True | |
| Returns: | |
| Dict[str, Any]: ์ต์ข ๊ฒฐ๊ณผ (๋ง์ง๋ง ๋ ธ๋์ ์ถ๋ ฅ) | |
| """ | |
| config = config or {} | |
| final_result = {} | |
| def format_namespace(namespace): | |
| return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph" | |
| # subgraphs ๋งค๊ฐ๋ณ์๋ฅผ ํตํด ์๋ธ๊ทธ๋ํ์ ์ถ๋ ฅ๋ ํฌํจ | |
| async for chunk in graph.astream( | |
| inputs, config, stream_mode="updates", subgraphs=include_subgraphs | |
| ): | |
| # ๋ฐํ ํ์์ ๋ฐ๋ผ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ ๋ถ๊ธฐ | |
| if isinstance(chunk, tuple) and len(chunk) == 2: | |
| # ๊ธฐ์กด ์์ ํ์: (namespace, chunk_dict) | |
| namespace, node_chunks = chunk | |
| else: | |
| # ๋จ์ผ ๋์ ๋๋ฆฌ๋ง ๋ฐํํ๋ ๊ฒฝ์ฐ (REACT ์์ด์ ํธ ๋ฑ) | |
| namespace = [] # ๋น ๋ค์์คํ์ด์ค (๋ฃจํธ ๊ทธ๋ํ) | |
| node_chunks = chunk # chunk ์์ฒด๊ฐ ๋ ธ๋ ์ฒญํฌ ๋์ ๋๋ฆฌ | |
| # ๋์ ๋๋ฆฌ์ธ์ง ํ์ธํ๊ณ ํญ๋ชฉ ์ฒ๋ฆฌ | |
| if isinstance(node_chunks, dict): | |
| for node_name, node_chunk in node_chunks.items(): | |
| final_result = { | |
| "node": node_name, | |
| "content": node_chunk, | |
| "namespace": namespace, | |
| } | |
| # node_names๊ฐ ๋น์ด์์ง ์์ ๊ฒฝ์ฐ์๋ง ํํฐ๋ง | |
| if node_names and node_name not in node_names: | |
| continue | |
| # ์ฝ๋ฐฑ ํจ์๊ฐ ์๋ ๊ฒฝ์ฐ ์คํ | |
| if callback is not None: | |
| result = callback({"node": node_name, "content": node_chunk}) | |
| # ์ฝ๋ฃจํด์ธ ๊ฒฝ์ฐ await | |
| if hasattr(result, "__await__"): | |
| await result | |
| # ์ฝ๋ฐฑ์ด ์๋ ๊ฒฝ์ฐ ๊ธฐ๋ณธ ์ถ๋ ฅ | |
| else: | |
| print("\n" + "=" * 50) | |
| formatted_namespace = format_namespace(namespace) | |
| if formatted_namespace == "root graph": | |
| print(f"๐ Node: \033[1;36m{node_name}\033[0m ๐") | |
| else: | |
| print( | |
| f"๐ Node: \033[1;36m{node_name}\033[0m in [\033[1;33m{formatted_namespace}\033[0m] ๐" | |
| ) | |
| print("- " * 25) | |
| # ๋ ธ๋์ ์ฒญํฌ ๋ฐ์ดํฐ ์ถ๋ ฅ | |
| if isinstance(node_chunk, dict): | |
| for k, v in node_chunk.items(): | |
| if isinstance(v, BaseMessage): | |
| v.pretty_print() | |
| elif isinstance(v, list): | |
| for list_item in v: | |
| if isinstance(list_item, BaseMessage): | |
| list_item.pretty_print() | |
| else: | |
| print(list_item) | |
| elif isinstance(v, dict): | |
| for node_chunk_key, node_chunk_value in v.items(): | |
| print(f"{node_chunk_key}:\n{node_chunk_value}") | |
| else: | |
| print(f"\033[1;32m{k}\033[0m:\n{v}") | |
| elif node_chunk is not None: | |
| if hasattr(node_chunk, "__iter__") and not isinstance( | |
| node_chunk, str | |
| ): | |
| for item in node_chunk: | |
| print(item) | |
| else: | |
| print(node_chunk) | |
| print("=" * 50) | |
| else: | |
| # ๋์ ๋๋ฆฌ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ์ฒญํฌ ์ถ๋ ฅ | |
| print("\n" + "=" * 50) | |
| print(f"๐ Raw output ๐") | |
| print("- " * 25) | |
| print(node_chunks) | |
| print("=" * 50) | |
| final_result = {"content": node_chunks} | |
| # ์ต์ข ๊ฒฐ๊ณผ ๋ฐํ | |
| return final_result | |