File size: 5,625 Bytes
c509185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Dict, List, Optional

import requests
from huggingface_hub import InferenceClient


@dataclass
class LLMConfig:
    provider: str = os.environ.get("LLM_PROVIDER", "hf_inference")  # hf_inference|local|openai|together|groq
    model: str = os.environ.get("HF_INFERENCE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
    max_new_tokens: int = 600
    temperature: float = 0.4
    top_p: float = 0.9
    stop: Optional[List[str]] = None


class LLMClient:
    def __init__(self, config: Optional[LLMConfig] = None):
        self.config = config or LLMConfig()
        self._init_provider()

    def _init_provider(self):
        p = self.config.provider
        if p == "hf_inference":
            token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
            self.hf = InferenceClient(token=token)
        elif p == "openai":
            # Lazy import to avoid dependency unless needed
            from openai import OpenAI  # type: ignore

            self.openai = OpenAI()
        elif p == "together":
            # Uses Together's API
            self.together_api_key = os.environ.get("TOGETHER_API_KEY")
        elif p == "groq":
            self.groq_api_key = os.environ.get("GROQ_API_KEY")
        elif p == "local":
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
            import torch

            model_id = self.config.model
            self.tokenizer = AutoTokenizer.from_pretrained(model_id)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            self.pipe = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=0 if torch.cuda.is_available() else -1,
            )
        else:
            raise ValueError(f"Unsupported provider: {p}")

    def generate(self, prompt: str, **kwargs) -> str:
        p = self.config.provider
        if p == "hf_inference":
            model = self.config.model
            resp = self.hf.text_generation(
                prompt,
                model=model,
                max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
                temperature=kwargs.get("temperature", self.config.temperature),
                top_p=kwargs.get("top_p", self.config.top_p),
                stop_sequences=kwargs.get("stop", self.config.stop),
            )
            return resp
        elif p == "local":
            out = self.pipe(
                prompt,
                max_new_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
                temperature=kwargs.get("temperature", self.config.temperature),
                top_p=kwargs.get("top_p", self.config.top_p),
                do_sample=True,
            )
            return out[0]["generated_text"]
        elif p == "openai":
            # Assumes model is a chat model id
            client = self.openai
            completion = client.chat.completions.create(
                model=os.environ.get("OPENAI_MODEL", self.config.model),
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt},
                ],
                temperature=kwargs.get("temperature", self.config.temperature),
                max_tokens=kwargs.get("max_new_tokens", self.config.max_new_tokens),
                top_p=kwargs.get("top_p", self.config.top_p),
            )
            return completion.choices[0].message.content or ""
        elif p == "together":
            url = "https://api.together.xyz/v1/chat/completions"
            headers = {"Authorization": f"Bearer {self.together_api_key}", "Content-Type": "application/json"}
            data = {
                "model": self.config.model,
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt},
                ],
                "temperature": kwargs.get("temperature", self.config.temperature),
                "max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
                "top_p": kwargs.get("top_p", self.config.top_p),
            }
            r = requests.post(url, headers=headers, json=data, timeout=60)
            r.raise_for_status()
            j = r.json()
            return j["choices"][0]["message"]["content"]
        elif p == "groq":
            url = "https://api.groq.com/openai/v1/chat/completions"
            headers = {"Authorization": f"Bearer {self.groq_api_key}", "Content-Type": "application/json"}
            data = {
                "model": self.config.model,
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt},
                ],
                "temperature": kwargs.get("temperature", self.config.temperature),
                "max_tokens": kwargs.get("max_new_tokens", self.config.max_new_tokens),
                "top_p": kwargs.get("top_p", self.config.top_p),
            }
            r = requests.post(url, headers=headers, json=data, timeout=60)
            r.raise_for_status()
            j = r.json()
            return j["choices"][0]["message"]["content"]
        else:
            raise ValueError(f"Unsupported provider: {p}")