Sayoyo commited on
Commit
f8052f0
·
1 Parent(s): 109427f

feat: gitignore python_embeded

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. profile_inference.py +43 -5
.gitignore CHANGED
@@ -223,4 +223,5 @@ torchinductor_root/
223
  scripts/
224
  checkpoints_legacy/
225
  lora_output/
226
- datasets/
 
 
223
  scripts/
224
  checkpoints_legacy/
225
  lora_output/
226
+ datasets/
227
+ python_embeded/
profile_inference.py CHANGED
@@ -36,6 +36,34 @@ project_root = os.path.abspath(os.path.dirname(__file__))
36
  if project_root not in sys.path:
37
  sys.path.insert(0, project_root)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  import torch
40
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
41
  from acestep.handler import AceStepHandler
@@ -522,14 +550,21 @@ def load_example_config(example_file: str) -> Tuple[GenerationParams, Generation
522
  def main():
523
  global timer, llm_debugger
524
 
 
 
 
525
  parser = argparse.ArgumentParser(
526
  description="Profile ACE-Step inference with LLM debugging"
527
  )
528
  parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
529
- parser.add_argument("--config-path", type=str, default="acestep-v15-turbo-rl")
530
- parser.add_argument("--device", type=str, default="cuda")
531
- parser.add_argument("--lm-model", type=str, default="acestep-5Hz-lm-0.6B-v3")
532
- parser.add_argument("--lm-backend", type=str, default="vllm")
 
 
 
 
533
  parser.add_argument("--no-warmup", action="store_true")
534
  parser.add_argument("--detailed", action="store_true")
535
  parser.add_argument("--llm-debug", action="store_true",
@@ -553,7 +588,10 @@ def main():
553
  print("=" * 100)
554
  print("🎵 ACE-Step Inference Profiler (LLM Performance Analysis)")
555
  print("=" * 100)
556
- print(f"\nConfiguration:")
 
 
 
557
  print(f" Device: {args.device}")
558
  print(f" LLM Backend: {args.lm_backend}")
559
  print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")
 
36
  if project_root not in sys.path:
37
  sys.path.insert(0, project_root)
38
 
39
+
40
+ def load_env_config():
41
+ """从 .env 文件加载配置"""
42
+ env_config = {
43
+ 'ACESTEP_CONFIG_PATH': 'acestep-v15-turbo-rl',
44
+ 'ACESTEP_LM_MODEL_PATH': 'acestep-5Hz-lm-0.6B-v3',
45
+ 'ACESTEP_DEVICE': 'auto',
46
+ 'ACESTEP_LM_BACKEND': 'vllm',
47
+ }
48
+
49
+ env_file = os.path.join(project_root, '.env')
50
+ if os.path.exists(env_file):
51
+ with open(env_file, 'r', encoding='utf-8') as f:
52
+ for line in f:
53
+ line = line.strip()
54
+ # 跳过空行和注释
55
+ if not line or line.startswith('#'):
56
+ continue
57
+ # 解析键值对
58
+ if '=' in line:
59
+ key, value = line.split('=', 1)
60
+ key = key.strip()
61
+ value = value.strip()
62
+ if key in env_config and value:
63
+ env_config[key] = value
64
+
65
+ return env_config
66
+
67
  import torch
68
  from acestep.inference import generate_music, GenerationParams, GenerationConfig
69
  from acestep.handler import AceStepHandler
 
550
  def main():
551
  global timer, llm_debugger
552
 
553
+ # 从 .env 文件加载默认配置
554
+ env_config = load_env_config()
555
+
556
  parser = argparse.ArgumentParser(
557
  description="Profile ACE-Step inference with LLM debugging"
558
  )
559
  parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints")
560
+ parser.add_argument("--config-path", type=str, default=env_config['ACESTEP_CONFIG_PATH'],
561
+ help=f"模型配置路径 (默认从 .env: {env_config['ACESTEP_CONFIG_PATH']})")
562
+ parser.add_argument("--device", type=str, default=env_config['ACESTEP_DEVICE'],
563
+ help=f"设备 (默认从 .env: {env_config['ACESTEP_DEVICE']})")
564
+ parser.add_argument("--lm-model", type=str, default=env_config['ACESTEP_LM_MODEL_PATH'],
565
+ help=f"LLM 模型路径 (默认从 .env: {env_config['ACESTEP_LM_MODEL_PATH']})")
566
+ parser.add_argument("--lm-backend", type=str, default=env_config['ACESTEP_LM_BACKEND'],
567
+ help=f"LLM 后端 (默认从 .env: {env_config['ACESTEP_LM_BACKEND']})")
568
  parser.add_argument("--no-warmup", action="store_true")
569
  parser.add_argument("--detailed", action="store_true")
570
  parser.add_argument("--llm-debug", action="store_true",
 
588
  print("=" * 100)
589
  print("🎵 ACE-Step Inference Profiler (LLM Performance Analysis)")
590
  print("=" * 100)
591
+ print(f"\n模型配置 (从 .env 加载):")
592
+ print(f" DiT 模型: {args.config_path}")
593
+ print(f" LLM 模型: {args.lm_model}")
594
+ print(f"\n运行配置:")
595
  print(f" Device: {args.device}")
596
  print(f" LLM Backend: {args.lm_backend}")
597
  print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}")