| | from typing import Dict, List, Any |
| | from tangoflux import TangoFluxInference |
| | import torchaudio |
| |
|
| | from huggingface_inference_toolkit.logging import logger |
| | import io |
| | import base64 |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | |
| | |
| | self.model = TangoFluxInference(name='declare-lab/TangoFlux',device='cuda') |
| |
|
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| |
|
| | logger.info(f"Received incoming request with {data=}") |
| |
|
| | if "inputs" in data and isinstance(data["inputs"], str): |
| | prompt = data.pop("inputs") |
| | elif "prompt" in data and isinstance(data["prompt"], str): |
| | prompt = data.pop("prompt") |
| | else: |
| | raise ValueError( |
| | "Provided input body must contain either the key `inputs` or `prompt` with the" |
| | " prompt to use for the audio generation, and it needs to be a non-empty string." |
| | ) |
| |
|
| | parameters = data.pop("parameters", {}) |
| |
|
| | num_inference_steps = parameters.get("num_inference_steps", 50) |
| | duration = parameters.get("duration", 10) |
| | guidance_scale = parameters.get("guidance_scale", 3.5) |
| |
|
| | audio= self.model.generate(prompt,steps=num_inference_steps, |
| | duration=duration, |
| | guidance_scale=guidance_scale) |
| | |
| | buffer = io.BytesIO() |
| | torchaudio.save(buffer, audio, 44100, format="wav") |
| | buffer.seek(0) |
| | audio_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
| | return audio_base64 |