Mihai Băluță-Cujbă
Add initial implementation of AI-Powered Technical Initiative Generator
c509185
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, List, Optional
import requests
from huggingface_hub import InferenceClient
@dataclass
class LLMConfig:
provider: str = os.environ.get("LLM_PROVIDER", "hf_inference") # hf_inference|local|openai|together|groq
model: str = os.environ.get("HF_INFERENCE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
max_new_tokens: int = 600
temperature: float = 0.4
top_p: float = 0.9
stop: Optional[List[str]] = None
class LLMClient:
def __init__(self, config: Optional[LLMConfig] = None):
self.config = config or LLMConfig()
self._init_provider()
def _init_provider(self):
p = self.config.provider
if p == "hf_inference":
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
self.hf = InferenceClient(token=token)
elif p == "openai":
# Lazy import to avoid dependency unless needed
from openai import OpenAI # type: ignore
self.openai = OpenAI()
elif p == "together":
# Uses Together's API
self.together_api_key = os.environ.get("TOGETHER_API_KEY")
elif p == "groq":
self.groq_api_key = os.environ.get("GROQ_API_KEY")
elif p == "local":
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
model_id = self.config.model
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
)
else:
raise ValueError(f"Unsupported provider: {p}")
def generate(self, prompt: str, **kwargs) -> str:
p = self.config.provider
if p == "hf_inference":
model = self.config.model
resp = self.hf.text_generation(
prompt,
model=model,
max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
top_p=kwargs.get("top_p", self.config.top_p),
stop_sequences=kwargs.get("stop", self.config.stop),
)
return resp
elif p == "local":
out = self.pipe(
prompt,
max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
top_p=kwargs.get("top_p", self.config.top_p),
do_sample=True,
)
return out[0]["generated_text"]
elif p == "openai":
# Assumes model is a chat model id
client = self.openai
completion = client.chat.completions.create(
model=os.environ.get("OPENAI_MODEL", self.config.model),
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature=kwargs.get("temperature", self.config.temperature),
max_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
top_p=kwargs.get("top_p", self.config.top_p),
)
return completion.choices[0].message.content or ""
elif p == "together":
url = "https://api.together.xyz/v1/chat/completions"
headers = {"Authorization": f"Bearer {self.together_api_key}", "Content-Type": "application/json"}
data = {
"model": self.config.model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
}
r = requests.post(url, headers=headers, json=data, timeout=60)
r.raise_for_status()
j = r.json()
return j["choices"][0]["message"]["content"]
elif p == "groq":
url = "https://api.groq.com/openai/v1/chat/completions"
headers = {"Authorization": f"Bearer {self.groq_api_key}", "Content-Type": "application/json"}
data = {
"model": self.config.model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
}
r = requests.post(url, headers=headers, json=data, timeout=60)
r.raise_for_status()
j = r.json()
return j["choices"][0]["message"]["content"]
else:
raise ValueError(f"Unsupported provider: {p}")