File size: 5,625 Bytes
c509185 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, List, Optional
import requests
from huggingface_hub import InferenceClient
@dataclass
class LLMConfig:
provider: str = os.environ.get("LLM_PROVIDER", "hf_inference") # hf_inference|local|openai|together|groq
model: str = os.environ.get("HF_INFERENCE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
max_new_tokens: int = 600
temperature: float = 0.4
top_p: float = 0.9
stop: Optional[List[str]] = None
class LLMClient:
def __init__(self, config: Optional[LLMConfig] = None):
self.config = config or LLMConfig()
self._init_provider()
def _init_provider(self):
p = self.config.provider
if p == "hf_inference":
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
self.hf = InferenceClient(token=token)
elif p == "openai":
# Lazy import to avoid dependency unless needed
from openai import OpenAI # type: ignore
self.openai = OpenAI()
elif p == "together":
# Uses Together's API
self.together_api_key = os.environ.get("TOGETHER_API_KEY")
elif p == "groq":
self.groq_api_key = os.environ.get("GROQ_API_KEY")
elif p == "local":
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
model_id = self.config.model
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
)
else:
raise ValueError(f"Unsupported provider: {p}")
def generate(self, prompt: str, **kwargs) -> str:
p = self.config.provider
if p == "hf_inference":
model = self.config.model
resp = self.hf.text_generation(
prompt,
model=model,
max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
top_p=kwargs.get("top_p", self.config.top_p),
stop_sequences=kwargs.get("stop", self.config.stop),
)
return resp
elif p == "local":
out = self.pipe(
prompt,
max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
temperature=kwargs.get("temperature", self.config.temperature),
top_p=kwargs.get("top_p", self.config.top_p),
do_sample=True,
)
return out[0]["generated_text"]
elif p == "openai":
# Assumes model is a chat model id
client = self.openai
completion = client.chat.completions.create(
model=os.environ.get("OPENAI_MODEL", self.config.model),
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature=kwargs.get("temperature", self.config.temperature),
max_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
top_p=kwargs.get("top_p", self.config.top_p),
)
return completion.choices[0].message.content or ""
elif p == "together":
url = "https://api.together.xyz/v1/chat/completions"
headers = {"Authorization": f"Bearer {self.together_api_key}", "Content-Type": "application/json"}
data = {
"model": self.config.model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
}
r = requests.post(url, headers=headers, json=data, timeout=60)
r.raise_for_status()
j = r.json()
return j["choices"][0]["message"]["content"]
elif p == "groq":
url = "https://api.groq.com/openai/v1/chat/completions"
headers = {"Authorization": f"Bearer {self.groq_api_key}", "Content-Type": "application/json"}
data = {
"model": self.config.model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
"temperature": kwargs.get("temperature", self.config.temperature),
"max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
"top_p": kwargs.get("top_p", self.config.top_p),
}
r = requests.post(url, headers=headers, json=data, timeout=60)
r.raise_for_status()
j = r.json()
return j["choices"][0]["message"]["content"]
else:
raise ValueError(f"Unsupported provider: {p}")
|