Nad54's picture
Update app.py
e56bc06 verified
import sys, os
sys.path.append("../")
# ↓↓↓ ajoute ceci tout de suite après
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces
import torch
import random
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from pipeline import InstantCharacterFluxPipeline
# --------------------------------------------
# Global & Model paths
# --------------------------------------------
MAX_SEED = np.iinfo(np.int32).max
device = "cuda" if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float16 if "cuda" in str(device) else torch.float32
ip_adapter_path = hf_hub_download("tencent/InstantCharacter", "instantcharacter_ip-adapter.bin")
base_model = "black-forest-labs/FLUX.1-dev"
image_encoder_path = "google/siglip-so400m-patch14-384"
image_encoder_2_path = "facebook/dinov2-giant"
birefnet_path = "ZhengPeng7/BiRefNet"
makoto_style_lora_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai", "Makoto_Shinkai_style.safetensors")
ghibli_style_lora_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Ghibli", "ghibli_style.safetensors")
# local One Piece LoRA
onepiece_style_lora_path = os.path.join(os.path.dirname(__file__), "onepiece_flux_v2.safetensors")
ONEPIECE_TRIGGER = "onepiece style"
# ---- Universal prompt (homme ou femme)
UNIVERSAL_PROMPT = (
"Upper-body anime portrait of a pirate character inspired by One Piece, confident and charismatic expression, "
"original and dynamic pose, expressive eyes with anime-style lighting, slightly windswept hair, preserving the subject’s "
"distinctive facial features and hairstyle (and facial hair if present), detailed anime rendering of the face, natural matte skin tone, "
"lips matching the skin color (no pink or gloss), wearing stylish pirate clothing appropriate to the subject (open shirt, coat, vest, "
"belts, scarves, cape, etc...), with optional pirate accessories (earrings, necklace, bandana or hat) only if they fit the subject’s style, "
"well-framed head and shoulders, centered and balanced, cinematic warm lighting, high-quality cel-shaded coloring and clean linework, "
"One Piece-style background (ship deck or ocean sky), designed to look cool, original and iconic like a real One Piece portrait character, "
"no frame, no text."
)
# --------------------------------------------
# Init pipeline
# --------------------------------------------
pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pipe.to(device)
pipe.init_adapter(
image_encoder_path=image_encoder_path,
image_encoder_2_path=image_encoder_2_path,
subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, nb_token=1024),
)
# --------------------------------------------
# Background remover
# --------------------------------------------
birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
birefnet_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def remove_bkg(subject_image):
def infer_matting(img_pil):
inp = birefnet_transform(img_pil).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(inp)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
mask = transforms.ToPILImage()(pred).resize(img_pil.size)
return np.array(mask)[..., None]
def pad_to_square(image, pad_value=255):
H, W = image.shape[:2]
if H == W:
return image
pad = abs(H - W)
pad1, pad2 = pad // 2, pad - pad // 2
pad_param = ((0, 0), (pad1, pad2), (0, 0)) if H > W else ((pad1, pad2), (0, 0), (0, 0))
return np.pad(image, pad_param, "constant", constant_values=pad_value)
mask = infer_matting(subject_image)[..., 0]
subject_np = np.array(subject_image)
mask = (mask > 128).astype(np.uint8) * 255
sample_mask = np.stack([mask] * 3, axis=-1)
obj = sample_mask / 255 * subject_np + (1 - sample_mask / 255) * 255
cropped = pad_to_square(obj, 255)
return Image.fromarray(cropped.astype(np.uint8))
# --------------------------------------------
# Generation logic
# --------------------------------------------
def randomize_seed(seed, randomize):
return random.randint(0, MAX_SEED) if randomize else seed
@spaces.GPU
def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps, seed, style_mode, negative_prompt=""):
input_image = remove_bkg(input_image)
if style_mode == "Makoto Shinkai style":
lora_path, trigger = makoto_style_lora_path, "Makoto Shinkai style"
elif style_mode == "Ghibli style":
lora_path, trigger = ghibli_style_lora_path, "ghibli style"
elif style_mode == "One Piece style":
lora_path, trigger = onepiece_style_lora_path, ONEPIECE_TRIGGER
else:
lora_path, trigger = None, ""
generator = torch.manual_seed(seed)
common_args = dict(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=1024, height=768, # <<< sortie fixe 1024 x 768
subject_image=input_image,
subject_scale=scale,
generator=generator,
)
if lora_path:
result = pipe.with_style_lora(lora_file_path=lora_path, trigger=trigger, **common_args)
else:
result = pipe(**common_args)
return result.images
# --------------------------------------------
# UI definition (Gradio 5)
# --------------------------------------------
def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, negative_prompt):
seed = randomize_seed(seed, randomize)
return create_image(image, prompt, scale, guidance, steps, seed, style, negative_prompt)
title = "🎨 InstantCharacter + One Piece LoRA"
description = (
"Upload your photo, use the universal One Piece prompt, choose **One Piece style**. "
"Output is fixed to **1024×780**. API is enabled for Make.com."
)
# (ne PAS mettre api_open ici)
demo = gr.Interface(
fn=generate_fn,
inputs=[
gr.Image(label="Source Image", type="pil"),
gr.Textbox(label="Prompt", value=f"a character is riding a bike in snow, {ONEPIECE_TRIGGER}"),
gr.Slider(0, 1.5, value=1.0, step=0.01, label="Scale"),
gr.Dropdown(choices=[None, "Makoto Shinkai style", "Ghibli style", "One Piece style"],
value="One Piece style", label="Style"),
gr.Slider(1, 7.0, value=3.5, step=0.01, label="Guidance Scale"),
gr.Slider(5, 50, value=28, step=1, label="Inference Steps"),
gr.Slider(-1000000, 1000000, value=123456, step=1, label="Seed"),
gr.Checkbox(value=True, label="Randomize Seed"),
gr.Textbox(label="Negative Prompt", placeholder="e.g. photorealistic, realistic skin, pores, hdr")
],
outputs=gr.Gallery(label="Generated Image"),
title=title,
description=description,
examples=[
["./assets/girl.jpg", f"A girl playing guitar, {ONEPIECE_TRIGGER}", 0.9, "One Piece style", 3.5, 28, 123, False, ""],
["./assets/boy.jpg", f"A boy riding a bike, {ONEPIECE_TRIGGER}", 0.9, "One Piece style", 3.5, 28, 123, False, ""]
]
)
# ⇩⇩⇩ utiliser show_api=True ici
demo.launch(show_api=True)