Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, set_seed | |
| from diffusers import AutoPipelineForText2Image | |
| import openai | |
| import os | |
| import time | |
| import traceback | |
| from typing import Optional, Tuple, Union, Literal, TypedDict | |
| from PIL import Image | |
| # 在代码开头添加: | |
| import os | |
| os.environ["OPENAI_API_KEY"] = "sk-your-api-key-here" | |
| # ---- 类型定义 ---- | |
| class ModelConfig(TypedDict): | |
| model_id: str | |
| dtype: torch.dtype | |
| timeout: int | |
| class UIConfig(TypedDict): | |
| title: str | |
| description: str | |
| warning_css: str | |
| # ---- 配置管理 ---- | |
| class AppConfig: | |
| # 硬件配置 | |
| DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 模型配置 | |
| MODEL: ModelConfig = { | |
| "model_id": "nota-ai/bk-sdm-tiny", | |
| "dtype": torch.float32, | |
| "timeout": 300 | |
| } | |
| # 界面配置 | |
| UI: UIConfig = { | |
| "title": "🎨 轻量级AI图像生成器(CPU/GPU版)", | |
| "description": """\ | |
| 💡 使用技巧:输入简短描述后选择风格和质量选项\n | |
| 🚀 支持语音输入 • 自动提示词优化 • 快速生成模式\n | |
| ⚠️ 注意:小模型生成速度快但细节有限,建议使用具体描述""", | |
| "warning_css": """ | |
| .warning {color: orange !important; border-left: 3px solid orange; padding: 10px;} | |
| .success {color: green !important;} | |
| """ | |
| } | |
| # 生成参数 | |
| DEFAULT_STEPS: int = 20 | |
| MAX_STEPS: int = 40 | |
| DEFAULT_GUIDANCE: float = 5.0 | |
| # 错误模板 | |
| def error_msg(message: str) -> str: | |
| return f"❌ 错误:{message}" | |
| config = AppConfig() | |
| # ---- 初始化检查 ---- | |
| openai_client: Optional[openai.OpenAI] = None | |
| openai_available: bool = False | |
| if os.environ.get("OPENAI_API_KEY"): | |
| try: | |
| openai_client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| openai_available = True | |
| print("✅ OpenAI 客户端初始化成功") | |
| except Exception as e: | |
| print(config.error_msg(f"OpenAI 初始化失败: {e}")) | |
| # ---- 模型加载 ---- | |
| class DummyPipe: | |
| def __call__(self, *args, **kwargs) -> None: | |
| raise RuntimeError("图像生成模型未加载") | |
| # 语音识别模型 | |
| asr_pipeline = None | |
| try: | |
| asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base", | |
| device=config.DEVICE, | |
| torch_dtype=config.MODEL["dtype"] | |
| ) | |
| print("✅ 语音识别模型加载成功") | |
| except Exception as e: | |
| print(config.error_msg(f"语音模型加载失败: {e}")) | |
| # 图像生成模型 | |
| image_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe() | |
| try: | |
| image_pipe = AutoPipelineForText2Image.from_pretrained( | |
| config.MODEL["model_id"], | |
| torch_dtype=config.MODEL["dtype"], | |
| use_safetensors=True, | |
| resume_download=True, | |
| timeout=config.MODEL["timeout"] | |
| ).to(config.DEVICE) | |
| print(f"✅ 图像模型 {config.MODEL['model_id']} 加载成功") | |
| except Exception as e: | |
| print(config.error_msg(f"图像模型加载失败: {e}")) | |
| # ---- 核心功能 ---- | |
| def enhance_prompt(short_prompt: str, style: str, quality: list) -> str: | |
| """提示词优化处理""" | |
| if not short_prompt.strip(): | |
| raise gr.Error("描述内容不能为空") | |
| # 基础增强模板 | |
| base_prompt = f"{short_prompt.strip()}, {style}, {', '.join(quality)}" | |
| if not openai_available: | |
| return base_prompt | |
| try: | |
| response = openai_client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{ | |
| "role": "system", | |
| "content": "你是一个AI绘画提示词专家,请把用户的简短描述扩展为适合小模型使用的详细提示词。" | |
| }, { | |
| "role": "user", | |
| "content": f"请优化这个提示词:'{base_prompt}'。要求:保持简洁,适合快速生成,包含主要视觉元素。" | |
| }], | |
| temperature=0.7, | |
| max_tokens=100 | |
| ) | |
| return response.choices[0].message.content.strip('"') | |
| except Exception as e: | |
| print(config.error_msg(f"提示词优化失败: {e}")) | |
| return base_prompt | |
| def generate_image(prompt: str, neg_prompt: str, cfg: float, steps: int) -> Image.Image: | |
| """图像生成核心函数""" | |
| if isinstance(image_pipe, DummyPipe): | |
| raise gr.Error("图像生成功能不可用:模型加载失败") | |
| try: | |
| with torch.no_grad(): | |
| result = image_pipe( | |
| prompt=prompt, | |
| negative_prompt=neg_prompt, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=torch.Generator(config.DEVICE).manual_seed(int(time.time())) | |
| ) | |
| return result.images[0] | |
| except Exception as e: | |
| raise gr.Error(f"生成失败: {str(e)}") | |
| def transcribe_audio(audio_path: str) -> str: | |
| """语音转文字处理""" | |
| if not asr_pipeline or not audio_path: | |
| return "" | |
| try: | |
| return asr_pipeline(audio_path)["text"].strip() | |
| except Exception as e: | |
| print(config.error_msg(f"语音识别失败: {e}")) | |
| return "" | |
| # ---- 界面逻辑 ---- | |
| STYLE_OPTIONS = { | |
| "🎥 电影风格": "cinematic lighting", | |
| "🖼️ 照片写实": "photorealistic", | |
| "🇯🇵 二次元": "anime style", | |
| "🎨 水彩艺术": "watercolor painting" | |
| } | |
| QUALITY_OPTIONS = [ | |
| "高清细节", "复杂构图", | |
| "专业光影", "4K分辨率" | |
| ] | |
| def process_inputs( | |
| text: str, | |
| audio: Optional[str], | |
| style: str, | |
| quality: list, | |
| neg_prompt: str, | |
| cfg: float, | |
| steps: int | |
| ) -> Tuple[str, Optional[Image.Image]]: | |
| """主处理流程""" | |
| try: | |
| # 输入处理 | |
| final_text = text.strip() | |
| if audio and os.path.exists(audio): | |
| final_text = transcribe_audio(audio) or final_text | |
| # 提示词优化 | |
| enhanced = enhance_prompt(final_text, STYLE_OPTIONS[style], quality) | |
| # 图像生成 | |
| start_time = time.time() | |
| image = generate_image(enhanced, neg_prompt, cfg, steps) | |
| time_cost = time.time() - start_time | |
| return f"✅ 生成成功(耗时:{time_cost:.1f}s)\n{enhanced}", image | |
| except Exception as e: | |
| return f"❌ 生成失败:{str(e)}", None | |
| # ---- Gradio界面 ---- | |
| with gr.Blocks(theme=gr.themes.Soft(), css=config.UI["warning_css"]) as app: | |
| # 标题区 | |
| gr.Markdown(f"## {config.UI['title']}") | |
| gr.Markdown(config.UI["description"]) | |
| # 状态提示 | |
| if not openai_available: | |
| gr.HTML("<div class='warning'>⚠️ OpenAI服务未启用,使用基础提示优化</div>") | |
| if isinstance(image_pipe, DummyPipe): | |
| gr.HTML("<div class='warning'>⚠️ 图像生成功能不可用:模型加载失败</div>") | |
| with gr.Row(): | |
| # 输入列 | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| label="📝 输入描述", | |
| placeholder="例:机械猫在火星咖啡馆喝咖啡", | |
| max_lines=3 | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="🎤 语音输入", | |
| visible=bool(asr_pipeline) | |
| ) | |
| with gr.Accordion("⚙️ 高级参数", open=False): | |
| style_select = gr.Dropdown( | |
| label="艺术风格", | |
| choices=list(STYLE_OPTIONS.keys()), | |
| value="🎥 电影风格" | |
| ) | |
| quality_check = gr.CheckboxGroup( | |
| label="质量增强", | |
| choices=QUALITY_OPTIONS, | |
| value=["高清细节"] | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="🚫 排除内容", | |
| placeholder="输入不希望出现的元素..." | |
| ) | |
| cfg_slider = gr.Slider( | |
| 1.0, 10.0, | |
| value=config.DEFAULT_GUIDANCE, | |
| label="生成引导强度" | |
| ) | |
| steps_slider = gr.Slider( | |
| 5, config.MAX_STEPS, | |
| value=config.DEFAULT_STEPS, | |
| label="迭代步数" | |
| ) | |
| generate_btn = gr.Button( | |
| "✨ 开始生成", | |
| variant="primary", | |
| interactive=not isinstance(image_pipe, DummyPipe) | |
| ) | |
| # 输出列 | |
| with gr.Column(scale=1): | |
| prompt_output = gr.Textbox( | |
| label="📋 生成提示", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| image_output = gr.Image( | |
| label="🖼️ 生成结果", | |
| type="pil", | |
| height=512, | |
| show_download_button=True | |
| ) | |
| # 事件绑定 | |
| inputs = [input_text, audio_input, style_select, quality_check, neg_prompt, cfg_slider, steps_slider] | |
| generate_btn.click(process_inputs, inputs, [prompt_output, image_output]) | |
| # 音频输入自动清空文本 | |
| if asr_pipeline: | |
| audio_input.change( | |
| lambda x: "" if x else gr.update(), | |
| audio_input, input_text | |
| ) | |
| # ---- 启动应用 ---- | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |