import torch import json import base64 import io from typing import Dict, Any, List from PIL import Image import numpy as np class EndpointHandler: def __init__(self, path=""): """ Initialize the MultiTalk model handler """ import sys import os # Add error handling for missing dependencies try: from diffusers import DiffusionPipeline import librosa except ImportError as e: print(f"Missing dependency: {e}") print("Please ensure all requirements are installed") raise self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Initialize model with low VRAM mode if needed try: # Try to load the model self.pipeline = DiffusionPipeline.from_pretrained( path if path else "MeiGen-AI/MeiGen-MultiTalk", torch_dtype=torch.float16, device_map="auto" ) # Enable memory efficient attention if available if hasattr(self.pipeline, "enable_attention_slicing"): self.pipeline.enable_attention_slicing() if hasattr(self.pipeline, "enable_vae_slicing"): self.pipeline.enable_vae_slicing() print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") raise def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the inference request Args: data: Input data containing: - inputs: The input prompt or image - parameters: Additional generation parameters Returns: Dict containing the generated output """ try: # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle different input types if isinstance(inputs, str): # Text prompt input prompt = inputs image = None elif isinstance(inputs, dict): prompt = inputs.get("prompt", "") # Handle base64 encoded image if provided if "image" in inputs: image_data = base64.b64decode(inputs["image"]) image = Image.open(io.BytesIO(image_data)) else: image = None else: prompt = str(inputs) image = None # Set default parameters num_inference_steps = parameters.get("num_inference_steps", 25) guidance_scale = parameters.get("guidance_scale", 7.5) height = parameters.get("height", 480) width = parameters.get("width", 640) num_frames = parameters.get("num_frames", 16) # Generate video with torch.no_grad(): if hasattr(self.pipeline, "__call__"): result = self.pipeline( prompt=prompt, image=image, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale ) # Handle the output if hasattr(result, "frames"): # Convert frames to base64 encoded video or images frames = result.frames[0] if len(result.frames) > 0 else [] # Convert frames to base64 encoded images encoded_frames = [] for frame in frames: if isinstance(frame, Image.Image): buffered = io.BytesIO() frame.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() encoded_frames.append(img_str) return { "frames": encoded_frames, "num_frames": len(encoded_frames), "message": "Video generated successfully" } else: return { "error": "Model output format not recognized", "result": str(result) } else: return { "error": "Model pipeline not properly initialized" } except Exception as e: import traceback return { "error": str(e), "traceback": traceback.format_exc() }