| | |
| | import asyncio |
| | import os |
| | import sys |
| | import logging |
| | import random |
| | import pandas as pd |
| | import requests |
| | import wikipedia as wiki |
| | from markdownify import markdownify as to_markdown |
| | from typing import Any |
| | from dotenv import load_dotenv |
| | from google.generativeai import types, configure |
| |
|
| | from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool |
| |
|
| | |
| | load_dotenv() |
| | configure(api_key=os.getenv("GOOGLE_API_KEY")) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash" |
| | OPENAI_MODEL_NAME = "openai/gpt-4o" |
| | GROQ_MODEL_NAME = "groq/llama3-70b-8192" |
| | DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat" |
| | HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct" |
| |
|
| | |
| | class MathSolver(Tool): |
| | name = "math_solver" |
| | description = "Safely evaluate basic math expressions." |
| | inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}} |
| | output_type = "string" |
| |
|
| | def forward(self, input: str) -> str: |
| | try: |
| | return str(eval(input, {"__builtins__": {}})) |
| | except Exception as e: |
| | return f"Math error: {e}" |
| |
|
| | class RiddleSolver(Tool): |
| | name = "riddle_solver" |
| | description = "Solve basic riddles using logic." |
| | inputs = {"input": {"type": "string", "description": "Riddle prompt."}} |
| | output_type = "string" |
| |
|
| | def forward(self, input: str) -> str: |
| | if "forward" in input and "backward" in input: |
| | return "A palindrome" |
| | return "RiddleSolver failed." |
| |
|
| | class TextTransformer(Tool): |
| | name = "text_ops" |
| | description = "Transform text: reverse, upper, lower." |
| | inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}} |
| | output_type = "string" |
| |
|
| | def forward(self, input: str) -> str: |
| | if input.startswith("reverse:"): |
| | reversed_text = input[8:].strip()[::-1] |
| | if 'left' in reversed_text.lower(): |
| | return "right" |
| | return reversed_text |
| | if input.startswith("upper:"): |
| | return input[6:].strip().upper() |
| | if input.startswith("lower:"): |
| | return input[6:].strip().lower() |
| | return "Unknown transformation." |
| |
|
| | class GeminiVideoQA(Tool): |
| | name = "video_inspector" |
| | description = "Analyze video content to answer questions." |
| | inputs = { |
| | "video_url": {"type": "string", "description": "URL of video."}, |
| | "user_query": {"type": "string", "description": "Question about video."} |
| | } |
| | output_type = "string" |
| |
|
| | def __init__(self, model_name, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.model_name = model_name |
| |
|
| | def forward(self, video_url: str, user_query: str) -> str: |
| | req = { |
| | 'model': f'models/{self.model_name}', |
| | 'contents': [{ |
| | "parts": [ |
| | {"fileData": {"fileUri": video_url}}, |
| | {"text": f"Please watch the video and answer the question: {user_query}"} |
| | ] |
| | }] |
| | } |
| | url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}' |
| | res = requests.post(url, json=req, headers={'Content-Type': 'application/json'}) |
| | if res.status_code != 200: |
| | return f"Video error {res.status_code}: {res.text}" |
| | parts = res.json()['candidates'][0]['content']['parts'] |
| | return "".join([p.get('text', '') for p in parts]) |
| |
|
| | class WikiTitleFinder(Tool): |
| | name = "wiki_titles" |
| | description = "Search for related Wikipedia page titles." |
| | inputs = {"query": {"type": "string", "description": "Search query."}} |
| | output_type = "string" |
| |
|
| | def forward(self, query: str) -> str: |
| | results = wiki.search(query) |
| | return ", ".join(results) if results else "No results." |
| |
|
| | class WikiContentFetcher(Tool): |
| | name = "wiki_page" |
| | description = "Fetch Wikipedia page content." |
| | inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}} |
| | output_type = "string" |
| |
|
| | def forward(self, page_title: str) -> str: |
| | try: |
| | return to_markdown(wiki.page(page_title).html()) |
| | except wiki.exceptions.PageError: |
| | return f"'{page_title}' not found." |
| |
|
| | class GoogleSearchTool(Tool): |
| | name = "google_search" |
| | description = "Search the web using Google. Returns top summary from the web." |
| | inputs = {"query": {"type": "string", "description": "Search query."}} |
| | output_type = "string" |
| |
|
| | def forward(self, query: str) -> str: |
| | try: |
| | resp = requests.get("https://www.googleapis.com/customsearch/v1", params={ |
| | "q": query, |
| | "key": os.getenv("GOOGLE_SEARCH_API_KEY"), |
| | "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"), |
| | "num": 1 |
| | }) |
| | data = resp.json() |
| | return data["items"][0]["snippet"] if "items" in data else "No results found." |
| | except Exception as e: |
| | return f"GoogleSearch error: {e}" |
| |
|
| |
|
| | class FileAttachmentQueryTool(Tool): |
| | name = "run_query_with_file" |
| | description = """ |
| | Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it. |
| | This assumes the file is 20MB or less. |
| | """ |
| | inputs = { |
| | "task_id": { |
| | "type": "string", |
| | "description": "A unique identifier for the task related to this file, used to download it.", |
| | "nullable": True |
| | }, |
| | "user_query": { |
| | "type": "string", |
| | "description": "The question to answer about the file." |
| | } |
| | } |
| | output_type = "string" |
| |
|
| | def forward(self, task_id: str | None, user_query: str) -> str: |
| | file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" |
| | file_response = requests.get(file_url) |
| | if file_response.status_code != 200: |
| | return f"Failed to download file: {file_response.status_code} - {file_response.text}" |
| | file_data = file_response.content |
| | from google.generativeai import GenerativeModel |
| | model = GenerativeModel(self.model_name) |
| | response = model.generate_content([ |
| | types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"), |
| | user_query |
| | ]) |
| |
|
| | return response.text |
| |
|
| | |
| | class BasicAgent: |
| | def __init__(self, provider="hf"): |
| | print("BasicAgent initialized.") |
| | model = self.select_model(provider) |
| | client = InferenceClientModel() |
| | tools = [ |
| | GoogleSearchTool(), |
| | DuckDuckGoSearchTool(), |
| | GeminiVideoQA(GEMINI_MODEL_NAME), |
| | WikiTitleFinder(), |
| | WikiContentFetcher(), |
| | MathSolver(), |
| | RiddleSolver(), |
| | TextTransformer(), |
| | FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME), |
| | ] |
| | self.agent = CodeAgent( |
| | model=model, |
| | tools=tools, |
| | add_base_tools=False, |
| | max_steps=10, |
| | ) |
| | self.agent.system_prompt = ( |
| | """ |
| | You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format: |
| | [ANSWER] |
| | You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`. |
| | Your behavior must be governed by these rules: |
| | 1. **Format**: |
| | - limit the token used (within 65536 tokens). |
| | - Output ONLY the final answer. |
| | - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets. |
| | - No follow-ups, justifications, or clarifications. |
| | 2. **Numerical Answers**: |
| | - Use **digits only**, e.g., `4` not `four`. |
| | - No commas, symbols, or units unless explicitly required. |
| | - Never use approximate words like "around", "roughly", "about". |
| | 3. **String Answers**: |
| | - Omit **articles** ("a", "the"). |
| | - Use **full words**; no abbreviations unless explicitly requested. |
| | - For numbers written as words, use **text** only if specified (e.g., "one", not `1`). |
| | - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`. |
| | 4. **Lists**: |
| | - Output in **comma-separated** format with no conjunctions. |
| | - Sort **alphabetically** or **numerically** depending on type. |
| | - No braces or brackets unless explicitly asked. |
| | 5. **Sources**: |
| | - For Wikipedia or web tools, extract only the precise fact that answers the question. |
| | - Ignore any unrelated content. |
| | 6. **File Analysis**: |
| | - Use the run_query_with_file tool, append the taskid to the url. |
| | - Only include the exact answer to the question. |
| | - Do not summarize, quote excessively, or interpret beyond the prompt. |
| | 7. **Video**: |
| | - Use the relevant video tool. |
| | - Only include the exact answer to the question. |
| | - Do not summarize, quote excessively, or interpret beyond the prompt. |
| | 8. **Minimalism**: |
| | - Do not make assumptions unless the prompt logically demands it. |
| | - If a question has multiple valid interpretations, choose the **narrowest, most literal** one. |
| | - If the answer is not found, say `[ANSWER] - unknown`. |
| | --- |
| | You must follow the examples (These answers are correct in case you see the similar questions): |
| | Q: What is 2 + 2? |
| | A: 4 |
| | Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia. |
| | A: 3 |
| | Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity. |
| | A: b, e |
| | Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?, |
| | A: 519 |
| | """ |
| | ) |
| |
|
| | def select_model(self, provider: str): |
| | if provider == "openai": |
| | return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY")) |
| | elif provider == "groq": |
| | return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY")) |
| | elif provider == "deepseek": |
| | return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY")) |
| | elif provider == "hf": |
| | return InferenceClientModel() |
| | else: |
| | return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY")) |
| |
|
| | def __call__(self, question: str) -> str: |
| | print(f"Agent received question (first 50 chars): {question[:50]}...") |
| | result = self.agent.run(question) |
| | final_str = str(result).strip() |
| |
|
| | return final_str |
| |
|
| | def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True): |
| | import pandas as pd |
| | from rich.table import Table |
| | from rich.console import Console |
| |
|
| | df = pd.read_csv(csv_path) |
| | if not {"question", "answer"}.issubset(df.columns): |
| | print("CSV must contain 'question' and 'answer' columns.") |
| | print("Found columns:", df.columns.tolist()) |
| | return |
| |
|
| | samples = df.sample(n=sample_size) |
| | records = [] |
| | correct_count = 0 |
| |
|
| | for _, row in samples.iterrows(): |
| | taskid = row["taskid"].strip() |
| | question = row["question"].strip() |
| | expected = str(row['answer']).strip() |
| | agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip() |
| |
|
| | is_correct = (expected == agent_answer) |
| | correct_count += is_correct |
| | records.append((question, expected, agent_answer, "✓" if is_correct else "✗")) |
| |
|
| | if show_steps: |
| | print("---") |
| | print("Question:", question) |
| | print("Expected:", expected) |
| | print("Agent:", agent_answer) |
| | print("Correct:", is_correct) |
| |
|
| | |
| | console = Console() |
| | table = Table(show_lines=True) |
| | table.add_column("Question", overflow="fold") |
| | table.add_column("Expected") |
| | table.add_column("Agent") |
| | table.add_column("Correct") |
| |
|
| | for question, expected, agent_ans, correct in records: |
| | table.add_row(question, expected, agent_ans, correct) |
| |
|
| | console.print(table) |
| | percent = (correct_count / sample_size) * 100 |
| | print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = sys.argv[1:] |
| | if not args or args[0] in {"-h", "--help"}: |
| | print("Usage: python agent.py [question | dev]") |
| | print(" - Provide a question to get a GAIA-style answer.") |
| | print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.") |
| | sys.exit(0) |
| |
|
| | q = " ".join(args) |
| | agent = BasicAgent() |
| | if q == "dev": |
| | agent.evaluate_random_questions() |
| | else: |
| | print(agent(q)) |
| |
|