immango commited on
Commit
637d3ba
·
verified ·
1 Parent(s): 4fd25de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +470 -0
app.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import tempfile
4
+ import os
5
+ import sys
6
+ import shutil
7
+ from pathlib import Path
8
+ import time
9
+
10
+ # 输出 Gradio 版本信息
11
+ print(f"===== Application Startup at {time.strftime('%Y-%m-%d %H:%M:%S')} =====")
12
+ print(f"Gradio version: {gr.__version__}")
13
+ print(f"Python version: {sys.version}")
14
+ print(f"Python executable: {sys.executable}")
15
+ print("=" * 60)
16
+
17
+ class Wan2S2VPipeline:
18
+ def __init__(self):
19
+ self.model_loaded = False
20
+ self.model_path = None
21
+ self.script_path = None
22
+ self.ckpt_dir = None
23
+ self.model_repo = "Wan-AI/Wan2.2-S2V-14B"
24
+
25
+ def load_model(self):
26
+ """下载Wan2.2-S2V-14B模型和脚本"""
27
+ try:
28
+ if self.model_loaded:
29
+ return True, "模型已加载"
30
+
31
+ # 设置工作目录(使用持久目录)
32
+ work_dir = "/tmp/wan2.2"
33
+ os.makedirs(work_dir, exist_ok=True)
34
+
35
+ # 步骤1: 克隆官方代码仓库
36
+ print("步骤1: 克隆官方代码仓库...")
37
+ repo_path = os.path.join(work_dir, "Wan2.2")
38
+
39
+ if not os.path.exists(os.path.join(repo_path, ".git")):
40
+ # 如果目录不存在或不是git仓库,则克隆
41
+ if os.path.exists(repo_path):
42
+ shutil.rmtree(repo_path)
43
+
44
+ result = subprocess.run(
45
+ ["git", "clone", "https://github.com/Wan-Video/Wan2.2.git", repo_path],
46
+ capture_output=True,
47
+ text=True,
48
+ timeout=300
49
+ )
50
+
51
+ if result.returncode != 0:
52
+ return False, f"❌ 克隆代码仓库失败: {result.stderr}"
53
+
54
+ print("✅ 代码仓库克隆成功")
55
+ else:
56
+ print("✅ 代码仓库已存在,跳过克隆")
57
+
58
+ # 步骤2: 下载模型权重
59
+ print("步骤2: 下载模型权重...")
60
+ model_dir = os.path.join(work_dir, "Wan2.2-S2V-14B")
61
+
62
+ if not os.path.exists(model_dir):
63
+ from huggingface_hub import snapshot_download
64
+
65
+ print(f"正在下载模型 {self.model_repo}...")
66
+ model_path = snapshot_download(
67
+ repo_id=self.model_repo,
68
+ cache_dir="/tmp/hf_cache",
69
+ local_dir=model_dir,
70
+ local_dir_use_symlinks=False
71
+ )
72
+ print(f"✅ 模型权重下载完成: {model_path}")
73
+ else:
74
+ print("✅ 模型权重已存在,跳过下载")
75
+
76
+ # 步骤3: 安装依赖
77
+ print("步骤3: 安装依赖...")
78
+ requirements_file = os.path.join(repo_path, "requirements.txt")
79
+ if os.path.exists(requirements_file):
80
+ try:
81
+ result = subprocess.run(
82
+ [sys.executable, "-m", "pip", "install", "-r", requirements_file],
83
+ capture_output=True,
84
+ text=True,
85
+ timeout=600,
86
+ cwd=repo_path
87
+ )
88
+ if result.returncode == 0:
89
+ print("✅ 依赖安装成功")
90
+ else:
91
+ print(f"⚠️ 依赖安装警告: {result.stderr}")
92
+ except Exception as e:
93
+ print(f"⚠️ 依赖安装跳过: {e}")
94
+ else:
95
+ print("⚠️ 未找到 requirements.txt,跳过依赖安装")
96
+
97
+ # 步骤4: 设置路径
98
+ self.model_path = repo_path
99
+ self.script_path = os.path.join(repo_path, "generate.py")
100
+ self.ckpt_dir = model_dir
101
+
102
+ # 验证文件
103
+ if not os.path.exists(self.script_path):
104
+ return False, "❌ 未找到 generate.py 脚本"
105
+
106
+ if not os.path.exists(self.ckpt_dir):
107
+ return False, "❌ 未找到模型权重目录"
108
+
109
+ self.model_loaded = True
110
+ print("🎉 Wan2.2-S2V-14B 模型准备完成!")
111
+ return True, "✅ 模型加载成功!"
112
+
113
+ except Exception as e:
114
+ error_msg = f"模型加载失败: {str(e)}"
115
+ print(error_msg)
116
+ return False, f"❌ {error_msg}"
117
+
118
+ def generate(self, task, size, prompt, image_file, audio_file,
119
+ num_frames=16, guidance_scale=7.5,
120
+ num_inference_steps=20, seed=-1, offload_model=True,
121
+ convert_model_dtype=True):
122
+ """执行Wan2.2-S2V-14B生成命令"""
123
+ try:
124
+ if not self.model_loaded:
125
+ success, message = self.load_model()
126
+ if not success:
127
+ return None, message
128
+
129
+ # 设置环境变量解决 OMP_NUM_THREADS 问题
130
+ env = os.environ.copy()
131
+ env["OMP_NUM_THREADS"] = "1"
132
+ env["TOKENIZERS_PARALLELISM"] = "false"
133
+
134
+ # 验证必需参数
135
+ if not prompt or not prompt.strip():
136
+ return None, "❌ 提示词不能为空"
137
+ if not image_file:
138
+ return None, "❌ 请上传输入图片"
139
+ if not audio_file:
140
+ return None, "❌ 请上传输入音频"
141
+
142
+ # 构建命令行参数
143
+ cmd = [sys.executable, self.script_path]
144
+
145
+ # 必需参数
146
+ cmd.extend(["--task", task])
147
+ cmd.extend(["--size", size])
148
+ cmd.extend(["--ckpt_dir", self.ckpt_dir])
149
+ cmd.extend(["--prompt", prompt])
150
+ cmd.extend(["--image", image_file])
151
+ cmd.extend(["--audio", audio_file])
152
+
153
+ # 可选参数
154
+ if num_frames is not None:
155
+ cmd.extend(["--frame_num", str(num_frames)])
156
+ # 使用 infer_frames 替代 fps 参数
157
+ cmd.extend(["--infer_frames", str(num_frames)])
158
+ if guidance_scale is not None:
159
+ cmd.extend(["--sample_guide_scale", str(guidance_scale)])
160
+ if num_inference_steps is not None:
161
+ cmd.extend(["--sample_steps", str(num_inference_steps)])
162
+ if seed is not None and seed != -1:
163
+ cmd.extend(["--base_seed", str(seed)])
164
+
165
+ # 模型优化参数
166
+ if offload_model:
167
+ cmd.extend(["--offload_model", "True"])
168
+ else:
169
+ cmd.extend(["--offload_model", "False"])
170
+ if convert_model_dtype:
171
+ cmd.append("--convert_model_dtype")
172
+
173
+ print(f"执行命令: {' '.join(cmd)}")
174
+
175
+ # 创建临时输出目录
176
+ output_dir = os.path.join(self.model_path, "outputs")
177
+ os.makedirs(output_dir, exist_ok=True)
178
+
179
+ # 执行命令(实时输出日志)
180
+ start_time = time.time()
181
+ print("🚀 开始执行 generate.py 脚本...")
182
+ print("=" * 50)
183
+
184
+ # 使用 Popen 实现实时日志输出
185
+ process = subprocess.Popen(
186
+ cmd,
187
+ stdout=subprocess.PIPE,
188
+ stderr=subprocess.STDOUT, # 将 stderr 重定向到 stdout
189
+ text=True,
190
+ bufsize=1, # 行缓冲
191
+ cwd=self.model_path,
192
+ env=env
193
+ )
194
+
195
+ # 实时读取输出(带超时检查)
196
+ all_output = []
197
+ start_read_time = time.time()
198
+ timeout_seconds = 3600 # 10分钟超时
199
+
200
+ while True:
201
+ # 检查是否超时
202
+ if time.time() - start_read_time > timeout_seconds:
203
+ process.terminate() # 尝试优雅终止
204
+ try:
205
+ process.wait(timeout=10) # 等待10秒
206
+ except subprocess.TimeoutExpired:
207
+ process.kill() # 强制终止
208
+ raise subprocess.TimeoutExpired(cmd, timeout_seconds)
209
+
210
+ # 尝试读取输出(非阻塞)
211
+ output_line = process.stdout.readline()
212
+ if output_line == '' and process.poll() is not None:
213
+ break
214
+ if output_line:
215
+ output_line = output_line.strip()
216
+ if output_line: # 忽略空行
217
+ print(f"[generate.py] {output_line}")
218
+ all_output.append(output_line)
219
+ # 重置超时计时器(有输出说明脚本还在运行)
220
+ start_read_time = time.time()
221
+
222
+ # 等待进程完成
223
+ return_code = process.wait()
224
+ execution_time = time.time() - start_time
225
+
226
+ print("=" * 50)
227
+ print(f"脚本执行完成,返回码: {return_code}")
228
+ print(f"总耗时: {execution_time:.1f}秒")
229
+
230
+ if return_code == 0:
231
+ print("✅ 命令执行成功")
232
+
233
+ # 构建详细的成功消息
234
+ success_msg = f"✅ 生成成功!耗时: {execution_time:.1f}秒\n\n"
235
+ if all_output:
236
+ success_msg += f"脚本输出:\n" + "\n".join(all_output) + "\n"
237
+
238
+ # 查找输出文件
239
+ output_files = self._find_output_files()
240
+ if output_files:
241
+ # 直接返回原始输出文件路径
242
+ output_file = output_files[0]
243
+ print(f"找到输出文件: {output_file}")
244
+ return output_file, success_msg
245
+ else:
246
+ return None, f"⚠️ 生成成功但未找到输出文件\n\n脚本输出:\n" + "\n".join(all_output)
247
+ else:
248
+ # 构建详细的错误消息
249
+ error_msg = f"脚本执行失败,返回码: {return_code}\n\n"
250
+ if all_output:
251
+ error_msg += f"脚本输出:\n" + "\n".join(all_output)
252
+ else:
253
+ error_msg += "无输出信息"
254
+
255
+ print(f"❌ 命令执行失败: {error_msg}")
256
+ return None, f"❌ 生成失败:\n{error_msg}"
257
+
258
+ except subprocess.TimeoutExpired:
259
+ return None, "⏰ 生成超时(10分钟),请尝试减少参数或检查模型状态"
260
+ except Exception as e:
261
+ error_msg = f"执行失败: {str(e)}"
262
+ print(error_msg)
263
+ return None, f"❌ {error_msg}"
264
+
265
+ def _find_output_files(self):
266
+ """查找输出文件"""
267
+ output_extensions = ['.mp4', '.gif', '.avi', '.mov', '.png', '.jpg', '.jpeg']
268
+ output_files = []
269
+
270
+ # 优先搜索 outputs 目录
271
+ outputs_dir = os.path.join(self.model_path, "outputs")
272
+ if os.path.exists(outputs_dir):
273
+ for ext in output_extensions:
274
+ for file_path in Path(outputs_dir).rglob(f"*{ext}"):
275
+ if file_path.is_file():
276
+ output_files.append(str(file_path))
277
+ print(f"在 outputs 目录找到文件: {file_path}")
278
+
279
+ # 如果没有找到,搜索整个模型目录
280
+ if not output_files:
281
+ print("在 outputs 目录未找到文件,搜索整个模型目录...")
282
+ for ext in output_extensions:
283
+ for file_path in Path(self.model_path).rglob(f"*{ext}"):
284
+ if file_path.is_file():
285
+ # 排除一些不需要的文件
286
+ file_path_str = str(file_path)
287
+ if not any(exclude in file_path_str.lower() for exclude in ['.git', '__pycache__', 'node_modules']):
288
+ output_files.append(file_path_str)
289
+ print(f"在模型目录找到文件: {file_path_str}")
290
+
291
+ # 按修改时间排序,最新的文件在前面
292
+ if output_files:
293
+ output_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
294
+ print(f"找到 {len(output_files)} 个输出文件,按时间排序")
295
+
296
+ return output_files
297
+
298
+ def _copy_output_for_display(self, output_file):
299
+ """复制输出文件到临时目录以便Gradio显示(已弃用)"""
300
+ # 此方法已不再使用,直接返回原始文件路径
301
+ print(f"直接使用原始文件: {output_file}")
302
+ return output_file
303
+
304
+ # 创建全局实例
305
+ pipeline = Wan2S2VPipeline()
306
+
307
+ def generate_interface(task, size, prompt, image_file, audio_file,
308
+ num_frames, guidance_scale, num_inference_steps,
309
+ seed, offload_model, convert_model_dtype):
310
+ """Gradio 界面函数"""
311
+ # 执行生成
312
+ result, message = pipeline.generate(
313
+ task=task,
314
+ size=size,
315
+ prompt=prompt,
316
+ image_file=image_file,
317
+ audio_file=audio_file,
318
+ num_frames=num_frames,
319
+ guidance_scale=guidance_scale,
320
+ num_inference_steps=num_inference_steps,
321
+ seed=seed,
322
+ offload_model=offload_model,
323
+ convert_model_dtype=convert_model_dtype
324
+ )
325
+
326
+ return result, message
327
+
328
+ def load_model_interface():
329
+ """加载模型界面函数"""
330
+ success, message = pipeline.load_model()
331
+ return message
332
+
333
+ # 创建 Gradio 界面
334
+ with gr.Blocks(title="Wan2.2-S2V-14B 视频生成器") as demo:
335
+ gr.Markdown("""
336
+
337
+ # 使用前说明:本项目无法正常运行是因为没有选择GPU部署
338
+ # 完整的运行,请参考工程Files或者复制这个space,部署时最低选择 Nvidia 1xL40S 48G VRAM
339
+
340
+ # 🎬 Wan2.2-S2V-14B 视频生成器
341
+
342
+ **模型介绍**: Wan2.2-S2V-14B 是一个强大的图像到视频生成模型,支持音频引导。
343
+
344
+ **使用方法**:
345
+ 1. 点击"🚀 加载模型"按钮下载模型
346
+ 2. 填写提示词、上传图片和音频
347
+ 3. 调整参数后点击"🎬 开始生成"
348
+
349
+ **注意**: 首次使用需要下载约14GB的模型文件,请耐心等待。
350
+ """)
351
+
352
+ with gr.Row():
353
+ with gr.Column(scale=1):
354
+ # 模型加载
355
+ gr.Markdown("### 📥 模型管理")
356
+ load_btn = gr.Button("🚀 加载模型", variant="primary", size="lg")
357
+ load_status = gr.Textbox(label="模型状态", interactive=False, value="等待加载模型...")
358
+
359
+ # 必需参数
360
+ gr.Markdown("### 📝 必需参数")
361
+ task = gr.Textbox(
362
+ label="任务类型",
363
+ value="s2v-14B",
364
+ interactive=False
365
+ )
366
+ size = gr.Dropdown(
367
+ label="分辨率",
368
+ choices=["1024*704", "1024*1024", "704*1024", "512*512"],
369
+ value="1024*704"
370
+ )
371
+
372
+ prompt = gr.Textbox(
373
+ label="提示词 *",
374
+ lines=3,
375
+ placeholder="例如: Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."
376
+ )
377
+
378
+ image = gr.Image(
379
+ label="输入图片 *",
380
+ type="filepath"
381
+ )
382
+ audio = gr.Audio(
383
+ label="输入音频 *",
384
+ type="filepath"
385
+ )
386
+
387
+ # 高级参数
388
+ with gr.Accordion("🔧 高级参数", open=False):
389
+ num_frames = gr.Slider(
390
+ 8, 32, 16,
391
+ step=1,
392
+ label="帧数 (frame_num/infer_frames)"
393
+ )
394
+ guidance_scale = gr.Slider(
395
+ 1.0, 20.0, 7.5,
396
+ step=0.1,
397
+ label="引导强度 (sample_guide_scale)"
398
+ )
399
+ num_inference_steps = gr.Slider(
400
+ 10, 100, 20,
401
+ step=1,
402
+ label="推理步数 (sample_steps)"
403
+ )
404
+ seed = gr.Number(
405
+ label="随机种子 (base_seed)",
406
+ value=-1
407
+ )
408
+
409
+ with gr.Row():
410
+ offload_model = gr.Checkbox(
411
+ label="模型卸载",
412
+ value=True
413
+ )
414
+ convert_model_dtype = gr.Checkbox(
415
+ label="转换数据类型",
416
+ value=True
417
+ )
418
+
419
+ # 生成按钮
420
+ generate_btn = gr.Button("🎬 开始生成", variant="primary", size="lg")
421
+
422
+ with gr.Column(scale=1):
423
+ # 输出结果
424
+ gr.Markdown("### 🎥 生成结果")
425
+ output = gr.File(label="输出视频")
426
+ status = gr.Textbox(label="生成状态", interactive=False, lines=3)
427
+
428
+ # 使用说明
429
+ gr.Markdown("""
430
+ ### 📋 使用说明
431
+
432
+ **参数说明**:
433
+ - **分辨率**: 选择适合你需求的视频尺寸
434
+ - **提示词**: 用英文描述想要的视频内容,越详细越好
435
+ - **图片**: 上传参考图片,模型会基于此生成视频
436
+ - **音频**: 上传音频文件,模型会结合音频内容生成视频
437
+
438
+ **高级参数**:
439
+ - **帧数 (frame_num/infer_frames)**: 控制视频长度,8-32帧
440
+ - **引导强度 (sample_guide_scale)**: 生成质量控制,1.0-20.0
441
+ - **推理步数 (sample_steps)**: 生成精度,10-100步
442
+ - **随机种子 (base_seed)**: 结果重现,-1为随机
443
+
444
+ **优化建议**:
445
+ - 首次使用建议保持默认参数
446
+ - 如果显存不足,可以降低分辨率和帧数
447
+ - 提示词使用英文效果更好
448
+ - 音频文件建议使用清晰的语音或音乐
449
+
450
+ **注意事项**:
451
+ - 生成时间取决于参数设置,通常需要5-10分钟
452
+ - 确保上传的图片和音频文件格式正确
453
+ - 如果遇到错误,请检查参数设置和文件格式
454
+ """)
455
+
456
+ # 事件绑定
457
+ load_btn.click(load_model_interface, outputs=load_status)
458
+ generate_btn.click(
459
+ generate_interface,
460
+ inputs=[
461
+ task, size, prompt, image, audio,
462
+ num_frames, guidance_scale, num_inference_steps,
463
+ seed, offload_model, convert_model_dtype
464
+ ],
465
+ outputs=[output, status]
466
+ )
467
+
468
+ # 启动应用
469
+ if __name__ == "__main__":
470
+ demo.launch(server_name="0.0.0.0", server_port=7860)