Mihai Băluță-Cujbă
Add initial implementation of AI-Powered Technical Initiative Generator
c509185
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional | |
| import requests | |
| from huggingface_hub import InferenceClient | |
| class LLMConfig: | |
| provider: str = os.environ.get("LLM_PROVIDER", "hf_inference") # hf_inference|local|openai|together|groq | |
| model: str = os.environ.get("HF_INFERENCE_MODEL", "Qwen/Qwen2.5-3B-Instruct") | |
| max_new_tokens: int = 600 | |
| temperature: float = 0.4 | |
| top_p: float = 0.9 | |
| stop: Optional[List[str]] = None | |
| class LLMClient: | |
| def __init__(self, config: Optional[LLMConfig] = None): | |
| self.config = config or LLMConfig() | |
| self._init_provider() | |
| def _init_provider(self): | |
| p = self.config.provider | |
| if p == "hf_inference": | |
| token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN") | |
| self.hf = InferenceClient(token=token) | |
| elif p == "openai": | |
| # Lazy import to avoid dependency unless needed | |
| from openai import OpenAI # type: ignore | |
| self.openai = OpenAI() | |
| elif p == "together": | |
| # Uses Together's API | |
| self.together_api_key = os.environ.get("TOGETHER_API_KEY") | |
| elif p == "groq": | |
| self.groq_api_key = os.environ.get("GROQ_API_KEY") | |
| elif p == "local": | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import torch | |
| model_id = self.config.model | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {p}") | |
| def generate(self, prompt: str, **kwargs) -> str: | |
| p = self.config.provider | |
| if p == "hf_inference": | |
| model = self.config.model | |
| resp = self.hf.text_generation( | |
| prompt, | |
| model=model, | |
| max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens), | |
| temperature=kwargs.get("temperature", self.config.temperature), | |
| top_p=kwargs.get("top_p", self.config.top_p), | |
| stop_sequences=kwargs.get("stop", self.config.stop), | |
| ) | |
| return resp | |
| elif p == "local": | |
| out = self.pipe( | |
| prompt, | |
| max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens), | |
| temperature=kwargs.get("temperature", self.config.temperature), | |
| top_p=kwargs.get("top_p", self.config.top_p), | |
| do_sample=True, | |
| ) | |
| return out[0]["generated_text"] | |
| elif p == "openai": | |
| # Assumes model is a chat model id | |
| client = self.openai | |
| completion = client.chat.completions.create( | |
| model=os.environ.get("OPENAI_MODEL", self.config.model), | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=kwargs.get("temperature", self.config.temperature), | |
| max_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens), | |
| top_p=kwargs.get("top_p", self.config.top_p), | |
| ) | |
| return completion.choices[0].message.content or "" | |
| elif p == "together": | |
| url = "https://api.together.xyz/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.together_api_key}", "Content-Type": "application/json"} | |
| data = { | |
| "model": self.config.model, | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| "temperature": kwargs.get("temperature", self.config.temperature), | |
| "max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens), | |
| "top_p": kwargs.get("top_p", self.config.top_p), | |
| } | |
| r = requests.post(url, headers=headers, json=data, timeout=60) | |
| r.raise_for_status() | |
| j = r.json() | |
| return j["choices"][0]["message"]["content"] | |
| elif p == "groq": | |
| url = "https://api.groq.com/openai/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {self.groq_api_key}", "Content-Type": "application/json"} | |
| data = { | |
| "model": self.config.model, | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| "temperature": kwargs.get("temperature", self.config.temperature), | |
| "max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens), | |
| "top_p": kwargs.get("top_p", self.config.top_p), | |
| } | |
| r = requests.post(url, headers=headers, json=data, timeout=60) | |
| r.raise_for_status() | |
| j = r.json() | |
| return j["choices"][0]["message"]["content"] | |
| else: | |
| raise ValueError(f"Unsupported provider: {p}") | |