multitalk-handler / handler.py
ajwestfield's picture
Add MultiTalk custom handler for HF Inference Endpoint
2034ad0
import torch
import json
import base64
import io
from typing import Dict, Any, List
from PIL import Image
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the MultiTalk model handler
"""
import sys
import os
# Add error handling for missing dependencies
try:
from diffusers import DiffusionPipeline
import librosa
except ImportError as e:
print(f"Missing dependency: {e}")
print("Please ensure all requirements are installed")
raise
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Initialize model with low VRAM mode if needed
try:
# Try to load the model
self.pipeline = DiffusionPipeline.from_pretrained(
path if path else "MeiGen-AI/MeiGen-MultiTalk",
torch_dtype=torch.float16,
device_map="auto"
)
# Enable memory efficient attention if available
if hasattr(self.pipeline, "enable_attention_slicing"):
self.pipeline.enable_attention_slicing()
if hasattr(self.pipeline, "enable_vae_slicing"):
self.pipeline.enable_vae_slicing()
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the inference request
Args:
data: Input data containing:
- inputs: The input prompt or image
- parameters: Additional generation parameters
Returns:
Dict containing the generated output
"""
try:
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle different input types
if isinstance(inputs, str):
# Text prompt input
prompt = inputs
image = None
elif isinstance(inputs, dict):
prompt = inputs.get("prompt", "")
# Handle base64 encoded image if provided
if "image" in inputs:
image_data = base64.b64decode(inputs["image"])
image = Image.open(io.BytesIO(image_data))
else:
image = None
else:
prompt = str(inputs)
image = None
# Set default parameters
num_inference_steps = parameters.get("num_inference_steps", 25)
guidance_scale = parameters.get("guidance_scale", 7.5)
height = parameters.get("height", 480)
width = parameters.get("width", 640)
num_frames = parameters.get("num_frames", 16)
# Generate video
with torch.no_grad():
if hasattr(self.pipeline, "__call__"):
result = self.pipeline(
prompt=prompt,
image=image,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
)
# Handle the output
if hasattr(result, "frames"):
# Convert frames to base64 encoded video or images
frames = result.frames[0] if len(result.frames) > 0 else []
# Convert frames to base64 encoded images
encoded_frames = []
for frame in frames:
if isinstance(frame, Image.Image):
buffered = io.BytesIO()
frame.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
encoded_frames.append(img_str)
return {
"frames": encoded_frames,
"num_frames": len(encoded_frames),
"message": "Video generated successfully"
}
else:
return {
"error": "Model output format not recognized",
"result": str(result)
}
else:
return {
"error": "Model pipeline not properly initialized"
}
except Exception as e:
import traceback
return {
"error": str(e),
"traceback": traceback.format_exc()
}