Add tokenizer normalization retry in load_model
Browse files
app.py
CHANGED
|
@@ -37,6 +37,10 @@ TOKENIZER = None
|
|
| 37 |
MODEL_NAME = None
|
| 38 |
DEVICE = "cpu"
|
| 39 |
MODEL_LOCK = threading.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# ----------------------------- Utilities ---# ------------------------------
|
| 42 |
|
|
@@ -83,8 +87,18 @@ def _diagnose_and_fix_tokenizer_model(tok: AutoTokenizer, mdl: AutoModelForCausa
|
|
| 83 |
# ensure pad token exists and ids/config align
|
| 84 |
if getattr(tok, "pad_token", None) is None:
|
| 85 |
tok.pad_token = getattr(tok, "eos_token", "[PAD]")
|
|
|
|
| 86 |
try:
|
| 87 |
tok.add_special_tokens({"pad_token": tok.pad_token})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
except Exception:
|
| 89 |
pass
|
| 90 |
try:
|
|
@@ -235,6 +249,25 @@ def repair_tokenizer_on_hub(repo_id: str) -> str:
|
|
| 235 |
|
| 236 |
# ----------------------------- Model loading -------------------------------
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
| 239 |
"""Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support.
|
| 240 |
|
|
@@ -250,6 +283,11 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 250 |
if MODEL is not None and MODEL_NAME == repo_id and not force_reload:
|
| 251 |
return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})"
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
MODEL = None
|
| 254 |
TOKENIZER = None
|
| 255 |
MODEL_NAME = repo_id
|
|
@@ -265,40 +303,135 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 265 |
trust_remote_code=True,
|
| 266 |
use_auth_token=hf_token,
|
| 267 |
)
|
| 268 |
-
|
| 269 |
-
#
|
| 270 |
try:
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
try:
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
# 2) Load model (prefer 4-bit on GPU if available)
|
| 298 |
if DEVICE == "cuda" and HAS_BNB:
|
| 299 |
try:
|
| 300 |
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 301 |
-
MODEL =
|
| 302 |
repo_id,
|
| 303 |
device_map="auto",
|
| 304 |
quantization_config=bnb_config,
|
|
@@ -307,6 +440,8 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 307 |
)
|
| 308 |
MODEL.eval()
|
| 309 |
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
|
|
|
|
|
|
|
| 310 |
return f"Loaded {repo_id} (4-bit, device_map=auto)"
|
| 311 |
except Exception as e:
|
| 312 |
print("bnb/4bit load failed - falling back:", e)
|
|
@@ -314,7 +449,7 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 314 |
# 3) FP16 / CPU fallback
|
| 315 |
try:
|
| 316 |
if DEVICE == "cuda":
|
| 317 |
-
MODEL =
|
| 318 |
repo_id,
|
| 319 |
device_map="auto",
|
| 320 |
torch_dtype=torch.float16,
|
|
@@ -322,7 +457,7 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 322 |
use_auth_token=hf_token,
|
| 323 |
)
|
| 324 |
else:
|
| 325 |
-
MODEL =
|
| 326 |
repo_id,
|
| 327 |
low_cpu_mem_usage=True,
|
| 328 |
torch_dtype=torch.float32,
|
|
@@ -333,11 +468,15 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
|
| 333 |
|
| 334 |
MODEL.eval()
|
| 335 |
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
|
|
|
|
|
|
|
| 336 |
return f"Loaded {repo_id} (@{DEVICE})"
|
| 337 |
except Exception as e:
|
| 338 |
MODEL = None
|
| 339 |
TOKENIZER = None
|
| 340 |
-
# provide a helpful diagnostic message
|
|
|
|
|
|
|
| 341 |
return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)"
|
| 342 |
|
| 343 |
|
|
@@ -358,6 +497,8 @@ def _normalize_history(raw_history) -> List[Tuple[str, str]]:
|
|
| 358 |
and return a list of (user, assistant) pairs suitable for prompt construction.
|
| 359 |
|
| 360 |
Behavior: pairs each user message with the next assistant message (assistant may be "" if not present).
|
|
|
|
|
|
|
| 361 |
"""
|
| 362 |
if not raw_history:
|
| 363 |
return []
|
|
@@ -409,6 +550,33 @@ def build_prompt(history, user_input: str, system_prompt: str, max_history: int
|
|
| 409 |
pairs = _normalize_history(history or [])
|
| 410 |
pairs = pairs[-max_history:]
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
parts: List[str] = [f"System: {system_prompt}"]
|
| 413 |
for u, a in pairs:
|
| 414 |
# include previous turns as completed instruction/response pairs
|
|
@@ -425,7 +593,10 @@ def _generate_text(prompt: str, temperature: float, top_p: float, top_k: int, ma
|
|
| 425 |
if MODEL is None or TOKENIZER is None:
|
| 426 |
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
|
| 427 |
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
gen_kwargs = dict(
|
| 431 |
input_ids=input_ids,
|
|
@@ -452,7 +623,8 @@ def _generate_stream(prompt: str, temperature: float, top_p: float, top_k: int,
|
|
| 452 |
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
|
| 453 |
|
| 454 |
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 455 |
-
|
|
|
|
| 456 |
|
| 457 |
gen_kwargs = dict(
|
| 458 |
input_ids=input_ids,
|
|
@@ -491,6 +663,17 @@ def submit_message(user_message: str, history, system_prompt: str, temperature:
|
|
| 491 |
# Append current user turn (assistant reply empty until generated)
|
| 492 |
pairs.append((str(user_message or ""), ""))
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history)
|
| 495 |
|
| 496 |
# If user is running the full Nanbeige model on CPU, warn and suggest options
|
|
@@ -538,7 +721,15 @@ def regenerate(history, system_prompt: str, temperature: float, top_p: float, to
|
|
| 538 |
|
| 539 |
def load_model_ui(repo: str):
|
| 540 |
status = load_model(repo, force_reload=True)
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
|
| 544 |
def apply_lora_adapter(adapter_repo: str):
|
|
@@ -609,7 +800,7 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
|
|
| 609 |
apply_adapter = gr.Button("Apply LoRA adapter")
|
| 610 |
|
| 611 |
# Events
|
| 612 |
-
load_btn.click(fn=
|
| 613 |
repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status)
|
| 614 |
|
| 615 |
send.click(
|
|
@@ -635,7 +826,21 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
|
|
| 635 |
|
| 636 |
# auto-load default model in background (non-blocking)
|
| 637 |
def _bg_initial_load():
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
# For local smoke tests you can skip automatic model loading by setting
|
| 641 |
# environment variable `SKIP_AUTOLOAD=1` so the UI starts without loading
|
|
@@ -644,6 +849,11 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
|
|
| 644 |
model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)"
|
| 645 |
else:
|
| 646 |
model_status.value = _bg_initial_load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
|
| 648 |
# CPU warning / demo hint (visible in UI)
|
| 649 |
gr.Markdown("""
|
|
|
|
| 37 |
MODEL_NAME = None
|
| 38 |
DEVICE = "cpu"
|
| 39 |
MODEL_LOCK = threading.Lock()
|
| 40 |
+
# flag: whether a model load is currently in progress (prevents requests)
|
| 41 |
+
MODEL_LOADING = False
|
| 42 |
+
# flag: whether the loaded tokenizer exposes a chat template helper
|
| 43 |
+
USE_CHAT_TEMPLATE = False
|
| 44 |
|
| 45 |
# ----------------------------- Utilities ---# ------------------------------
|
| 46 |
|
|
|
|
| 87 |
# ensure pad token exists and ids/config align
|
| 88 |
if getattr(tok, "pad_token", None) is None:
|
| 89 |
tok.pad_token = getattr(tok, "eos_token", "[PAD]")
|
| 90 |
+
# Be defensive: different tokenizer backends expect different arg types
|
| 91 |
try:
|
| 92 |
tok.add_special_tokens({"pad_token": tok.pad_token})
|
| 93 |
+
except TypeError as e:
|
| 94 |
+
# try list form or add_tokens fallback
|
| 95 |
+
try:
|
| 96 |
+
tok.add_special_tokens([tok.pad_token])
|
| 97 |
+
except Exception:
|
| 98 |
+
try:
|
| 99 |
+
tok.add_tokens([tok.pad_token])
|
| 100 |
+
except Exception:
|
| 101 |
+
pass
|
| 102 |
except Exception:
|
| 103 |
pass
|
| 104 |
try:
|
|
|
|
| 249 |
|
| 250 |
# ----------------------------- Model loading -------------------------------
|
| 251 |
|
| 252 |
+
|
| 253 |
+
def _safe_model_from_pretrained(repo_id, *args, **kwargs):
|
| 254 |
+
"""Call AutoModelForCausalLM.from_pretrained but retry without `use_auth_token`
|
| 255 |
+
if the called class improperly forwards unexpected kwargs into __init__.
|
| 256 |
+
"""
|
| 257 |
+
try:
|
| 258 |
+
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs)
|
| 259 |
+
except TypeError as e:
|
| 260 |
+
msg = str(e)
|
| 261 |
+
if "use_auth_token" in msg or "unexpected keyword argument" in msg:
|
| 262 |
+
# retry without auth-token kwargs (some remote `from_pretrained` may leak kwargs)
|
| 263 |
+
kwargs2 = dict(kwargs)
|
| 264 |
+
kwargs2.pop("use_auth_token", None)
|
| 265 |
+
kwargs2.pop("token", None)
|
| 266 |
+
print(f"_safe_model_from_pretrained: retrying without auth-token due to: {e}")
|
| 267 |
+
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs2)
|
| 268 |
+
raise
|
| 269 |
+
|
| 270 |
+
|
| 271 |
def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
|
| 272 |
"""Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support.
|
| 273 |
|
|
|
|
| 283 |
if MODEL is not None and MODEL_NAME == repo_id and not force_reload:
|
| 284 |
return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})"
|
| 285 |
|
| 286 |
+
# mark loading state so UI handlers can guard incoming requests
|
| 287 |
+
global MODEL_LOADING
|
| 288 |
+
MODEL_LOADING = True
|
| 289 |
+
print(f"Model load started: {repo_id}")
|
| 290 |
+
|
| 291 |
MODEL = None
|
| 292 |
TOKENIZER = None
|
| 293 |
MODEL_NAME = repo_id
|
|
|
|
| 303 |
trust_remote_code=True,
|
| 304 |
use_auth_token=hf_token,
|
| 305 |
)
|
| 306 |
+
print(f"Tokenizer loaded from repo: {repo_id}")
|
| 307 |
+
# detect whether tokenizer supports the Nanbeige chat template API
|
| 308 |
try:
|
| 309 |
+
global USE_CHAT_TEMPLATE
|
| 310 |
+
USE_CHAT_TEMPLATE = hasattr(TOKENIZER, "apply_chat_template")
|
| 311 |
+
print(f"USE_CHAT_TEMPLATE={USE_CHAT_TEMPLATE}")
|
| 312 |
+
except Exception:
|
| 313 |
+
USE_CHAT_TEMPLATE = False
|
| 314 |
+
except Exception as e_tok:
|
| 315 |
+
print(f"Tokenizer load from {repo_id} failed: {e_tok}")
|
| 316 |
+
# specific fix: some tokenizers fail with 'Input must be a List...' when
|
| 317 |
+
# `special_tokens_map.json` contains dict entries instead of plain strings.
|
| 318 |
+
# Try an in-memory normalization + local retry before broader fallbacks/repairs.
|
| 319 |
+
if "Input must be a List" in str(e_tok) or "Input must be a List[Union[str, AddedToken]]" in str(e_tok):
|
| 320 |
+
try:
|
| 321 |
+
print('Detected tokenizer add-tokens type error; attempting in-place normalization and retry...')
|
| 322 |
+
# try to download tokenizer files and normalize special_tokens_map.json
|
| 323 |
+
try:
|
| 324 |
+
from huggingface_hub import hf_hub_download
|
| 325 |
+
import json, tempfile, shutil
|
| 326 |
+
|
| 327 |
+
tmp = tempfile.mkdtemp(prefix="tokfix_")
|
| 328 |
+
# files we need locally for AutoTokenizer
|
| 329 |
+
candidates = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "added_tokens.json"]
|
| 330 |
+
for fn in candidates:
|
| 331 |
+
try:
|
| 332 |
+
src = hf_hub_download(repo_id=repo_id, filename=fn, token=hf_token)
|
| 333 |
+
shutil.copy(src, tmp)
|
| 334 |
+
except Exception:
|
| 335 |
+
# ignore missing files — AutoTokenizer is tolerant
|
| 336 |
+
pass
|
| 337 |
+
|
| 338 |
+
# normalize special_tokens_map.json if present
|
| 339 |
+
stm = os.path.join(tmp, "special_tokens_map.json")
|
| 340 |
+
if os.path.exists(stm):
|
| 341 |
+
try:
|
| 342 |
+
with open(stm, "r", encoding="utf-8") as f:
|
| 343 |
+
stm_j = json.load(f)
|
| 344 |
+
changed = False
|
| 345 |
+
if "additional_special_tokens" in stm_j:
|
| 346 |
+
new = []
|
| 347 |
+
for it in stm_j["additional_special_tokens"]:
|
| 348 |
+
if isinstance(it, dict):
|
| 349 |
+
new.append(it.get("content") or it.get("token") or str(it))
|
| 350 |
+
changed = True
|
| 351 |
+
else:
|
| 352 |
+
new.append(it)
|
| 353 |
+
stm_j["additional_special_tokens"] = new
|
| 354 |
+
for k in ["bos_token", "eos_token", "pad_token", "unk_token"]:
|
| 355 |
+
if k in stm_j and isinstance(stm_j[k], dict):
|
| 356 |
+
stm_j[k] = stm_j[k].get("content", stm_j[k])
|
| 357 |
+
changed = True
|
| 358 |
+
if changed:
|
| 359 |
+
with open(stm, "w", encoding="utf-8") as f:
|
| 360 |
+
json.dump(stm_j, f, ensure_ascii=False, indent=2)
|
| 361 |
+
print('Normalized special_tokens_map.json in temp dir')
|
| 362 |
+
except Exception:
|
| 363 |
+
pass
|
| 364 |
+
|
| 365 |
+
# try loading tokenizer from the temporary normalized directory
|
| 366 |
+
TOKENIZER = AutoTokenizer.from_pretrained(tmp, use_fast=False, trust_remote_code=True)
|
| 367 |
+
print('Tokenizer reloaded from normalized temp copy')
|
| 368 |
+
shutil.rmtree(tmp)
|
| 369 |
+
except Exception as e_localnorm:
|
| 370 |
+
print('In-place normalization retry failed:', e_localnorm)
|
| 371 |
+
# fall through to the existing repair path below
|
| 372 |
+
|
| 373 |
+
# as a fallback, attempt to auto-repair the remote repo (if HF token available)
|
| 374 |
+
if hf_token:
|
| 375 |
+
print('Attempting repo-side auto-repair/upload from base tokenizer...')
|
| 376 |
+
_repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
|
| 377 |
+
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, use_fast=False, trust_remote_code=True)
|
| 378 |
+
print('Tokenizer reloaded after repo repair')
|
| 379 |
+
else:
|
| 380 |
+
# final fallback will be handled by the outer fallbacks below
|
| 381 |
+
raise RuntimeError('Normalization + auto-repair could not proceed (no HF_TOKEN)')
|
| 382 |
+
except Exception as e_retry:
|
| 383 |
+
print('Repair/retry failed:', e_retry)
|
| 384 |
+
return f"Tokenizer load failed: {e_retry}"
|
| 385 |
+
else:
|
| 386 |
+
# If a local repo was cloned without git-lfs, tokenizer.model may be a pointer file — try auto-fetch
|
| 387 |
try:
|
| 388 |
+
if os.path.isdir(repo_id) and _ensure_local_tokenizer_model(repo_id, hf_token=hf_token):
|
| 389 |
+
print(f"Found LFS pointer at {repo_id}/tokenizer.model — fetched real tokenizer.model; retrying tokenizer load...")
|
| 390 |
+
TOKENIZER = AutoTokenizer.from_pretrained(
|
| 391 |
+
repo_id,
|
| 392 |
+
use_fast=False,
|
| 393 |
+
trust_remote_code=True,
|
| 394 |
+
use_auth_token=hf_token,
|
| 395 |
+
)
|
| 396 |
+
print(f"Tokenizer loaded from local repo after fetching LFS: {repo_id}")
|
| 397 |
+
else:
|
| 398 |
+
# Local workspace fallback: use bundled Nanbeige4.1 tokenizer if available
|
| 399 |
+
local_fallback = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'Nanbeige4.1-3B'))
|
| 400 |
+
if os.path.isdir(local_fallback):
|
| 401 |
+
try:
|
| 402 |
+
print(f"Attempting local workspace tokenizer fallback: {local_fallback}")
|
| 403 |
+
TOKENIZER = AutoTokenizer.from_pretrained(local_fallback, use_fast=False, trust_remote_code=True)
|
| 404 |
+
print(f"Tokenizer loaded from local workspace: {local_fallback}")
|
| 405 |
+
except Exception as e_local:
|
| 406 |
+
print(f"Local tokenizer fallback failed: {e_local}")
|
| 407 |
+
raise e_local
|
| 408 |
+
else:
|
| 409 |
+
# Try known base tokenizer on the Hub (Nanbeige4.1 if repo looks like 4.1)
|
| 410 |
+
base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else "PioTio/Nanbeige2.5"
|
| 411 |
+
print(f"Falling back to base tokenizer: {base}")
|
| 412 |
+
TOKENIZER = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True, use_auth_token=hf_token)
|
| 413 |
+
|
| 414 |
+
# If HF token is available, attempt to auto-repair/upload tokenizer files to the target repo
|
| 415 |
+
if hf_token:
|
| 416 |
+
try:
|
| 417 |
+
uploaded = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
|
| 418 |
+
print(f"Auto-repair attempt to {repo_id}: {'succeeded' if uploaded else 'no-change/failure'}")
|
| 419 |
+
except Exception as e_rep:
|
| 420 |
+
print(f"Auto-repair attempt failed: {e_rep}")
|
| 421 |
+
except Exception as e_base:
|
| 422 |
+
# last-resort: try fast tokenizer (may still fail or produce garbled output)
|
| 423 |
+
try:
|
| 424 |
+
print(f"All fallbacks failed: {e_base}. Trying generic AutoTokenizer as last resort...")
|
| 425 |
+
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True, use_auth_token=hf_token)
|
| 426 |
+
except Exception as e_final:
|
| 427 |
+
MODEL_LOADING = False
|
| 428 |
+
return f"Tokenizer load failed: {e_final}"
|
| 429 |
|
| 430 |
# 2) Load model (prefer 4-bit on GPU if available)
|
| 431 |
if DEVICE == "cuda" and HAS_BNB:
|
| 432 |
try:
|
| 433 |
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 434 |
+
MODEL = _safe_model_from_pretrained(
|
| 435 |
repo_id,
|
| 436 |
device_map="auto",
|
| 437 |
quantization_config=bnb_config,
|
|
|
|
| 440 |
)
|
| 441 |
MODEL.eval()
|
| 442 |
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
|
| 443 |
+
MODEL_LOADING = False
|
| 444 |
+
print(f"Model load finished (4-bit): {repo_id}")
|
| 445 |
return f"Loaded {repo_id} (4-bit, device_map=auto)"
|
| 446 |
except Exception as e:
|
| 447 |
print("bnb/4bit load failed - falling back:", e)
|
|
|
|
| 449 |
# 3) FP16 / CPU fallback
|
| 450 |
try:
|
| 451 |
if DEVICE == "cuda":
|
| 452 |
+
MODEL = _safe_model_from_pretrained(
|
| 453 |
repo_id,
|
| 454 |
device_map="auto",
|
| 455 |
torch_dtype=torch.float16,
|
|
|
|
| 457 |
use_auth_token=hf_token,
|
| 458 |
)
|
| 459 |
else:
|
| 460 |
+
MODEL = _safe_model_from_pretrained(
|
| 461 |
repo_id,
|
| 462 |
low_cpu_mem_usage=True,
|
| 463 |
torch_dtype=torch.float32,
|
|
|
|
| 468 |
|
| 469 |
MODEL.eval()
|
| 470 |
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
|
| 471 |
+
MODEL_LOADING = False
|
| 472 |
+
print(f"Model load finished: {repo_id} (@{DEVICE})")
|
| 473 |
return f"Loaded {repo_id} (@{DEVICE})"
|
| 474 |
except Exception as e:
|
| 475 |
MODEL = None
|
| 476 |
TOKENIZER = None
|
| 477 |
+
# clear loading flag and provide a helpful diagnostic message
|
| 478 |
+
MODEL_LOADING = False
|
| 479 |
+
print(f"Model load failed: {repo_id} -> {e}")
|
| 480 |
return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)"
|
| 481 |
|
| 482 |
|
|
|
|
| 497 |
and return a list of (user, assistant) pairs suitable for prompt construction.
|
| 498 |
|
| 499 |
Behavior: pairs each user message with the next assistant message (assistant may be "" if not present).
|
| 500 |
+
NOTE: For chat-first models (Nanbeige4.1) we prefer `tokenizer.apply_chat_template` later
|
| 501 |
+
so this function only normalizes the history shape.
|
| 502 |
"""
|
| 503 |
if not raw_history:
|
| 504 |
return []
|
|
|
|
| 550 |
pairs = _normalize_history(history or [])
|
| 551 |
pairs = pairs[-max_history:]
|
| 552 |
|
| 553 |
+
# If tokenizer provides a chat-template helper (Nanbeige4.1), use it.
|
| 554 |
+
# This avoids instruction-format mismatches that produce garbled output.
|
| 555 |
+
try:
|
| 556 |
+
from __main__ import TOKENIZER # safe access to global TOKENIZER when available
|
| 557 |
+
except Exception:
|
| 558 |
+
TOKENIZER = None
|
| 559 |
+
|
| 560 |
+
if TOKENIZER is not None and hasattr(TOKENIZER, "apply_chat_template"):
|
| 561 |
+
# build messages list with optional system prompt first
|
| 562 |
+
messages = []
|
| 563 |
+
if system_prompt:
|
| 564 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 565 |
+
for u, a in pairs:
|
| 566 |
+
messages.append({"role": "user", "content": u})
|
| 567 |
+
if a:
|
| 568 |
+
messages.append({"role": "assistant", "content": a})
|
| 569 |
+
# current user turn
|
| 570 |
+
messages.append({"role": "user", "content": user_input})
|
| 571 |
+
# use tokenizer's chat template (returns the full prompt string)
|
| 572 |
+
try:
|
| 573 |
+
prompt = TOKENIZER.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 574 |
+
return prompt
|
| 575 |
+
except Exception:
|
| 576 |
+
# fall back to ALPACA format if anything goes wrong
|
| 577 |
+
pass
|
| 578 |
+
|
| 579 |
+
# Default / fallback: ALPACA-style instruction template
|
| 580 |
parts: List[str] = [f"System: {system_prompt}"]
|
| 581 |
for u, a in pairs:
|
| 582 |
# include previous turns as completed instruction/response pairs
|
|
|
|
| 593 |
if MODEL is None or TOKENIZER is None:
|
| 594 |
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
|
| 595 |
|
| 596 |
+
# When using a chat-template prompt we must avoid adding special tokens again
|
| 597 |
+
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
|
| 598 |
+
|
| 599 |
+
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
|
| 600 |
|
| 601 |
gen_kwargs = dict(
|
| 602 |
input_ids=input_ids,
|
|
|
|
| 623 |
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
|
| 624 |
|
| 625 |
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 626 |
+
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
|
| 627 |
+
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
|
| 628 |
|
| 629 |
gen_kwargs = dict(
|
| 630 |
input_ids=input_ids,
|
|
|
|
| 663 |
# Append current user turn (assistant reply empty until generated)
|
| 664 |
pairs.append((str(user_message or ""), ""))
|
| 665 |
|
| 666 |
+
# Guard: block generation while model is loading or not loaded
|
| 667 |
+
if MODEL_LOADING:
|
| 668 |
+
pairs[-1] = (user_message, "⚠️ Model is still loading — please wait and try again. Check 'Status' for progress.")
|
| 669 |
+
yield pairs, ""
|
| 670 |
+
return
|
| 671 |
+
|
| 672 |
+
if MODEL is None:
|
| 673 |
+
pairs[-1] = (user_message, "⚠️ Model is not loaded — click 'Load model' first.")
|
| 674 |
+
yield pairs, ""
|
| 675 |
+
return
|
| 676 |
+
|
| 677 |
prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history)
|
| 678 |
|
| 679 |
# If user is running the full Nanbeige model on CPU, warn and suggest options
|
|
|
|
| 721 |
|
| 722 |
def load_model_ui(repo: str):
|
| 723 |
status = load_model(repo, force_reload=True)
|
| 724 |
+
try:
|
| 725 |
+
suffix = " — chat-template detected" if USE_CHAT_TEMPLATE else ""
|
| 726 |
+
except NameError:
|
| 727 |
+
suffix = ""
|
| 728 |
+
# enable the Send button only when the model actually loaded
|
| 729 |
+
loaded = str(status).lower().startswith("loaded")
|
| 730 |
+
from gradio import update as gr_update
|
| 731 |
+
send_state = gr_update(interactive=loaded)
|
| 732 |
+
return status + suffix, send_state
|
| 733 |
|
| 734 |
|
| 735 |
def apply_lora_adapter(adapter_repo: str):
|
|
|
|
| 800 |
apply_adapter = gr.Button("Apply LoRA adapter")
|
| 801 |
|
| 802 |
# Events
|
| 803 |
+
load_btn.click(fn=load_model_ui, inputs=model_input, outputs=[model_status, send])
|
| 804 |
repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status)
|
| 805 |
|
| 806 |
send.click(
|
|
|
|
| 826 |
|
| 827 |
# auto-load default model in background (non-blocking)
|
| 828 |
def _bg_initial_load():
|
| 829 |
+
# run load_model in a background thread to warm up model on Space startup
|
| 830 |
+
def _worker():
|
| 831 |
+
res = load_model(DEFAULT_MODEL, force_reload=False)
|
| 832 |
+
try:
|
| 833 |
+
# update UI Send button when loaded
|
| 834 |
+
from gradio import update as gr_update
|
| 835 |
+
interactive = str(res).lower().startswith("loaded")
|
| 836 |
+
send.update(interactive=interactive)
|
| 837 |
+
except Exception:
|
| 838 |
+
pass
|
| 839 |
+
return res
|
| 840 |
+
|
| 841 |
+
t = threading.Thread(target=_worker, daemon=True)
|
| 842 |
+
t.start()
|
| 843 |
+
return "Loading model in background..."
|
| 844 |
|
| 845 |
# For local smoke tests you can skip automatic model loading by setting
|
| 846 |
# environment variable `SKIP_AUTOLOAD=1` so the UI starts without loading
|
|
|
|
| 849 |
model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)"
|
| 850 |
else:
|
| 851 |
model_status.value = _bg_initial_load()
|
| 852 |
+
# disable Send while background load is in progress
|
| 853 |
+
try:
|
| 854 |
+
send.update(interactive=False)
|
| 855 |
+
except Exception:
|
| 856 |
+
pass
|
| 857 |
|
| 858 |
# CPU warning / demo hint (visible in UI)
|
| 859 |
gr.Markdown("""
|