Spaces:
Build error
Build error
| # Lazy imports for heavy ML libraries - only imported when needed | |
| # This reduces startup time from ~1 minute to a few seconds | |
| import gc | |
| import sys | |
| import os | |
| import time | |
| import psutil | |
| import json | |
| import spaces | |
| from threading import Thread | |
| #----------------- | |
| from relatively_constant_variables import knowledge_base | |
| # Lazy import placeholders - will be imported on first use | |
| torch = None | |
| transformers = None | |
| diffusers = None | |
| sentence_transformers = None | |
| def _ensure_torch(): | |
| """Lazy import torch only when needed.""" | |
| global torch | |
| if torch is None: | |
| import torch as _torch | |
| torch = _torch | |
| return torch | |
| def _ensure_transformers(): | |
| """Lazy import transformers only when needed.""" | |
| global transformers | |
| if transformers is None: | |
| import transformers as _transformers | |
| transformers = _transformers | |
| return transformers | |
| def _ensure_diffusers(): | |
| """Lazy import diffusers only when needed.""" | |
| global diffusers | |
| if diffusers is None: | |
| import diffusers as _diffusers | |
| diffusers = _diffusers | |
| return diffusers | |
| def _ensure_sentence_transformers(): | |
| """Lazy import sentence_transformers only when needed.""" | |
| global sentence_transformers | |
| if sentence_transformers is None: | |
| import sentence_transformers as _st | |
| sentence_transformers = _st | |
| return sentence_transformers | |
| # Directory for saving generated media (same as file_explorer_and_upload.py) | |
| GENERATED_MEDIA_DIR = os.path.abspath("saved_media") | |
| os.makedirs(GENERATED_MEDIA_DIR, exist_ok=True) | |
| modelnames = ["stvlynn/Gemma-2-2b-Chinese-it", "unsloth/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-3B-Instruct", "nbeerbower/mistral-nemo-wissenschaft-12B", "princeton-nlp/gemma-2-9b-it-SimPO", "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k", "01-ai/Yi-Coder-9B-Chat", "ArliAI/Llama-3.1-8B-ArliAI-RPMax-v1.1", "ArliAI/Phi-3.5-mini-3.8B-ArliAI-RPMax-v1.1", | |
| "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2-0.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-7B-Instruct", "Qwen/Qwen1.5-MoE-A2.7B-Chat", "HuggingFaceTB/SmolLM-135M-Instruct", "microsoft/Phi-3-mini-4k-instruct", "Groq/Llama-3-Groq-8B-Tool-Use", "hugging-quants/Meta-Llama-3.1-8B-Instruct-BNB-NF4", | |
| "SpectraSuite/TriLM_3.9B_Unpacked", "h2oai/h2o-danube3-500m-chat", "OuteAI/Lite-Mistral-150M-v2-Instruct", "Zyphra/Zamba2-1.2B", "anthracite-org/magnum-v2-4b", | |
| "unsloth/functiongemma-270m-it", # FunctionGemma for function calling | |
| # New models (Dec 2025) | |
| "HuggingFaceTB/SmolLM3-3B", | |
| "unsloth/Ministral-3-3B-Instruct-2512-bnb-4bit", | |
| "unsloth/granite-4.0-h-micro-bnb-4bit", | |
| # New models (Jan 2026) | |
| "tiiuae/Falcon-H1R-7B", # Hybrid Transformer+Mamba2, reasoning-specialized | |
| "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", # MoE 3.5B active/30B total, FP8 | |
| "openai/gpt-oss-20b", # MoE 3.6B active/21B total, Apache 2.0, agentic | |
| ] | |
| # T5Gemma2 encoder-decoder models (require AutoModelForSeq2SeqLM) | |
| seq2seq_modelnames = [ | |
| "google/t5gemma-2-270m-270m", | |
| "google/t5gemma-2-1b-1b", | |
| ] | |
| # imagemodelnames = ["black-forest-labs/FLUX.1-schnell", ] | |
| current_model_index = 0 | |
| current_image_model_index = 0 | |
| modelname = modelnames[current_model_index] | |
| # imagemodelname = imagemodelnames[current_image_model_index] | |
| lastmodelnameinloadfunction = None | |
| lastimagemodelnameinloadfunction = None | |
| embedding_model = None | |
| knowledge_base_embeddings = None | |
| def initialize_rag(): | |
| global embedding_model, knowledge_base_embeddings | |
| if embedding_model is None: | |
| st = _ensure_sentence_transformers() | |
| embedding_model = st.SentenceTransformer('all-MiniLM-L6-v2') | |
| knowledge_base_embeddings = embedding_model.encode([doc["content"] for doc in knowledge_base]) | |
| # Initialize model and tokenizer as global variables | |
| model = None | |
| tokenizer = None | |
| image_pipe = None | |
| imagemodelnames = [ | |
| "stabilityai/sd-turbo", | |
| "stabilityai/sdxl-turbo", | |
| # New models (Dec 2025) | |
| "radames/Real-Time-Text-to-Image-SDXL-Lightning", | |
| "unsloth/Qwen-Image-GGUF", # GGUF - may need special handling | |
| "unsloth/Z-Image-Turbo-GGUF", # GGUF - may need special handling | |
| ] | |
| current_image_model = imagemodelnames[0] # Default to sd-turbo (smaller/faster) | |
| # Video/I2V models | |
| videomodelnames = [ | |
| # LTX Video - distilled, fast (7-8 steps), works with diffusers | |
| "Lightricks/LTX-Video-0.9.7-distilled", # 13B distilled, CFG=1, fast iterations | |
| # Wan2.2 - Text/Image to Video | |
| "Wan-AI/Wan2.2-TI2V-5B-Diffusers", # 5B, T2V+I2V, 720P, runs on 4090 | |
| "Wan-AI/Wan2.2-T2V-A14B-Diffusers", # 14B MoE, text-to-video | |
| "Wan-AI/Wan2.2-I2V-A14B-Diffusers", # 14B MoE, image-to-video | |
| # HunyuanVideo - Tencent, consumer GPU friendly (use community diffusers version) | |
| "hunyuanvideo-community/HunyuanVideo", # 13B original, diffusers-compatible | |
| # GGUF format (may need llama.cpp or special handling) | |
| "QuantStack/Wan2.2-I2V-A14B-GGUF", # Image-to-Video, GGUF format | |
| ] | |
| # Dictionary to store loaded models | |
| loaded_models = {} | |
| # Seq2seq model globals (for T5Gemma2) | |
| seq2seq_model = None | |
| seq2seq_processor = None | |
| # Gemma Scope SAE globals | |
| gemma_scope_sae = None | |
| gemma_scope_layer = None | |
| def get_size_str(bytes): | |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
| if bytes < 1024: | |
| return f"{bytes:.2f} {unit}" | |
| bytes /= 1024 | |
| # Track currently loaded model name for model switching | |
| current_loaded_model_name = None | |
| def load_model(model_name): | |
| """ | |
| Load model on CPU only - DO NOT use device_map="auto" or CUDA operations here. | |
| CUDA operations must only happen inside @spaces.GPU decorated functions. | |
| The model will be moved to GPU inside generate_response(). | |
| """ | |
| global model, tokenizer, lastmodelnameinloadfunction, loaded_models, current_loaded_model_name | |
| # Lazy import heavy libraries | |
| _torch = _ensure_torch() | |
| tf = _ensure_transformers() | |
| print(f"Loading model and tokenizer: {model_name}") | |
| # Clear old model and tokenizer if they exist | |
| if 'model' in globals() and model is not None: | |
| del model | |
| model = None | |
| if 'tokenizer' in globals() and tokenizer is not None: | |
| tokenizer = None | |
| # Force garbage collection (no CUDA here - that happens in @spaces.GPU) | |
| gc.collect() | |
| # Load model on CPU - it will be moved to GPU inside @spaces.GPU function | |
| # Use device_map=None to avoid CUDA initialization | |
| model = tf.AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=_torch.bfloat16, # Use bfloat16 for efficiency | |
| device_map=None, # Don't auto-map to GPU - we'll do it in @spaces.GPU | |
| low_cpu_mem_usage=True | |
| ) | |
| tokenizer = tf.AutoTokenizer.from_pretrained(model_name) | |
| # Calculate sizes (CPU only, no CUDA) | |
| model_size = sum(p.numel() * p.element_size() for p in model.parameters()) | |
| tokenizer_size = sum(sys.getsizeof(v) for v in tokenizer.__dict__.values()) | |
| loaded_models[model_name] = [str(time.time()), model_size] | |
| current_loaded_model_name = model_name | |
| lastmodelnameinloadfunction = (model_name, model_size, tokenizer_size) | |
| print(f"Model and tokenizer {model_name} loaded successfully (on CPU)") | |
| print(f"Model size: {get_size_str(model_size)}") | |
| print(f"Tokenizer size: {get_size_str(tokenizer_size)}") | |
| return (f"Model {model_name} loaded (CPU). " | |
| f"Size: {get_size_str(model_size)}. " | |
| f"Will move to GPU on generation.") | |
| def load_seq2seq_model(model_name): | |
| """Load T5Gemma2 or similar encoder-decoder model.""" | |
| global seq2seq_model, seq2seq_processor | |
| _torch = _ensure_torch() | |
| tf = _ensure_transformers() | |
| print(f"Loading seq2seq model: {model_name}") | |
| # Don't call cuda.memory_allocated() here - it can initialize CUDA outside @spaces.GPU | |
| initial_memory = 0 | |
| # Clear previous | |
| if seq2seq_model is not None: | |
| seq2seq_model = None | |
| if seq2seq_processor is not None: | |
| seq2seq_processor = None | |
| # Don't call cuda.empty_cache() here - it initializes CUDA outside @spaces.GPU | |
| gc.collect() | |
| seq2seq_processor = tf.AutoProcessor.from_pretrained(model_name) | |
| # Load on CPU - will be moved to GPU in @spaces.GPU function | |
| seq2seq_model = tf.AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| torch_dtype=_torch.bfloat16, | |
| device_map=None, # Don't auto-map to GPU | |
| low_cpu_mem_usage=True | |
| ) | |
| print(f"Seq2seq model {model_name} loaded on CPU. Will move to GPU on generation.") | |
| return f"Loaded: {model_name} (CPU). Will move to GPU on generation." | |
| def generate_seq2seq_response(prompt, image_url=None): | |
| """Generate response using T5Gemma2.""" | |
| global seq2seq_model, seq2seq_processor | |
| _torch = _ensure_torch() | |
| if seq2seq_model is None: | |
| load_seq2seq_model(seq2seq_modelnames[0]) | |
| zero = _torch.Tensor([0]).cuda() | |
| seq2seq_model.to(zero.device) | |
| if image_url: | |
| from PIL import Image | |
| import requests | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| inputs = seq2seq_processor(text=prompt, images=image, return_tensors="pt") | |
| else: | |
| inputs = seq2seq_processor(text=prompt, return_tensors="pt") | |
| inputs = {k: v.to(zero.device) for k, v in inputs.items()} | |
| outputs = seq2seq_model.generate(**inputs, max_new_tokens=256) | |
| response = seq2seq_processor.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # ============ GEMMA SCOPE 2 SAE FUNCTIONS ============ | |
| def load_gemma_scope_sae(layer_num=12): | |
| """Load Gemma Scope SAE for a specific layer.""" | |
| global gemma_scope_sae, gemma_scope_layer | |
| _torch = _ensure_torch() | |
| try: | |
| from sae_lens import SAE | |
| except ImportError: | |
| return "Error: sae_lens not installed. Run: pip install sae_lens" | |
| # Use canonical release with correct layer ID format | |
| layer_id = f"layer_{layer_num}/width_16k/canonical" | |
| try: | |
| # Load on CPU - will be moved to GPU in @spaces.GPU function | |
| gemma_scope_sae = SAE.from_pretrained( | |
| release="gemma-scope-2b-pt-res-canonical", # Gemma 2 2B canonical | |
| sae_id=layer_id, | |
| device="cpu" # Don't initialize CUDA here | |
| ) | |
| gemma_scope_layer = layer_num | |
| return f"Loaded SAE for layer {layer_num}: {layer_id} (CPU)" | |
| except Exception as e: | |
| return f"Error loading SAE: {str(e)}" | |
| def analyze_prompt_features(prompt, top_k=10): | |
| """Analyze which SAE features activate for a given prompt.""" | |
| global model, tokenizer, gemma_scope_sae | |
| _torch = _ensure_torch() | |
| top_k = int(top_k) # Ensure it's an int (from slider) | |
| # Need a Gemma 2 model for SAE analysis - use the Chinese fine-tune from modelnames | |
| if model is None or "gemma" not in str(getattr(model, 'name_or_path', '')).lower(): | |
| load_model("stvlynn/Gemma-2-2b-Chinese-it") # Use existing Gemma 2 from modelnames | |
| if gemma_scope_sae is None: | |
| load_result = load_gemma_scope_sae() | |
| if "Error" in load_result: | |
| return load_result | |
| zero = _torch.Tensor([0]).cuda() | |
| model.to(zero.device) | |
| # Move SAE to GPU if it has a .to() method | |
| if hasattr(gemma_scope_sae, 'to'): | |
| gemma_scope_sae.to(zero.device) | |
| # Get model activations | |
| inputs = tokenizer(prompt, return_tensors="pt").to(zero.device) | |
| with _torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| # Run through SAE - hidden_states[0] is embedding, so layer N is at index N+1 | |
| layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13 | |
| if layer_idx >= len(outputs.hidden_states): | |
| layer_idx = len(outputs.hidden_states) - 1 # Use last layer if out of bounds | |
| hidden_state = outputs.hidden_states[layer_idx] | |
| feature_acts = gemma_scope_sae.encode(hidden_state) | |
| # Get top activated features | |
| top_features = _torch.topk(feature_acts.mean(dim=1).squeeze(), top_k) | |
| # Build Neuronpedia base URL for this layer/SAE | |
| # Format: https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-16k/{feature_id} | |
| layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12 | |
| neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k" | |
| results = ["## Top Activated Features\n"] | |
| results.append("| Feature | Activation | Neuronpedia Link |") | |
| results.append("|---------|------------|------------------|") | |
| for idx, val in zip(top_features.indices, top_features.values): | |
| feature_id = idx.item() | |
| activation = val.item() | |
| link = f"{neuronpedia_base}/{feature_id}" | |
| results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |") | |
| results.append("") | |
| results.append("---") | |
| results.append("**How to use:** Click the links to see what concepts each feature represents.") | |
| results.append("- Higher activation = concept is more relevant to your prompt") | |
| results.append("- Compare prompts to find features that make configs interesting vs predictable") | |
| return "\n".join(results) | |
| def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"): | |
| """Fetch feature data from Neuronpedia API.""" | |
| import requests | |
| feature_id = int(feature_id) | |
| layer = int(layer) | |
| # Neuronpedia API endpoint | |
| api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}" | |
| try: | |
| response = requests.get(api_url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return format_neuronpedia_feature(data, feature_id, layer, width) | |
| elif response.status_code == 404: | |
| return f"Feature {feature_id} not found at layer {layer}" | |
| else: | |
| return f"API error: {response.status_code}" | |
| except requests.exceptions.Timeout: | |
| return "Request timed out - Neuronpedia may be slow" | |
| except Exception as e: | |
| return f"Error fetching feature: {str(e)}" | |
| def format_neuronpedia_feature(data, feature_id, layer, width): | |
| """Format Neuronpedia feature data as markdown.""" | |
| results = [] | |
| # Header | |
| results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)") | |
| results.append("") | |
| # Description if available | |
| if data.get("description"): | |
| results.append(f"**Description:** {data['description']}") | |
| results.append("") | |
| # Auto-interp explanation if available | |
| if data.get("explanations") and len(data["explanations"]) > 0: | |
| explanation = data["explanations"][0].get("description", "") | |
| if explanation: | |
| results.append(f"**Auto-interpretation:** {explanation}") | |
| results.append("") | |
| # Activation examples | |
| if data.get("activations") and len(data["activations"]) > 0: | |
| results.append("### Top Activating Examples") | |
| results.append("") | |
| for i, act in enumerate(data["activations"][:5]): | |
| tokens = act.get("tokens", []) | |
| values = act.get("values", []) | |
| if tokens: | |
| # Highlight the max activating token | |
| max_idx = values.index(max(values)) if values else 0 | |
| text_parts = [] | |
| for j, tok in enumerate(tokens): | |
| if j == max_idx: | |
| text_parts.append(f"**{tok}**") | |
| else: | |
| text_parts.append(tok) | |
| text = "".join(text_parts) | |
| results.append(f"{i+1}. {text}") | |
| results.append("") | |
| # Stats | |
| results.append("### Feature Stats") | |
| results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`") | |
| if data.get("max_activation"): | |
| results.append(f"- **Max Activation:** {data['max_activation']:.2f}") | |
| if data.get("frac_nonzero"): | |
| results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%") | |
| results.append("") | |
| results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})") | |
| return "\n".join(results) | |
| def load_image_model(model_name=None): | |
| """Load image model on CPU - will be moved to GPU in @spaces.GPU function.""" | |
| global image_pipe, current_image_model | |
| _torch = _ensure_torch() | |
| diff = _ensure_diffusers() | |
| if model_name: | |
| current_image_model = model_name | |
| print(f"Loading image model: {current_image_model}") | |
| # Don't call cuda.empty_cache() here - it initializes CUDA outside @spaces.GPU | |
| gc.collect() | |
| image_pipe = diff.AutoPipelineForText2Image.from_pretrained( | |
| current_image_model, | |
| torch_dtype=_torch.float16, | |
| variant="fp16" | |
| ) | |
| # Don't move to CUDA here - will be done in @spaces.GPU function | |
| print(f"Image model {current_image_model} loaded on CPU") | |
| return image_pipe | |
| def clear_all_models(): | |
| """Clear all loaded models from memory.""" | |
| global model, tokenizer, image_pipe, loaded_models | |
| for model_name, model_obj in loaded_models.items(): | |
| if isinstance(model_obj, tuple): | |
| del model_obj[0] | |
| del model_obj[1] | |
| else: | |
| del model_obj | |
| model = None | |
| tokenizer = None | |
| image_pipe = None | |
| loaded_models.clear() | |
| # Don't call cuda.empty_cache() here - it initializes CUDA outside @spaces.GPU | |
| gc.collect() | |
| return "All models cleared from memory." | |
| def load_model_list(model_list): | |
| messages = [] | |
| for model_name in model_list: | |
| message = load_model(model_name) | |
| messages.append(message) | |
| return "\n".join(messages) | |
| def loaded_model_list(): | |
| global loaded_models | |
| return loaded_models | |
| # Initial model load | |
| # load_model(modelname) | |
| # load_image_model(imagemodelname) | |
| # Create embeddings for the knowledge base | |
| def retrieve(query, k=2): | |
| _torch = _ensure_torch() | |
| initialize_rag() | |
| query_embedding = embedding_model.encode([query]) | |
| similarities = _torch.nn.functional.cosine_similarity(_torch.tensor(query_embedding), _torch.tensor(knowledge_base_embeddings)) | |
| top_k_indices = similarities.argsort(descending=True)[:k] | |
| return [(knowledge_base[i]["content"], knowledge_base[i]["id"]) for i in top_k_indices] | |
| def get_ram_usage(): | |
| ram = psutil.virtual_memory() | |
| return f"RAM Usage: {ram.percent:.2f}%, Available: {ram.available / (1024 ** 3):.2f}GB, Total: {ram.total / (1024 ** 3):.2f}GB" | |
| # Global dictionary to store outputs | |
| output_dict = {} | |
| def empty_output_dict(): | |
| global output_dict | |
| output_dict = {} | |
| print("Output dictionary has been emptied.") | |
| def get_model_details(model): | |
| return { | |
| "name": model.config.name_or_path, | |
| "architecture": model.config.architectures[0] if model.config.architectures else "Unknown", | |
| "num_parameters": sum(p.numel() for p in model.parameters()), | |
| } | |
| def get_tokenizer_details(tokenizer): | |
| return { | |
| "name": tokenizer.__class__.__name__, | |
| "vocab_size": tokenizer.vocab_size, | |
| "model_max_length": tokenizer.model_max_length, | |
| } | |
| def generate_response(prompt, use_rag, stream=False, max_tokens=512, model_name=None): | |
| """ | |
| Generate text response using the loaded model. | |
| Args: | |
| prompt: The input prompt | |
| use_rag: Whether to use RAG (retrieval augmented generation) | |
| stream: Whether to stream the response | |
| max_tokens: Maximum number of tokens to generate (default 512) | |
| model_name: Optional model name - if different from loaded model, will reload | |
| """ | |
| global output_dict, model, tokenizer, current_loaded_model_name | |
| _torch = _ensure_torch() | |
| tf = _ensure_transformers() | |
| # Check if we need to load or switch models | |
| if model_name and model_name != current_loaded_model_name: | |
| print(f"Model switch requested: {current_loaded_model_name} -> {model_name}") | |
| load_model(model_name) | |
| # Check if model is loaded | |
| if model is None or tokenizer is None: | |
| yield ("Error: No model loaded. Please select and load a model first using the model dropdown.", "N/A", "N/A", "N/A") | |
| return | |
| zero = _torch.Tensor([0]).cuda() | |
| print(f"GPU device: {zero.device}, Model: {current_loaded_model_name}") | |
| _torch.cuda.empty_cache() | |
| # Move model to GPU for inference | |
| model.to(zero.device) | |
| if use_rag: | |
| retrieved_docs = retrieve(prompt) | |
| context = " ".join([doc for doc, _ in retrieved_docs]) | |
| doc_ids = [doc_id for _, doc_id in retrieved_docs] | |
| full_prompt = f"Context: {context}\nQuestion: {prompt}\nAnswer:" | |
| else: | |
| full_prompt = prompt | |
| doc_ids = None | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": full_prompt} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(zero.device) | |
| start_time = time.time() | |
| total_tokens = 0 | |
| print(output_dict) | |
| output_key = f"output_{len(output_dict) + 1}" | |
| print(output_key) | |
| output_dict[output_key] = { | |
| "input_prompt": prompt, | |
| "full_prompt": full_prompt, | |
| "use_rag": use_rag, | |
| "max_tokens": max_tokens, | |
| "model_name": current_loaded_model_name, | |
| "generated_text": "", | |
| "tokens_per_second": 0, | |
| "ram_usage": "", | |
| "doc_ids": doc_ids if doc_ids else "N/A", | |
| "model_details": get_model_details(model), | |
| "tokenizer_details": get_tokenizer_details(tokenizer), | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time)) | |
| } | |
| print(output_dict) | |
| # Ensure max_tokens is an integer | |
| max_tokens = int(max_tokens) if max_tokens else 512 | |
| if stream: | |
| streamer = tf.TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| temperature=0.7, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| for new_text in streamer: | |
| output_dict[output_key]["generated_text"] += new_text | |
| total_tokens += 1 | |
| current_time = time.time() | |
| tokens_per_second = total_tokens / (current_time - start_time) | |
| ram_usage = get_ram_usage() | |
| output_dict[output_key]["tokens_per_second"] = f"{tokens_per_second:.2f}" | |
| output_dict[output_key]["ram_usage"] = ram_usage | |
| yield (output_dict[output_key]["generated_text"], | |
| output_dict[output_key]["tokens_per_second"], | |
| output_dict[output_key]["ram_usage"], | |
| output_dict[output_key]["doc_ids"]) | |
| else: | |
| generated_ids = model.generate( | |
| model_inputs.input_ids, | |
| max_new_tokens=max_tokens | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| total_tokens = len(generated_ids[0]) | |
| end_time = time.time() | |
| tokens_per_second = total_tokens / (end_time - start_time) | |
| ram_usage = get_ram_usage() | |
| output_dict[output_key]["generated_text"] = response | |
| output_dict[output_key]["tokens_per_second"] = f"{tokens_per_second:.2f}" | |
| output_dict[output_key]["ram_usage"] = ram_usage | |
| print(output_dict) | |
| yield (output_dict[output_key]["generated_text"], | |
| output_dict[output_key]["tokens_per_second"], | |
| output_dict[output_key]["ram_usage"], | |
| output_dict[output_key]["doc_ids"]) | |
| def generate_image(prompt, model_choice=None): | |
| global image_pipe, current_image_model | |
| _torch = _ensure_torch() | |
| try: | |
| # Load model on-demand if not loaded or if different model requested | |
| if image_pipe is None or (model_choice and model_choice != current_image_model): | |
| print(f"Loading image model on-demand: {model_choice or current_image_model}") | |
| load_image_model(model_choice) | |
| if image_pipe is None: | |
| ram_usage = get_ram_usage() | |
| return "Error: Failed to load image model.", ram_usage, None | |
| # Move model to GPU (loaded on CPU in load_image_model) | |
| zero = _torch.Tensor([0]).cuda() | |
| image_pipe.to(zero.device) | |
| # Generate image using SD-turbo or SDXL-turbo | |
| # These models work best with guidance_scale=0.0 and few steps | |
| image = image_pipe( | |
| prompt=prompt, | |
| num_inference_steps=4, | |
| guidance_scale=0.0, | |
| ).images[0] | |
| # Save to saved_media folder so it appears in file explorer | |
| image_filename = f"sd_output_{time.time()}.png" | |
| image_path = os.path.join(GENERATED_MEDIA_DIR, image_filename) | |
| image.save(image_path) | |
| ram_usage = get_ram_usage() | |
| return f"Image generated with {current_image_model}: {image_filename}", ram_usage, image | |
| except Exception as e: | |
| ram_usage = get_ram_usage() | |
| return f"Error generating image: {str(e)}", ram_usage, None | |
| def get_output_details(output_key): | |
| if output_key in output_dict: | |
| return output_dict[output_key] | |
| else: | |
| return f"No output found for key: {output_key}" | |
| # Update the switch_model function to return the load_model message | |
| def switch_model(choice): | |
| global modelname | |
| modelname = choice | |
| load_message = load_model(modelname) | |
| return load_message, f"Current model: {modelname}" | |
| # Update the model_change_handler function | |
| def model_change_handler(choice): | |
| message, current_model = switch_model(choice) | |
| return message, current_model, message # Use the same message for both outputs | |
| def format_output_dict(): | |
| global output_dict | |
| formatted_output = "" | |
| for key, value in output_dict.items(): | |
| formatted_output += f"Key: {key}\n" | |
| formatted_output += json.dumps(value, indent=2) | |
| formatted_output += "\n\n" | |
| print(formatted_output) | |
| return formatted_output | |
| # ============================================================ | |
| # TTS GENERATION (Multiple Backends) | |
| # ============================================================ | |
| # Supported TTS models: | |
| # - hexgrad/Kokoro-82M: Fast, lightweight TTS (82M params) | |
| # - Supertone/supertonic-2: High-quality expressive TTS (66M params, ONNX) | |
| # - zai-org/GLM-TTS: Multilingual text-to-speech | |
| TTS_MODELS = { | |
| "kokoro": { | |
| "name": "Kokoro-82M", | |
| "space": "Pendrokar/TTS-Spaces-Arena", # Arena has API enabled, supports Kokoro | |
| "fallback_spaces": ["eric-cli/Kokoro-TTS-Local"], | |
| "description": "Fast, lightweight TTS with natural voices", | |
| "local_support": True, | |
| "voices": ["af_heart", "af_bella", "af_nicole", "af_sarah", "af_sky", | |
| "am_adam", "am_michael", "bf_emma", "bf_isabella", "bm_george", "bm_lewis"] | |
| }, | |
| "supertonic": { | |
| "name": "Supertonic-2", | |
| "space": "Supertone/supertonic-2", | |
| "fallback_spaces": [], | |
| "description": "High-quality expressive speech synthesis (ONNX)", | |
| "local_support": True, | |
| "voices": ["F1", "F2", "F3", "F4", "F5", "M1", "M2", "M3", "M4", "M5"] | |
| }, | |
| "glm-tts": { | |
| "name": "GLM-TTS", | |
| "space": "zai-org/GLM-TTS", | |
| "fallback_spaces": [], | |
| "description": "Multilingual text-to-speech with voice cloning", | |
| "local_support": False, | |
| "voices": ["default"] | |
| } | |
| } | |
| # Cached model instances | |
| _kokoro_pipeline = None | |
| _supertonic_model = None | |
| def _load_kokoro(): | |
| """Load Kokoro-82M pipeline for local TTS generation.""" | |
| global _kokoro_pipeline | |
| if _kokoro_pipeline is None: | |
| print("Loading Kokoro-82M...") | |
| from kokoro import KPipeline | |
| _kokoro_pipeline = KPipeline(lang_code='a') | |
| print("Kokoro-82M loaded successfully") | |
| return _kokoro_pipeline | |
| def _load_supertonic(): | |
| """Load Supertonic-2 model for local TTS generation.""" | |
| global _supertonic_model | |
| if _supertonic_model is None: | |
| print("Loading Supertonic-2...") | |
| # Suppress ONNX runtime GPU discovery warnings on systems without proper GPU access | |
| import os | |
| import warnings | |
| os.environ.setdefault('ORT_DISABLE_ALL_WARNINGS', '1') | |
| warnings.filterwarnings('ignore', message='.*device_discovery.*') | |
| warnings.filterwarnings('ignore', message='.*GPU device discovery failed.*') | |
| from supertonic import TTS | |
| _supertonic_model = TTS(auto_download=True) | |
| print("Supertonic-2 loaded successfully") | |
| return _supertonic_model | |
| def generate_tts_local(text, model="kokoro", voice="af_heart"): | |
| """ | |
| Generate TTS audio locally using ZeroGPU. | |
| Args: | |
| text: The text to convert to speech | |
| model: One of "kokoro", "supertonic" | |
| voice: Voice name (model-specific) | |
| Returns: | |
| Tuple of (status_message, audio_path or None) | |
| """ | |
| import soundfile as sf | |
| try: | |
| safe_text = text[:30].replace(' ', '_').replace('/', '_').replace('\\', '_') | |
| filename = f"tts_{model}_{safe_text}_{int(time.time())}.wav" | |
| filepath = os.path.join(GENERATED_MEDIA_DIR, filename) | |
| if model == "kokoro": | |
| pipeline = _load_kokoro() | |
| if pipeline is None: | |
| return "Error: Failed to load Kokoro model", None | |
| # Generate audio - Kokoro yields segments | |
| generator = pipeline(text, voice=voice) | |
| audio_segments = [] | |
| for i, (gs, ps, audio) in enumerate(generator): | |
| audio_segments.append(audio) | |
| # Concatenate all segments | |
| import numpy as np | |
| full_audio = np.concatenate(audio_segments) if len(audio_segments) > 1 else audio_segments[0] | |
| # Kokoro outputs 24kHz audio | |
| sf.write(filepath, full_audio, 24000) | |
| return f"TTS saved as {filepath}", filepath | |
| elif model == "supertonic": | |
| tts = _load_supertonic() | |
| if tts is None: | |
| return "Error: Failed to load Supertonic model", None | |
| # Get voice style (F3 is a good default female voice) | |
| # Available: M1-M5 (male), F1-F5 (female) | |
| voice_name = voice if voice != "default" else "F3" | |
| style = tts.get_voice_style(voice_name=voice_name) | |
| # Generate audio with Supertonic | |
| wav, duration = tts.synthesize(text, voice_style=style) | |
| # Supertonic outputs 24kHz audio, wav shape is (1, num_samples) | |
| audio = wav.squeeze() # Remove batch dimension | |
| sf.write(filepath, audio, 24000) | |
| return f"TTS saved as {filepath}", filepath | |
| else: | |
| return f"Error: Model '{model}' does not support local generation", None | |
| except Exception as e: | |
| return f"Error generating TTS locally with {model}: {str(e)}", None | |
| def generate_tts_api(text, model="kokoro", voice="default"): | |
| """ | |
| Generate TTS audio using HuggingFace Space APIs (fallback). | |
| Args: | |
| text: The text to convert to speech | |
| model: One of "kokoro", "supertonic", or "glm-tts" | |
| voice: Voice parameter (model-specific) | |
| Returns: | |
| Tuple of (status_message, audio_path or None) | |
| """ | |
| from gradio_client import Client | |
| if model not in TTS_MODELS: | |
| return f"Error: Unknown TTS model '{model}'. Available: {list(TTS_MODELS.keys())}", None | |
| model_info = TTS_MODELS[model] | |
| spaces_to_try = [model_info["space"]] + model_info.get("fallback_spaces", []) | |
| last_error = None | |
| for space in spaces_to_try: | |
| try: | |
| print(f"Trying TTS via {space}...") | |
| client = Client(space) | |
| # Try to discover API endpoints | |
| result = None | |
| if model == "kokoro": | |
| # TTS Arena uses different endpoint names | |
| if "Arena" in space: | |
| # Try arena-style endpoints | |
| try: | |
| result = client.predict( | |
| text, # text input | |
| voice if voice != "default" else "af_heart", # voice | |
| 1.0, # speed | |
| api_name="/synthesize" | |
| ) | |
| except Exception: | |
| # Try alternate endpoint | |
| result = client.predict( | |
| text, | |
| api_name="/predict" | |
| ) | |
| else: | |
| # Try common Kokoro endpoint names | |
| for endpoint in ["/generate_speech", "/generate", "/synthesize", "/predict"]: | |
| try: | |
| result = client.predict( | |
| text, | |
| voice if voice != "default" else "af_heart", | |
| 1.0, # speed | |
| api_name=endpoint | |
| ) | |
| break | |
| except Exception: | |
| continue | |
| elif model == "supertonic": | |
| for endpoint in ["/synthesize", "/predict", "/generate"]: | |
| try: | |
| result = client.predict(text, api_name=endpoint) | |
| break | |
| except Exception: | |
| continue | |
| elif model == "glm-tts": | |
| for endpoint in ["/synthesize", "/predict", "/generate", "/infer"]: | |
| try: | |
| result = client.predict(text, api_name=endpoint) | |
| break | |
| except Exception: | |
| continue | |
| if result is None: | |
| continue | |
| # Process result - usually returns audio file path or tuple | |
| audio_path = None | |
| if isinstance(result, str) and os.path.exists(result): | |
| audio_path = result | |
| elif isinstance(result, tuple): | |
| for item in result: | |
| if isinstance(item, str) and os.path.exists(item): | |
| audio_path = item | |
| break | |
| elif isinstance(result, dict) and 'audio' in result: | |
| audio_path = result['audio'] | |
| if audio_path and os.path.exists(audio_path): | |
| safe_text = text[:30].replace(' ', '_').replace('/', '_').replace('\\', '_') | |
| filename = f"tts_{model}_{safe_text}_{int(time.time())}.wav" | |
| filepath = os.path.join(GENERATED_MEDIA_DIR, filename) | |
| import shutil | |
| shutil.copy(audio_path, filepath) | |
| return f"TTS saved as {filepath}", filepath | |
| except Exception as e: | |
| last_error = str(e) | |
| print(f"TTS API error with {space}: {e}") | |
| continue | |
| return f"Error: All TTS API attempts failed. Last error: {last_error}", None | |
| # ============================================================ | |
| # LOCAL 3D GENERATION (Shap-E) | |
| # ============================================================ | |
| shap_e_model = None | |
| shap_e_diffusion = None | |
| shap_e_xm = None | |
| def load_shap_e(): | |
| """Load Shap-E model for local 3D generation.""" | |
| global shap_e_model, shap_e_diffusion, shap_e_xm | |
| if shap_e_model is None: | |
| _torch = _ensure_torch() | |
| print("Loading Shap-E...") | |
| import shap_e | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu") | |
| shap_e_xm = load_model('transmitter', device=device) | |
| shap_e_model = load_model('text300M', device=device) | |
| shap_e_diffusion = diffusion_from_config(load_config('diffusion')) | |
| print("Shap-E loaded successfully") | |
| return shap_e_model, shap_e_diffusion, shap_e_xm | |
| def generate_3d_local(prompt, guidance_scale=15.0, num_steps=64): | |
| """ | |
| Generate 3D model locally using Shap-E. | |
| Args: | |
| prompt: Text description of the 3D object | |
| guidance_scale: Classifier-free guidance scale | |
| num_steps: Number of diffusion steps | |
| Returns: | |
| Tuple of (status_message, model_path or None) | |
| """ | |
| global shap_e_model, shap_e_diffusion, shap_e_xm | |
| try: | |
| _torch = _ensure_torch() | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.util.notebooks import decode_latent_mesh | |
| import trimesh | |
| device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu") | |
| # Load model if needed | |
| load_shap_e() | |
| if shap_e_model is None: | |
| return "Error: Failed to load Shap-E model", None | |
| # Generate latents | |
| latents = sample_latents( | |
| batch_size=1, | |
| model=shap_e_model, | |
| diffusion=shap_e_diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt]), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True, | |
| use_karras=True, | |
| karras_steps=num_steps, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| # Decode to mesh | |
| mesh = decode_latent_mesh(shap_e_xm, latents[0]).tri_mesh() | |
| # Save as GLB | |
| safe_prompt = prompt[:40].replace(' ', '_').replace('/', '_').replace('\\', '_') | |
| filename = f"3d_local_{safe_prompt}_{int(time.time())}.glb" | |
| filepath = os.path.join(GENERATED_MEDIA_DIR, filename) | |
| # Convert to trimesh and export | |
| tri_mesh = trimesh.Trimesh(vertices=mesh.verts, faces=mesh.faces) | |
| tri_mesh.export(filepath) | |
| return f"3D model saved as {filepath}", filepath | |
| except Exception as e: | |
| return f"Error generating 3D locally: {str(e)}", None | |
| # ============================================================ | |
| # VIDEO GENERATION (Text-to-Video, Image-to-Video) | |
| # ============================================================ | |
| _video_pipe = None | |
| _current_video_model = None | |
| # 3 min timeout for video generation | |
| def generate_video_t2v(prompt, model_name="Lightricks/LTX-Video-0.9.7-distilled", | |
| num_steps=4, duration_seconds=2, width=512, height=320): | |
| """ | |
| Generate video from text prompt using diffusers. | |
| Args: | |
| prompt: Text description of the video | |
| model_name: HuggingFace model ID | |
| num_steps: Number of inference steps | |
| duration_seconds: Video duration in seconds | |
| width: Video width | |
| height: Video height | |
| Returns: | |
| Tuple of (status_message, video_path or None) | |
| """ | |
| global _video_pipe, _current_video_model | |
| _torch = _ensure_torch() | |
| try: | |
| from diffusers.utils import export_to_video | |
| # Calculate frames (target 24fps) | |
| raw_frames = duration_seconds * 24 | |
| # LTX-Video requires (frames - 1) divisible by 8, so frames = 8n + 1 | |
| # Valid: 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97... | |
| if "LTX" in model_name or "Lightricks" in model_name: | |
| # Round to nearest valid frame count (8n + 1) | |
| n = round((raw_frames - 1) / 8) | |
| num_frames = max(9, n * 8 + 1) # Minimum 9 frames | |
| # Ensure dimensions divisible by 32 | |
| width = (width // 32) * 32 | |
| height = (height // 32) * 32 | |
| print(f"[LTX] Adjusted to {num_frames} frames (was {raw_frames}), {width}x{height}") | |
| else: | |
| num_frames = raw_frames | |
| negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" | |
| # Load pipeline based on model type | |
| if _video_pipe is None or _current_video_model != model_name: | |
| print(f"Loading video model: {model_name}") | |
| _torch.cuda.empty_cache() | |
| gc.collect() | |
| if "LTX" in model_name or "Lightricks" in model_name: | |
| from diffusers import LTXPipeline | |
| _video_pipe = LTXPipeline.from_pretrained( | |
| model_name, torch_dtype=_torch.bfloat16 | |
| ) | |
| # Use CPU offload for memory efficiency | |
| _video_pipe.enable_model_cpu_offload() | |
| if hasattr(_video_pipe, 'vae'): | |
| _video_pipe.vae.enable_tiling() | |
| elif "Wan" in model_name: | |
| from diffusers import WanPipeline, AutoencoderKLWan | |
| vae = AutoencoderKLWan.from_pretrained( | |
| model_name, subfolder="vae", torch_dtype=_torch.float32 | |
| ) | |
| _video_pipe = WanPipeline.from_pretrained( | |
| model_name, vae=vae, torch_dtype=_torch.bfloat16 | |
| ) | |
| # Use CPU offload for memory efficiency | |
| _video_pipe.enable_model_cpu_offload() | |
| if hasattr(_video_pipe, 'vae'): | |
| _video_pipe.vae.enable_tiling() | |
| elif "Hunyuan" in model_name: | |
| from diffusers import HunyuanVideoPipeline | |
| _video_pipe = HunyuanVideoPipeline.from_pretrained( | |
| model_name, torch_dtype=_torch.bfloat16 | |
| ) | |
| # Use CPU offload for memory efficiency | |
| _video_pipe.enable_model_cpu_offload() | |
| if hasattr(_video_pipe, 'vae'): | |
| _video_pipe.vae.enable_tiling() | |
| else: | |
| from diffusers import DiffusionPipeline | |
| _video_pipe = DiffusionPipeline.from_pretrained( | |
| model_name, torch_dtype=_torch.bfloat16 | |
| ) | |
| _video_pipe.enable_model_cpu_offload() | |
| _current_video_model = model_name | |
| print(f"Video model loaded: {model_name}") | |
| print(f"Generating video: {width}x{height}, {num_frames} frames, {num_steps} steps") | |
| # Generate video with model-specific parameters | |
| if "LTX" in model_name or "Lightricks" in model_name: | |
| output = _video_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| num_inference_steps=num_steps, | |
| guidance_scale=1.0, | |
| ) | |
| elif "Wan" in model_name: | |
| output = _video_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| guidance_scale=5.0, | |
| num_inference_steps=num_steps, | |
| ) | |
| elif "Hunyuan" in model_name: | |
| output = _video_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_steps, | |
| ) | |
| else: | |
| output = _video_pipe( | |
| prompt=prompt, | |
| num_inference_steps=num_steps, | |
| num_frames=num_frames, | |
| width=width, | |
| height=height, | |
| ) | |
| # Get video frames | |
| if hasattr(output, 'frames'): | |
| frames = output.frames[0] if isinstance(output.frames, list) else output.frames | |
| else: | |
| frames = output[0] | |
| # Save to file | |
| safe_prompt = prompt[:30].replace(' ', '_').replace('/', '_').replace('\\', '_') | |
| filename = f"video_t2v_{safe_prompt}_{int(time.time())}.mp4" | |
| filepath = os.path.join(GENERATED_MEDIA_DIR, filename) | |
| export_to_video(frames, filepath, fps=24) | |
| return f"Video saved as {filepath}", filepath | |
| except Exception as e: | |
| import traceback | |
| print(f"Error generating video: {traceback.format_exc()}") | |
| return f"Error generating video: {str(e)}", None | |
| # 3 min timeout for video generation | |
| def generate_video_i2v(image_path, prompt="", model_name="Wan-AI/Wan2.2-TI2V-5B-Diffusers", | |
| num_steps=8, duration_seconds=2): | |
| """ | |
| Generate video from image using diffusers. | |
| Args: | |
| image_path: Path to input image | |
| prompt: Optional motion/style prompt | |
| model_name: HuggingFace model ID | |
| num_steps: Number of inference steps | |
| duration_seconds: Video duration in seconds | |
| Returns: | |
| Tuple of (status_message, video_path or None) | |
| """ | |
| global _video_pipe, _current_video_model | |
| _torch = _ensure_torch() | |
| try: | |
| from diffusers.utils import export_to_video | |
| from PIL import Image | |
| # Load image | |
| image = Image.open(image_path).convert("RGB") | |
| num_frames = duration_seconds * 24 | |
| negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" | |
| # Load pipeline if needed or if model changed | |
| if _video_pipe is None or _current_video_model != model_name: | |
| print(f"Loading video model: {model_name}") | |
| _torch.cuda.empty_cache() | |
| gc.collect() | |
| if "Wan" in model_name: | |
| from diffusers import WanImageToVideoPipeline, AutoencoderKLWan | |
| # TI2V-5B supports both T2V and I2V, use it directly | |
| # For T2V-A14B, switch to I2V-A14B-Diffusers | |
| if "TI2V" in model_name: | |
| i2v_model = model_name # TI2V-5B handles I2V directly | |
| elif "T2V" in model_name: | |
| i2v_model = model_name.replace("T2V", "I2V") # T2V-A14B -> I2V-A14B | |
| else: | |
| i2v_model = model_name # Already I2V model | |
| vae = AutoencoderKLWan.from_pretrained( | |
| i2v_model, subfolder="vae", torch_dtype=_torch.float32 | |
| ) | |
| _video_pipe = WanImageToVideoPipeline.from_pretrained( | |
| i2v_model, vae=vae, torch_dtype=_torch.bfloat16 | |
| ) | |
| # Use CPU offload for memory efficiency | |
| _video_pipe.enable_model_cpu_offload() | |
| if hasattr(_video_pipe, 'vae'): | |
| _video_pipe.vae.enable_tiling() | |
| else: | |
| from diffusers import DiffusionPipeline | |
| _video_pipe = DiffusionPipeline.from_pretrained( | |
| model_name, torch_dtype=_torch.bfloat16 | |
| ) | |
| _video_pipe.enable_model_cpu_offload() | |
| _current_video_model = model_name | |
| print(f"Video model loaded: {model_name}") | |
| # Get image dimensions | |
| width, height = image.size | |
| # Ensure dimensions are multiples of 16 | |
| width = (width // 16) * 16 | |
| height = (height // 16) * 16 | |
| image = image.resize((width, height)) | |
| print(f"Generating I2V: {width}x{height}, {num_frames} frames, {num_steps} steps") | |
| # Generate video from image | |
| if "Wan" in model_name: | |
| output = _video_pipe( | |
| image=image, | |
| prompt=prompt if prompt else "camera movement, smooth motion", | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| guidance_scale=5.0, | |
| num_inference_steps=num_steps, | |
| ) | |
| else: | |
| output = _video_pipe( | |
| image=image, | |
| prompt=prompt if prompt else None, | |
| num_inference_steps=num_steps, | |
| num_frames=num_frames, | |
| ) | |
| if hasattr(output, 'frames'): | |
| frames = output.frames[0] if isinstance(output.frames, list) else output.frames | |
| else: | |
| frames = output[0] | |
| # Save to file | |
| safe_prompt = (prompt[:20] if prompt else "i2v").replace(' ', '_').replace('/', '_') | |
| filename = f"video_i2v_{safe_prompt}_{int(time.time())}.mp4" | |
| filepath = os.path.join(GENERATED_MEDIA_DIR, filename) | |
| export_to_video(frames, filepath, fps=24) | |
| return f"Video saved as {filepath}", filepath | |
| except Exception as e: | |
| import traceback | |
| print(f"Error generating I2V: {traceback.format_exc()}") | |
| return f"Error generating video: {str(e)}", None | |
| # ============================================================ | |
| # LOCAL TALKING HEAD GENERATION (SadTalker) | |
| # ============================================================ | |
| sadtalker_model = None | |
| def load_sadtalker(): | |
| """Load SadTalker model for local talking head generation.""" | |
| global sadtalker_model | |
| if sadtalker_model is None: | |
| print("[SadTalker] Loading SadTalker model...") | |
| try: | |
| # Clone and setup SadTalker if not present | |
| import subprocess | |
| import sys | |
| sadtalker_path = os.path.join(os.path.dirname(__file__), "SadTalker") | |
| if not os.path.exists(sadtalker_path): | |
| print("[SadTalker] Cloning SadTalker repository...") | |
| subprocess.run([ | |
| "git", "clone", "--depth", "1", | |
| "https://github.com/OpenTalker/SadTalker.git", | |
| sadtalker_path | |
| ], check=True) | |
| # Add to path | |
| if sadtalker_path not in sys.path: | |
| sys.path.insert(0, sadtalker_path) | |
| # Download checkpoints if needed | |
| checkpoints_path = os.path.join(sadtalker_path, "checkpoints") | |
| if not os.path.exists(checkpoints_path): | |
| print("[SadTalker] Downloading checkpoints...") | |
| os.makedirs(checkpoints_path, exist_ok=True) | |
| # Use huggingface_hub to download | |
| from huggingface_hub import hf_hub_download | |
| # Download the main checkpoints | |
| for filename in [ | |
| "mapping_00109-model.pth.tar", | |
| "mapping_00229-model.pth.tar", | |
| "SadTalker_V0.0.2_256.safetensors", | |
| "SadTalker_V0.0.2_512.safetensors" | |
| ]: | |
| try: | |
| hf_hub_download( | |
| repo_id="vinthony/SadTalker", | |
| filename=filename, | |
| local_dir=checkpoints_path | |
| ) | |
| except Exception as e: | |
| print(f"[SadTalker] Warning: Could not download {filename}: {e}") | |
| sadtalker_model = {"path": sadtalker_path, "loaded": True} | |
| print("[SadTalker] SadTalker loaded successfully") | |
| except Exception as e: | |
| print(f"[SadTalker] Failed to load: {e}") | |
| sadtalker_model = None | |
| return sadtalker_model | |
| def generate_talking_head_local(image_path, audio_path, preprocess="crop"): | |
| """ | |
| Generate talking head video locally using SadTalker. | |
| Args: | |
| image_path: Path to portrait image | |
| audio_path: Path to audio file | |
| preprocess: Preprocessing mode - "crop", "resize", or "full" | |
| Returns: | |
| Tuple of (status_message, video_path or None) | |
| """ | |
| global sadtalker_model | |
| try: | |
| import subprocess | |
| import sys | |
| print(f"[SadTalker] Starting local generation...") | |
| print(f"[SadTalker] Image: {image_path}") | |
| print(f"[SadTalker] Audio: {audio_path}") | |
| # Load model | |
| model_info = load_sadtalker() | |
| if model_info is None: | |
| return "Error: Failed to load SadTalker model", None | |
| sadtalker_path = model_info["path"] | |
| # Create output directory | |
| output_dir = os.path.join(GENERATED_MEDIA_DIR, "sadtalker_output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Run inference using subprocess (SadTalker's inference script) | |
| inference_script = os.path.join(sadtalker_path, "inference.py") | |
| if os.path.exists(inference_script): | |
| cmd = [ | |
| sys.executable, inference_script, | |
| "--driven_audio", audio_path, | |
| "--source_image", image_path, | |
| "--result_dir", output_dir, | |
| "--preprocess", preprocess, | |
| "--size", "256", | |
| "--still", # Less head movement for stability | |
| ] | |
| print(f"[SadTalker] Running: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
| if result.returncode != 0: | |
| print(f"[SadTalker] Error output: {result.stderr}") | |
| return f"Error: SadTalker inference failed: {result.stderr[:500]}", None | |
| # Find the output video | |
| for f in os.listdir(output_dir): | |
| if f.endswith(".mp4"): | |
| video_path = os.path.join(output_dir, f) | |
| # Move to main output directory | |
| final_path = os.path.join( | |
| GENERATED_MEDIA_DIR, | |
| f"talking_head_local_{int(time.time())}.mp4" | |
| ) | |
| import shutil | |
| shutil.move(video_path, final_path) | |
| print(f"[SadTalker] Success! Video saved to: {final_path}") | |
| return f"Talking head video saved as {final_path}", final_path | |
| return "Error: No output video found", None | |
| else: | |
| # Fallback: Use the SadTalker as a module | |
| print("[SadTalker] inference.py not found, trying module import...") | |
| return "Error: SadTalker inference script not found", None | |
| except subprocess.TimeoutExpired: | |
| return "Error: SadTalker generation timed out (>5 minutes)", None | |
| except Exception as e: | |
| import traceback | |
| print(f"[SadTalker] Error: {traceback.format_exc()}") | |
| return f"Error generating talking head locally: {str(e)}", None |