rabiyulfahim commited on
Commit
040c903
·
verified ·
1 Parent(s): ba8f36f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -12
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, Query, HTTPException
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from pydantic import BaseModel
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import HTMLResponse
@@ -8,10 +8,12 @@ import os
8
  import torch
9
 
10
  # -----------------------
11
- # Hugging Face cache
12
  # -----------------------
13
- os.environ["HF_HOME"] = "/tmp" # writable cache
14
- os.environ["TRANSFORMERS_CACHE"] = "/tmp" # optional
 
 
15
 
16
  # -----------------------
17
  # Model Setup
@@ -20,10 +22,15 @@ model_id = "LLM360/K2-Think"
20
 
21
  print("Loading tokenizer and model...")
22
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
 
 
 
 
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_id,
25
- device_map="auto", # auto assign to GPU/CPU
26
- load_in_8bit=True, # 8-bit quantization for low memory
27
  cache_dir="/tmp"
28
  )
29
  print("Model loaded!")
@@ -59,12 +66,6 @@ class QueryRequest(BaseModel):
59
  def home():
60
  return {"message": "Welcome to K2-Think QA API 🚀"}
61
 
62
- @app.get("/ui", response_class=HTMLResponse)
63
- def serve_ui():
64
- html_path = os.path.join("static", "index.html")
65
- with open(html_path, "r", encoding="utf-8") as f:
66
- return HTMLResponse(f.read())
67
-
68
  @app.get("/health")
69
  def health():
70
  return {"status": "ok"}
 
1
  from fastapi import FastAPI, Query, HTTPException
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  from pydantic import BaseModel
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import HTMLResponse
 
8
  import torch
9
 
10
  # -----------------------
11
+ # Set cache dirs (avoid Docker errors)
12
  # -----------------------
13
+ os.environ["HF_HOME"] = "/tmp"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
15
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
16
+ os.makedirs("/tmp/torch_inductor_cache", exist_ok=True)
17
 
18
  # -----------------------
19
  # Model Setup
 
22
 
23
  print("Loading tokenizer and model...")
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
25
+
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_8bit=True # 8-bit quantization
28
+ )
29
+
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_id,
32
+ quantization_config=bnb_config,
33
+ device_map="auto",
34
  cache_dir="/tmp"
35
  )
36
  print("Model loaded!")
 
66
  def home():
67
  return {"message": "Welcome to K2-Think QA API 🚀"}
68
 
 
 
 
 
 
 
69
  @app.get("/health")
70
  def health():
71
  return {"status": "ok"}