tangchao5355 commited on
Commit
1192c0b
·
verified ·
1 Parent(s): 5c3d7fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -112
app.py CHANGED
@@ -1,136 +1,290 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline, AutoTokenizer, T5ForConditionalGeneration
4
- from diffusers import StableDiffusionPipeline
5
- import speech_recognition as sr
6
- import gc
7
- from accelerate import init_empty_weights
8
-
9
- # ===== 模型初始化 =====
10
- def load_models():
11
- # Prompt增强模型
12
- prompt_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
13
- prompt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
14
-
15
- # Stable Diffusion管道
16
- sd_pipe = StableDiffusionPipeline.from_pretrained(
17
- "runwayml/stable-diffusion-v1-5",
18
- torch_dtype=torch.float32,
19
- use_safetensors=True,
20
- variant="fp16",
21
- device_map="auto",
22
- offload_state_dict=True
23
- )
24
- sd_pipe.enable_attention_slicing()
25
- sd_pipe.enable_sequential_cpu_offload()
 
 
 
 
 
 
 
 
26
 
27
- return prompt_model, prompt_tokenizer, sd_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- prompt_model, prompt_tokenizer, sd_pipe = load_models()
30
 
31
- # ===== 核心功能 =====
32
- def enhance_prompt(raw_input, style_choice):
33
- template = f"Generate a detailed Stable Diffusion prompt about: {raw_input} in {style_choice} style."
34
- inputs = prompt_tokenizer(template, return_tensors="pt")
35
- outputs = prompt_model.generate(inputs.input_ids, max_length=100)
36
- return prompt_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
38
- def generate_image(enhanced_prompt, steps=20, guidance=7.5):
39
  try:
40
- image = sd_pipe(
41
- enhanced_prompt,
42
- num_inference_steps=int(steps),
43
- guidance_scale=guidance,
44
- generator=torch.Generator().manual_seed(42)
45
- ).images[0]
46
- finally:
47
- # 清理内存
48
- gc.collect()
49
- with init_empty_weights():
50
- reload_models()
51
- return image
52
-
53
- def reload_models():
54
- global sd_pipe
55
- del sd_pipe
56
- sd_pipe = StableDiffusionPipeline.from_pretrained(
57
- "runwayml/stable-diffusion-v1-5",
58
- torch_dtype=torch.float32,
59
- device_map="auto",
60
- offload_folder="offload"
61
  )
 
 
 
62
 
63
- # ===== 语音处理 =====
64
- recognizer = sr.Recognizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def audio_to_text(audio_file):
67
- if not audio_file:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return ""
 
69
  try:
70
- with sr.AudioFile(audio_file) as source:
71
- audio = recognizer.record(source)
72
- return recognizer.recognize_whisper(audio, model="tiny.en")
73
  except Exception as e:
74
- print(f"语音识别错误: {e}")
75
  return ""
76
 
77
- # ===== Gradio界面 =====
78
- with gr.Blocks(title="AI Art Studio", css=".gradio-container {max-width: 800px !important}") as app:
79
- gr.Markdown("## 🎨 AI 艺术生成器 (CPU优化版)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
81
  with gr.Row():
82
- with gr.Column(scale=2):
83
- # 输入控件
84
- input_type = gr.Radio(["文字", "语音"], label="输入方式", value="文字")
85
- voice_input = gr.Audio(
86
- sources=["upload"],
87
- type="filepath",
88
- visible=False,
89
- label="上传语音文件",
90
- elem_classes="voice-input"
91
  )
92
- text_input = gr.Textbox(label="输入描述", placeholder="例:空中的魔法树屋...", lines=3)
93
 
94
- # 风格选择
95
- style_choice = gr.Dropdown(
96
- ["数字艺术", "油画", "动漫", "照片写实"],
97
- value="数字艺术",
98
- label="艺术风格"
99
- )
100
 
101
- # 生成按钮
102
- generate_btn = gr.Button("生成作品", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # 高级设置
105
- with gr.Accordion("高级设置", open=False):
106
- steps_slider = gr.Slider(10, 30, value=20, step=1, label="生成步数")
107
- guidance_slider = gr.Slider(5.0, 10.0, value=7.5, label="创意自由度")
108
-
109
- with gr.Column(scale=3):
110
- # 输出展示
111
- prompt_output = gr.Textbox(label="优化后的Prompt", interactive=False)
112
- image_output = gr.Image(label="生成结果", show_label=False, elem_id="output-image")
113
-
114
- # 交互逻辑
115
- input_type.change(
116
- fn=lambda x: gr.update(visible=x == "语音"),
117
- inputs=input_type,
118
- outputs=voice_input
119
- )
120
-
121
- generate_btn.click(
122
- fn=audio_to_text,
123
- inputs=voice_input,
124
- outputs=text_input
125
- ).success(
126
- fn=enhance_prompt,
127
- inputs=[text_input, style_choice],
128
- outputs=prompt_output
129
- ).success(
130
- fn=generate_image,
131
- inputs=[prompt_output, steps_slider, guidance_slider],
132
- outputs=image_output
133
- )
134
 
 
135
  if __name__ == "__main__":
136
  app.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline, set_seed
4
+ from diffusers import AutoPipelineForText2Image
5
+ import openai
6
+ import os
7
+ import time
8
+ import traceback
9
+ from typing import Optional, Tuple, Union, Literal, TypedDict
10
+ from PIL import Image
11
+
12
+ # ---- 类型定义 ----
13
+ class ModelConfig(TypedDict):
14
+ model_id: str
15
+ dtype: torch.dtype
16
+ timeout: int
17
+
18
+ class UIConfig(TypedDict):
19
+ title: str
20
+ description: str
21
+ warning_css: str
22
+
23
+ # ---- 配置管理 ----
24
+ class AppConfig:
25
+ # 硬件配置
26
+ DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ # 模型配置
29
+ MODEL: ModelConfig = {
30
+ "model_id": "nota-ai/bk-sdm-tiny",
31
+ "dtype": torch.float32,
32
+ "timeout": 300
33
+ }
34
 
35
+ # 界面配置
36
+ UI: UIConfig = {
37
+ "title": "🎨 轻量级AI图像生成器(CPU/GPU版)",
38
+ "description": """\
39
+ 💡 使用技巧:输入简短描述后选择风格和质量选项\n
40
+ 🚀 支持语音输入 • 自动提示词优化 • 快速生成模式\n
41
+ ⚠️ 注意:小模型生成速度快但细节有限,建议使用具体描述""",
42
+ "warning_css": """
43
+ .warning {color: orange !important; border-left: 3px solid orange; padding: 10px;}
44
+ .success {color: green !important;}
45
+ """
46
+ }
47
+
48
+ # 生成参数
49
+ DEFAULT_STEPS: int = 20
50
+ MAX_STEPS: int = 40
51
+ DEFAULT_GUIDANCE: float = 5.0
52
+
53
+ # 错误模板
54
+ @staticmethod
55
+ def error_msg(message: str) -> str:
56
+ return f"❌ 错误:{message}"
57
 
58
+ config = AppConfig()
59
 
60
+ # ---- 初始化检查 ----
61
+ openai_client: Optional[openai.OpenAI] = None
62
+ openai_available: bool = False
 
 
 
63
 
64
+ if os.environ.get("OPENAI_API_KEY"):
65
  try:
66
+ openai_client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
67
+ openai_available = True
68
+ print("✅ OpenAI 客户端初始化成功")
69
+ except Exception as e:
70
+ print(config.error_msg(f"OpenAI 初始化失败: {e}"))
71
+
72
+ # ---- 模型加载 ----
73
+ class DummyPipe:
74
+ def __call__(self, *args, **kwargs) -> None:
75
+ raise RuntimeError("图像生成模型未加载")
76
+
77
+ # 语音识别模型
78
+ asr_pipeline = None
79
+ try:
80
+ asr_pipeline = pipeline(
81
+ "automatic-speech-recognition",
82
+ model="openai/whisper-base",
83
+ device=config.DEVICE,
84
+ torch_dtype=config.MODEL["dtype"]
 
 
85
  )
86
+ print("✅ 语音识别模型加载成功")
87
+ except Exception as e:
88
+ print(config.error_msg(f"语音模型加载失败: {e}"))
89
 
90
+ # 图像生成模型
91
+ image_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe()
92
+ try:
93
+ image_pipe = AutoPipelineForText2Image.from_pretrained(
94
+ config.MODEL["model_id"],
95
+ torch_dtype=config.MODEL["dtype"],
96
+ use_safetensors=True,
97
+ resume_download=True,
98
+ timeout=config.MODEL["timeout"]
99
+ ).to(config.DEVICE)
100
+ print(f"✅ 图像模型 {config.MODEL['model_id']} 加载成功")
101
+ except Exception as e:
102
+ print(config.error_msg(f"图像模型加载失败: {e}"))
103
+
104
+ # ---- 核心功能 ----
105
+ def enhance_prompt(short_prompt: str, style: str, quality: list) -> str:
106
+ """提示词优化处理"""
107
+ if not short_prompt.strip():
108
+ raise gr.Error("描述内容不能为空")
109
+
110
+ # 基础增强模板
111
+ base_prompt = f"{short_prompt.strip()}, {style}, {', '.join(quality)}"
112
+
113
+ if not openai_available:
114
+ return base_prompt
115
+
116
+ try:
117
+ response = openai_client.chat.completions.create(
118
+ model="gpt-3.5-turbo",
119
+ messages=[{
120
+ "role": "system",
121
+ "content": "你是一个AI绘画提示词专家,请把用户的简短描述扩展为适合小模型使用的详细提示词。"
122
+ }, {
123
+ "role": "user",
124
+ "content": f"请优化这个提示词:'{base_prompt}'。要求:保持简洁,适合快速生成,包含主要视觉元素。"
125
+ }],
126
+ temperature=0.7,
127
+ max_tokens=100
128
+ )
129
+ return response.choices[0].message.content.strip('"')
130
+ except Exception as e:
131
+ print(config.error_msg(f"提示词优化失败: {e}"))
132
+ return base_prompt
133
 
134
+ def generate_image(prompt: str, neg_prompt: str, cfg: float, steps: int) -> Image.Image:
135
+ """图像生成核心函数"""
136
+ if isinstance(image_pipe, DummyPipe):
137
+ raise gr.Error("图像生成功能不可用:模型加载失败")
138
+
139
+ try:
140
+ with torch.no_grad():
141
+ result = image_pipe(
142
+ prompt=prompt,
143
+ negative_prompt=neg_prompt,
144
+ guidance_scale=cfg,
145
+ num_inference_steps=steps,
146
+ generator=torch.Generator(config.DEVICE).manual_seed(int(time.time()))
147
+ )
148
+ return result.images[0]
149
+ except Exception as e:
150
+ raise gr.Error(f"生成失败: {str(e)}")
151
+
152
+ def transcribe_audio(audio_path: str) -> str:
153
+ """语音转文字处理"""
154
+ if not asr_pipeline or not audio_path:
155
  return ""
156
+
157
  try:
158
+ return asr_pipeline(audio_path)["text"].strip()
 
 
159
  except Exception as e:
160
+ print(config.error_msg(f"语音识别失败: {e}"))
161
  return ""
162
 
163
+ # ---- 界面逻辑 ----
164
+ STYLE_OPTIONS = {
165
+ "🎥 电影风格": "cinematic lighting",
166
+ "🖼️ 照片写实": "photorealistic",
167
+ "🇯🇵 二次元": "anime style",
168
+ "🎨 水彩艺术": "watercolor painting"
169
+ }
170
+
171
+ QUALITY_OPTIONS = [
172
+ "高清细节", "复杂构图",
173
+ "专业光影", "4K分辨率"
174
+ ]
175
+
176
+ def process_inputs(
177
+ text: str,
178
+ audio: Optional[str],
179
+ style: str,
180
+ quality: list,
181
+ neg_prompt: str,
182
+ cfg: float,
183
+ steps: int
184
+ ) -> Tuple[str, Optional[Image.Image]]:
185
+ """主处理流程"""
186
+ try:
187
+ # 输入处理
188
+ final_text = text.strip()
189
+ if audio and os.path.exists(audio):
190
+ final_text = transcribe_audio(audio) or final_text
191
+
192
+ # 提示词优化
193
+ enhanced = enhance_prompt(final_text, STYLE_OPTIONS[style], quality)
194
+
195
+ # 图像生成
196
+ start_time = time.time()
197
+ image = generate_image(enhanced, neg_prompt, cfg, steps)
198
+ time_cost = time.time() - start_time
199
+
200
+ return f"✅ 生成成功(耗时:{time_cost:.1f}s)\n{enhanced}", image
201
+ except Exception as e:
202
+ return f"❌ 生成失败:{str(e)}", None
203
+
204
+ # ---- Gradio界面 ----
205
+ with gr.Blocks(theme=gr.themes.Soft(), css=config.UI["warning_css"]) as app:
206
+ # 标题区
207
+ gr.Markdown(f"## {config.UI['title']}")
208
+ gr.Markdown(config.UI["description"])
209
 
210
+ # 状态提示
211
+ if not openai_available:
212
+ gr.HTML("<div class='warning'>⚠️ OpenAI服务未启用,使用基础提示优化</div>")
213
+ if isinstance(image_pipe, DummyPipe):
214
+ gr.HTML("<div class='warning'>⚠️ 图像生成功能不可用:模型加载失败</div>")
215
+
216
  with gr.Row():
217
+ # 输入列
218
+ with gr.Column(scale=1):
219
+ input_text = gr.Textbox(
220
+ label="📝 输入描述",
221
+ placeholder="例:机械猫在火星咖啡馆喝咖啡",
222
+ max_lines=3
 
 
 
223
  )
 
224
 
225
+ audio_input = gr.Audio(
226
+ sources=["microphone"],
227
+ type="filepath",
228
+ label="🎤 语音输入",
229
+ visible=bool(asr_pipeline)
 
230
 
231
+ with gr.Accordion("⚙️ 高级参数", open=False):
232
+ style_select = gr.Dropdown(
233
+ label="艺术风格",
234
+ choices=list(STYLE_OPTIONS.keys()),
235
+ value="🎥 电影风格"
236
+ )
237
+ quality_check = gr.CheckboxGroup(
238
+ label="质量增强",
239
+ choices=QUALITY_OPTIONS,
240
+ value=["高清细节"]
241
+ )
242
+ neg_prompt = gr.Textbox(
243
+ label="🚫 排除内容",
244
+ placeholder="输入不希望出现的元素..."
245
+ )
246
+ cfg_slider = gr.Slider(
247
+ 1.0, 10.0,
248
+ value=config.DEFAULT_GUIDANCE,
249
+ label="生成引导强度"
250
+ )
251
+ steps_slider = gr.Slider(
252
+ 5, config.MAX_STEPS,
253
+ value=config.DEFAULT_STEPS,
254
+ label="迭代步数"
255
+ )
256
 
257
+ generate_btn = gr.Button(
258
+ "✨ 开始生成",
259
+ variant="primary",
260
+ interactive=not isinstance(image_pipe, DummyPipe)
261
+ )
262
+
263
+ # 输出列
264
+ with gr.Column(scale=1):
265
+ prompt_output = gr.Textbox(
266
+ label="📋 生成提示",
267
+ interactive=False,
268
+ lines=4
269
+ )
270
+ image_output = gr.Image(
271
+ label="🖼️ 生成结果",
272
+ type="pil",
273
+ height=512,
274
+ show_download_button=True
275
+ )
276
+
277
+ # 事件绑定
278
+ inputs = [input_text, audio_input, style_select, quality_check, neg_prompt, cfg_slider, steps_slider]
279
+ generate_btn.click(process_inputs, inputs, [prompt_output, image_output])
280
+
281
+ # 音频输入自动清空文本
282
+ if asr_pipeline:
283
+ audio_input.change(
284
+ lambda x: "" if x else gr.update(),
285
+ audio_input, input_text
286
+ )
287
 
288
+ # ---- 启动应用 ----
289
  if __name__ == "__main__":
290
  app.launch(server_name="0.0.0.0", server_port=7860)