PioTio commited on
Commit
ef417e5
·
verified ·
1 Parent(s): 8876dbe

Add tokenizer normalization retry in load_model

Browse files
Files changed (1) hide show
  1. app.py +245 -35
app.py CHANGED
@@ -37,6 +37,10 @@ TOKENIZER = None
37
  MODEL_NAME = None
38
  DEVICE = "cpu"
39
  MODEL_LOCK = threading.Lock()
 
 
 
 
40
 
41
  # ----------------------------- Utilities ---# ------------------------------
42
 
@@ -83,8 +87,18 @@ def _diagnose_and_fix_tokenizer_model(tok: AutoTokenizer, mdl: AutoModelForCausa
83
  # ensure pad token exists and ids/config align
84
  if getattr(tok, "pad_token", None) is None:
85
  tok.pad_token = getattr(tok, "eos_token", "[PAD]")
 
86
  try:
87
  tok.add_special_tokens({"pad_token": tok.pad_token})
 
 
 
 
 
 
 
 
 
88
  except Exception:
89
  pass
90
  try:
@@ -235,6 +249,25 @@ def repair_tokenizer_on_hub(repo_id: str) -> str:
235
 
236
  # ----------------------------- Model loading -------------------------------
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
239
  """Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support.
240
 
@@ -250,6 +283,11 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
250
  if MODEL is not None and MODEL_NAME == repo_id and not force_reload:
251
  return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})"
252
 
 
 
 
 
 
253
  MODEL = None
254
  TOKENIZER = None
255
  MODEL_NAME = repo_id
@@ -265,40 +303,135 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
265
  trust_remote_code=True,
266
  use_auth_token=hf_token,
267
  )
268
- except Exception as e_tok:
269
- # If a local repo was cloned without git-lfs, tokenizer.model may be a pointer file — try auto-fetch
270
  try:
271
- if os.path.isdir(repo_id) and _ensure_local_tokenizer_model(repo_id, hf_token=hf_token):
272
- print(f"Found LFS pointer at {repo_id}/tokenizer.model — fetched real tokenizer.model; retrying tokenizer load...")
273
- TOKENIZER = AutoTokenizer.from_pretrained(
274
- repo_id,
275
- use_fast=False,
276
- trust_remote_code=True,
277
- use_auth_token=hf_token,
278
- )
279
- # success continue to model load
280
- else:
281
- # fallback: try base model tokenizer (common fix when adapter upload missed tokenizer.model)
282
- print(f"Tokenizer load from {repo_id} failed: {e_tok}. Falling back to base tokenizer PioTio/Nanbeige2.5...")
283
- TOKENIZER = AutoTokenizer.from_pretrained(
284
- DEFAULT_MODEL,
285
- use_fast=False,
286
- trust_remote_code=True,
287
- use_auth_token=hf_token,
288
- )
289
- except Exception as e_base:
290
- # last-resort: try fast tokenizer (may still fail or produce garbled output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  try:
292
- print(f"Base tokenizer fallback failed: {e_base}. Trying generic AutoTokenizer...")
293
- TOKENIZER = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True, use_auth_token=hf_token)
294
- except Exception as e_final:
295
- return f"Tokenizer load failed: {e_final}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  # 2) Load model (prefer 4-bit on GPU if available)
298
  if DEVICE == "cuda" and HAS_BNB:
299
  try:
300
  bnb_config = BitsAndBytesConfig(load_in_4bit=True)
301
- MODEL = AutoModelForCausalLM.from_pretrained(
302
  repo_id,
303
  device_map="auto",
304
  quantization_config=bnb_config,
@@ -307,6 +440,8 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
307
  )
308
  MODEL.eval()
309
  _diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
 
 
310
  return f"Loaded {repo_id} (4-bit, device_map=auto)"
311
  except Exception as e:
312
  print("bnb/4bit load failed - falling back:", e)
@@ -314,7 +449,7 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
314
  # 3) FP16 / CPU fallback
315
  try:
316
  if DEVICE == "cuda":
317
- MODEL = AutoModelForCausalLM.from_pretrained(
318
  repo_id,
319
  device_map="auto",
320
  torch_dtype=torch.float16,
@@ -322,7 +457,7 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
322
  use_auth_token=hf_token,
323
  )
324
  else:
325
- MODEL = AutoModelForCausalLM.from_pretrained(
326
  repo_id,
327
  low_cpu_mem_usage=True,
328
  torch_dtype=torch.float32,
@@ -333,11 +468,15 @@ def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
333
 
334
  MODEL.eval()
335
  _diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
 
 
336
  return f"Loaded {repo_id} (@{DEVICE})"
337
  except Exception as e:
338
  MODEL = None
339
  TOKENIZER = None
340
- # provide a helpful diagnostic message
 
 
341
  return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)"
342
 
343
 
@@ -358,6 +497,8 @@ def _normalize_history(raw_history) -> List[Tuple[str, str]]:
358
  and return a list of (user, assistant) pairs suitable for prompt construction.
359
 
360
  Behavior: pairs each user message with the next assistant message (assistant may be "" if not present).
 
 
361
  """
362
  if not raw_history:
363
  return []
@@ -409,6 +550,33 @@ def build_prompt(history, user_input: str, system_prompt: str, max_history: int
409
  pairs = _normalize_history(history or [])
410
  pairs = pairs[-max_history:]
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  parts: List[str] = [f"System: {system_prompt}"]
413
  for u, a in pairs:
414
  # include previous turns as completed instruction/response pairs
@@ -425,7 +593,10 @@ def _generate_text(prompt: str, temperature: float, top_p: float, top_k: int, ma
425
  if MODEL is None or TOKENIZER is None:
426
  raise RuntimeError("Model is not loaded. Press 'Load model' first.")
427
 
428
- input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048).input_ids.to(next(MODEL.parameters()).device)
 
 
 
429
 
430
  gen_kwargs = dict(
431
  input_ids=input_ids,
@@ -452,7 +623,8 @@ def _generate_stream(prompt: str, temperature: float, top_p: float, top_k: int,
452
  raise RuntimeError("Model is not loaded. Press 'Load model' first.")
453
 
454
  streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
455
- input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048).input_ids.to(next(MODEL.parameters()).device)
 
456
 
457
  gen_kwargs = dict(
458
  input_ids=input_ids,
@@ -491,6 +663,17 @@ def submit_message(user_message: str, history, system_prompt: str, temperature:
491
  # Append current user turn (assistant reply empty until generated)
492
  pairs.append((str(user_message or ""), ""))
493
 
 
 
 
 
 
 
 
 
 
 
 
494
  prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history)
495
 
496
  # If user is running the full Nanbeige model on CPU, warn and suggest options
@@ -538,7 +721,15 @@ def regenerate(history, system_prompt: str, temperature: float, top_p: float, to
538
 
539
  def load_model_ui(repo: str):
540
  status = load_model(repo, force_reload=True)
541
- return status
 
 
 
 
 
 
 
 
542
 
543
 
544
  def apply_lora_adapter(adapter_repo: str):
@@ -609,7 +800,7 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
609
  apply_adapter = gr.Button("Apply LoRA adapter")
610
 
611
  # Events
612
- load_btn.click(fn=lambda repo: load_model_ui(repo), inputs=model_input, outputs=model_status)
613
  repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status)
614
 
615
  send.click(
@@ -635,7 +826,21 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
635
 
636
  # auto-load default model in background (non-blocking)
637
  def _bg_initial_load():
638
- return load_model(DEFAULT_MODEL, force_reload=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  # For local smoke tests you can skip automatic model loading by setting
641
  # environment variable `SKIP_AUTOLOAD=1` so the UI starts without loading
@@ -644,6 +849,11 @@ with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
644
  model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)"
645
  else:
646
  model_status.value = _bg_initial_load()
 
 
 
 
 
647
 
648
  # CPU warning / demo hint (visible in UI)
649
  gr.Markdown("""
 
37
  MODEL_NAME = None
38
  DEVICE = "cpu"
39
  MODEL_LOCK = threading.Lock()
40
+ # flag: whether a model load is currently in progress (prevents requests)
41
+ MODEL_LOADING = False
42
+ # flag: whether the loaded tokenizer exposes a chat template helper
43
+ USE_CHAT_TEMPLATE = False
44
 
45
  # ----------------------------- Utilities ---# ------------------------------
46
 
 
87
  # ensure pad token exists and ids/config align
88
  if getattr(tok, "pad_token", None) is None:
89
  tok.pad_token = getattr(tok, "eos_token", "[PAD]")
90
+ # Be defensive: different tokenizer backends expect different arg types
91
  try:
92
  tok.add_special_tokens({"pad_token": tok.pad_token})
93
+ except TypeError as e:
94
+ # try list form or add_tokens fallback
95
+ try:
96
+ tok.add_special_tokens([tok.pad_token])
97
+ except Exception:
98
+ try:
99
+ tok.add_tokens([tok.pad_token])
100
+ except Exception:
101
+ pass
102
  except Exception:
103
  pass
104
  try:
 
249
 
250
  # ----------------------------- Model loading -------------------------------
251
 
252
+
253
+ def _safe_model_from_pretrained(repo_id, *args, **kwargs):
254
+ """Call AutoModelForCausalLM.from_pretrained but retry without `use_auth_token`
255
+ if the called class improperly forwards unexpected kwargs into __init__.
256
+ """
257
+ try:
258
+ return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs)
259
+ except TypeError as e:
260
+ msg = str(e)
261
+ if "use_auth_token" in msg or "unexpected keyword argument" in msg:
262
+ # retry without auth-token kwargs (some remote `from_pretrained` may leak kwargs)
263
+ kwargs2 = dict(kwargs)
264
+ kwargs2.pop("use_auth_token", None)
265
+ kwargs2.pop("token", None)
266
+ print(f"_safe_model_from_pretrained: retrying without auth-token due to: {e}")
267
+ return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs2)
268
+ raise
269
+
270
+
271
  def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
272
  """Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support.
273
 
 
283
  if MODEL is not None and MODEL_NAME == repo_id and not force_reload:
284
  return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})"
285
 
286
+ # mark loading state so UI handlers can guard incoming requests
287
+ global MODEL_LOADING
288
+ MODEL_LOADING = True
289
+ print(f"Model load started: {repo_id}")
290
+
291
  MODEL = None
292
  TOKENIZER = None
293
  MODEL_NAME = repo_id
 
303
  trust_remote_code=True,
304
  use_auth_token=hf_token,
305
  )
306
+ print(f"Tokenizer loaded from repo: {repo_id}")
307
+ # detect whether tokenizer supports the Nanbeige chat template API
308
  try:
309
+ global USE_CHAT_TEMPLATE
310
+ USE_CHAT_TEMPLATE = hasattr(TOKENIZER, "apply_chat_template")
311
+ print(f"USE_CHAT_TEMPLATE={USE_CHAT_TEMPLATE}")
312
+ except Exception:
313
+ USE_CHAT_TEMPLATE = False
314
+ except Exception as e_tok:
315
+ print(f"Tokenizer load from {repo_id} failed: {e_tok}")
316
+ # specific fix: some tokenizers fail with 'Input must be a List...' when
317
+ # `special_tokens_map.json` contains dict entries instead of plain strings.
318
+ # Try an in-memory normalization + local retry before broader fallbacks/repairs.
319
+ if "Input must be a List" in str(e_tok) or "Input must be a List[Union[str, AddedToken]]" in str(e_tok):
320
+ try:
321
+ print('Detected tokenizer add-tokens type error; attempting in-place normalization and retry...')
322
+ # try to download tokenizer files and normalize special_tokens_map.json
323
+ try:
324
+ from huggingface_hub import hf_hub_download
325
+ import json, tempfile, shutil
326
+
327
+ tmp = tempfile.mkdtemp(prefix="tokfix_")
328
+ # files we need locally for AutoTokenizer
329
+ candidates = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "added_tokens.json"]
330
+ for fn in candidates:
331
+ try:
332
+ src = hf_hub_download(repo_id=repo_id, filename=fn, token=hf_token)
333
+ shutil.copy(src, tmp)
334
+ except Exception:
335
+ # ignore missing files — AutoTokenizer is tolerant
336
+ pass
337
+
338
+ # normalize special_tokens_map.json if present
339
+ stm = os.path.join(tmp, "special_tokens_map.json")
340
+ if os.path.exists(stm):
341
+ try:
342
+ with open(stm, "r", encoding="utf-8") as f:
343
+ stm_j = json.load(f)
344
+ changed = False
345
+ if "additional_special_tokens" in stm_j:
346
+ new = []
347
+ for it in stm_j["additional_special_tokens"]:
348
+ if isinstance(it, dict):
349
+ new.append(it.get("content") or it.get("token") or str(it))
350
+ changed = True
351
+ else:
352
+ new.append(it)
353
+ stm_j["additional_special_tokens"] = new
354
+ for k in ["bos_token", "eos_token", "pad_token", "unk_token"]:
355
+ if k in stm_j and isinstance(stm_j[k], dict):
356
+ stm_j[k] = stm_j[k].get("content", stm_j[k])
357
+ changed = True
358
+ if changed:
359
+ with open(stm, "w", encoding="utf-8") as f:
360
+ json.dump(stm_j, f, ensure_ascii=False, indent=2)
361
+ print('Normalized special_tokens_map.json in temp dir')
362
+ except Exception:
363
+ pass
364
+
365
+ # try loading tokenizer from the temporary normalized directory
366
+ TOKENIZER = AutoTokenizer.from_pretrained(tmp, use_fast=False, trust_remote_code=True)
367
+ print('Tokenizer reloaded from normalized temp copy')
368
+ shutil.rmtree(tmp)
369
+ except Exception as e_localnorm:
370
+ print('In-place normalization retry failed:', e_localnorm)
371
+ # fall through to the existing repair path below
372
+
373
+ # as a fallback, attempt to auto-repair the remote repo (if HF token available)
374
+ if hf_token:
375
+ print('Attempting repo-side auto-repair/upload from base tokenizer...')
376
+ _repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
377
+ TOKENIZER = AutoTokenizer.from_pretrained(repo_id, use_fast=False, trust_remote_code=True)
378
+ print('Tokenizer reloaded after repo repair')
379
+ else:
380
+ # final fallback will be handled by the outer fallbacks below
381
+ raise RuntimeError('Normalization + auto-repair could not proceed (no HF_TOKEN)')
382
+ except Exception as e_retry:
383
+ print('Repair/retry failed:', e_retry)
384
+ return f"Tokenizer load failed: {e_retry}"
385
+ else:
386
+ # If a local repo was cloned without git-lfs, tokenizer.model may be a pointer file — try auto-fetch
387
  try:
388
+ if os.path.isdir(repo_id) and _ensure_local_tokenizer_model(repo_id, hf_token=hf_token):
389
+ print(f"Found LFS pointer at {repo_id}/tokenizer.model fetched real tokenizer.model; retrying tokenizer load...")
390
+ TOKENIZER = AutoTokenizer.from_pretrained(
391
+ repo_id,
392
+ use_fast=False,
393
+ trust_remote_code=True,
394
+ use_auth_token=hf_token,
395
+ )
396
+ print(f"Tokenizer loaded from local repo after fetching LFS: {repo_id}")
397
+ else:
398
+ # Local workspace fallback: use bundled Nanbeige4.1 tokenizer if available
399
+ local_fallback = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'Nanbeige4.1-3B'))
400
+ if os.path.isdir(local_fallback):
401
+ try:
402
+ print(f"Attempting local workspace tokenizer fallback: {local_fallback}")
403
+ TOKENIZER = AutoTokenizer.from_pretrained(local_fallback, use_fast=False, trust_remote_code=True)
404
+ print(f"Tokenizer loaded from local workspace: {local_fallback}")
405
+ except Exception as e_local:
406
+ print(f"Local tokenizer fallback failed: {e_local}")
407
+ raise e_local
408
+ else:
409
+ # Try known base tokenizer on the Hub (Nanbeige4.1 if repo looks like 4.1)
410
+ base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else "PioTio/Nanbeige2.5"
411
+ print(f"Falling back to base tokenizer: {base}")
412
+ TOKENIZER = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True, use_auth_token=hf_token)
413
+
414
+ # If HF token is available, attempt to auto-repair/upload tokenizer files to the target repo
415
+ if hf_token:
416
+ try:
417
+ uploaded = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
418
+ print(f"Auto-repair attempt to {repo_id}: {'succeeded' if uploaded else 'no-change/failure'}")
419
+ except Exception as e_rep:
420
+ print(f"Auto-repair attempt failed: {e_rep}")
421
+ except Exception as e_base:
422
+ # last-resort: try fast tokenizer (may still fail or produce garbled output)
423
+ try:
424
+ print(f"All fallbacks failed: {e_base}. Trying generic AutoTokenizer as last resort...")
425
+ TOKENIZER = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True, use_auth_token=hf_token)
426
+ except Exception as e_final:
427
+ MODEL_LOADING = False
428
+ return f"Tokenizer load failed: {e_final}"
429
 
430
  # 2) Load model (prefer 4-bit on GPU if available)
431
  if DEVICE == "cuda" and HAS_BNB:
432
  try:
433
  bnb_config = BitsAndBytesConfig(load_in_4bit=True)
434
+ MODEL = _safe_model_from_pretrained(
435
  repo_id,
436
  device_map="auto",
437
  quantization_config=bnb_config,
 
440
  )
441
  MODEL.eval()
442
  _diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
443
+ MODEL_LOADING = False
444
+ print(f"Model load finished (4-bit): {repo_id}")
445
  return f"Loaded {repo_id} (4-bit, device_map=auto)"
446
  except Exception as e:
447
  print("bnb/4bit load failed - falling back:", e)
 
449
  # 3) FP16 / CPU fallback
450
  try:
451
  if DEVICE == "cuda":
452
+ MODEL = _safe_model_from_pretrained(
453
  repo_id,
454
  device_map="auto",
455
  torch_dtype=torch.float16,
 
457
  use_auth_token=hf_token,
458
  )
459
  else:
460
+ MODEL = _safe_model_from_pretrained(
461
  repo_id,
462
  low_cpu_mem_usage=True,
463
  torch_dtype=torch.float32,
 
468
 
469
  MODEL.eval()
470
  _diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
471
+ MODEL_LOADING = False
472
+ print(f"Model load finished: {repo_id} (@{DEVICE})")
473
  return f"Loaded {repo_id} (@{DEVICE})"
474
  except Exception as e:
475
  MODEL = None
476
  TOKENIZER = None
477
+ # clear loading flag and provide a helpful diagnostic message
478
+ MODEL_LOADING = False
479
+ print(f"Model load failed: {repo_id} -> {e}")
480
  return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)"
481
 
482
 
 
497
  and return a list of (user, assistant) pairs suitable for prompt construction.
498
 
499
  Behavior: pairs each user message with the next assistant message (assistant may be "" if not present).
500
+ NOTE: For chat-first models (Nanbeige4.1) we prefer `tokenizer.apply_chat_template` later
501
+ so this function only normalizes the history shape.
502
  """
503
  if not raw_history:
504
  return []
 
550
  pairs = _normalize_history(history or [])
551
  pairs = pairs[-max_history:]
552
 
553
+ # If tokenizer provides a chat-template helper (Nanbeige4.1), use it.
554
+ # This avoids instruction-format mismatches that produce garbled output.
555
+ try:
556
+ from __main__ import TOKENIZER # safe access to global TOKENIZER when available
557
+ except Exception:
558
+ TOKENIZER = None
559
+
560
+ if TOKENIZER is not None and hasattr(TOKENIZER, "apply_chat_template"):
561
+ # build messages list with optional system prompt first
562
+ messages = []
563
+ if system_prompt:
564
+ messages.append({"role": "system", "content": system_prompt})
565
+ for u, a in pairs:
566
+ messages.append({"role": "user", "content": u})
567
+ if a:
568
+ messages.append({"role": "assistant", "content": a})
569
+ # current user turn
570
+ messages.append({"role": "user", "content": user_input})
571
+ # use tokenizer's chat template (returns the full prompt string)
572
+ try:
573
+ prompt = TOKENIZER.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
574
+ return prompt
575
+ except Exception:
576
+ # fall back to ALPACA format if anything goes wrong
577
+ pass
578
+
579
+ # Default / fallback: ALPACA-style instruction template
580
  parts: List[str] = [f"System: {system_prompt}"]
581
  for u, a in pairs:
582
  # include previous turns as completed instruction/response pairs
 
593
  if MODEL is None or TOKENIZER is None:
594
  raise RuntimeError("Model is not loaded. Press 'Load model' first.")
595
 
596
+ # When using a chat-template prompt we must avoid adding special tokens again
597
+ add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
598
+
599
+ input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
600
 
601
  gen_kwargs = dict(
602
  input_ids=input_ids,
 
623
  raise RuntimeError("Model is not loaded. Press 'Load model' first.")
624
 
625
  streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
626
+ add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
627
+ input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
628
 
629
  gen_kwargs = dict(
630
  input_ids=input_ids,
 
663
  # Append current user turn (assistant reply empty until generated)
664
  pairs.append((str(user_message or ""), ""))
665
 
666
+ # Guard: block generation while model is loading or not loaded
667
+ if MODEL_LOADING:
668
+ pairs[-1] = (user_message, "⚠️ Model is still loading — please wait and try again. Check 'Status' for progress.")
669
+ yield pairs, ""
670
+ return
671
+
672
+ if MODEL is None:
673
+ pairs[-1] = (user_message, "⚠️ Model is not loaded — click 'Load model' first.")
674
+ yield pairs, ""
675
+ return
676
+
677
  prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history)
678
 
679
  # If user is running the full Nanbeige model on CPU, warn and suggest options
 
721
 
722
  def load_model_ui(repo: str):
723
  status = load_model(repo, force_reload=True)
724
+ try:
725
+ suffix = " — chat-template detected" if USE_CHAT_TEMPLATE else ""
726
+ except NameError:
727
+ suffix = ""
728
+ # enable the Send button only when the model actually loaded
729
+ loaded = str(status).lower().startswith("loaded")
730
+ from gradio import update as gr_update
731
+ send_state = gr_update(interactive=loaded)
732
+ return status + suffix, send_state
733
 
734
 
735
  def apply_lora_adapter(adapter_repo: str):
 
800
  apply_adapter = gr.Button("Apply LoRA adapter")
801
 
802
  # Events
803
+ load_btn.click(fn=load_model_ui, inputs=model_input, outputs=[model_status, send])
804
  repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status)
805
 
806
  send.click(
 
826
 
827
  # auto-load default model in background (non-blocking)
828
  def _bg_initial_load():
829
+ # run load_model in a background thread to warm up model on Space startup
830
+ def _worker():
831
+ res = load_model(DEFAULT_MODEL, force_reload=False)
832
+ try:
833
+ # update UI Send button when loaded
834
+ from gradio import update as gr_update
835
+ interactive = str(res).lower().startswith("loaded")
836
+ send.update(interactive=interactive)
837
+ except Exception:
838
+ pass
839
+ return res
840
+
841
+ t = threading.Thread(target=_worker, daemon=True)
842
+ t.start()
843
+ return "Loading model in background..."
844
 
845
  # For local smoke tests you can skip automatic model loading by setting
846
  # environment variable `SKIP_AUTOLOAD=1` so the UI starts without loading
 
849
  model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)"
850
  else:
851
  model_status.value = _bg_initial_load()
852
+ # disable Send while background load is in progress
853
+ try:
854
+ send.update(interactive=False)
855
+ except Exception:
856
+ pass
857
 
858
  # CPU warning / demo hint (visible in UI)
859
  gr.Markdown("""