gaoyang07 commited on
Commit
685e40d
·
1 Parent(s): c5b84ea

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -36
app.py CHANGED
@@ -108,6 +108,7 @@ def load_backend(model_path: str, device_str: str, attn_implementation: str):
108
  )
109
  if hasattr(processor, "audio_tokenizer"):
110
  processor.audio_tokenizer = processor.audio_tokenizer.to(device)
 
111
 
112
  model_kwargs = {
113
  "trust_remote_code": True,
@@ -558,21 +559,7 @@ def build_demo(args: argparse.Namespace):
558
  )
559
 
560
  run_btn.click(
561
- fn=lambda text, reference_audio, mode_with_reference, duration_control_enabled, duration_tokens, temperature, top_p, top_k, repetition_penalty, max_new_tokens: run_inference(
562
- text=text,
563
- reference_audio=reference_audio,
564
- mode_with_reference=mode_with_reference,
565
- duration_control_enabled=duration_control_enabled,
566
- duration_tokens=duration_tokens,
567
- temperature=temperature,
568
- top_p=top_p,
569
- top_k=top_k,
570
- repetition_penalty=repetition_penalty,
571
- model_path=args.model_path,
572
- device=args.device,
573
- attn_implementation=args.attn_implementation,
574
- max_new_tokens=max_new_tokens,
575
- ),
576
  inputs=[
577
  text,
578
  reference_audio,
@@ -583,6 +570,9 @@ def build_demo(args: argparse.Namespace):
583
  top_p,
584
  top_k,
585
  repetition_penalty,
 
 
 
586
  max_new_tokens,
587
  ],
588
  outputs=[output_audio, status],
@@ -617,19 +607,6 @@ def parse_port(value: str | None, default: int) -> int:
617
  return default
618
 
619
 
620
- def build_default_args() -> argparse.Namespace:
621
- return resolve_runtime_attn(
622
- argparse.Namespace(
623
- model_path=MODEL_PATH,
624
- device="cuda:0",
625
- attn_implementation=DEFAULT_ATTN_IMPLEMENTATION,
626
- host=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
627
- port=parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), 7860),
628
- share=False,
629
- )
630
- )
631
-
632
-
633
  def main():
634
  parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
635
  parser.add_argument("--model_path", type=str, default=MODEL_PATH)
@@ -680,12 +657,5 @@ def main():
680
  )
681
 
682
 
683
- # Expose a module-level demo for Gradio hot-reload/Spaces launcher.
684
- demo = build_demo(build_default_args())
685
-
686
-
687
  if __name__ == "__main__":
688
- if os.getenv("GRADIO_HOT_RELOAD"):
689
- print("[Startup] GRADIO_HOT_RELOAD detected. Skipping explicit launch().", flush=True)
690
- else:
691
- main()
 
108
  )
109
  if hasattr(processor, "audio_tokenizer"):
110
  processor.audio_tokenizer = processor.audio_tokenizer.to(device)
111
+ processor.audio_tokenizer.eval()
112
 
113
  model_kwargs = {
114
  "trust_remote_code": True,
 
559
  )
560
 
561
  run_btn.click(
562
+ fn=run_inference,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  inputs=[
564
  text,
565
  reference_audio,
 
570
  top_p,
571
  top_k,
572
  repetition_penalty,
573
+ gr.State(args.model_path),
574
+ gr.State(args.device),
575
+ gr.State(args.attn_implementation),
576
  max_new_tokens,
577
  ],
578
  outputs=[output_audio, status],
 
607
  return default
608
 
609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  def main():
611
  parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
612
  parser.add_argument("--model_path", type=str, default=MODEL_PATH)
 
657
  )
658
 
659
 
 
 
 
 
660
  if __name__ == "__main__":
661
+ main()