| |
|
| |
|
| | import sys, os, random, numpy as np, torch |
| | sys.path.append("../") |
| |
|
| | from PIL import Image |
| | import spaces |
| | import gradio as gr |
| | from gradio.themes import Soft |
| | from huggingface_hub import hf_hub_download |
| | from transformers import AutoModelForImageSegmentation |
| | from torchvision import transforms |
| |
|
| | from pipeline import InstantCharacterFluxPipeline |
| |
|
| | |
| | |
| | |
| | MAX_SEED = np.iinfo(np.int32).max |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | dtype = torch.float16 if device == "cuda" else torch.float32 |
| |
|
| | |
| | |
| | |
| | ip_adapter_path = hf_hub_download("tencent/InstantCharacter", |
| | "instantcharacter_ip-adapter.bin") |
| | base_model = "black-forest-labs/FLUX.1-dev" |
| | image_encoder_path = "google/siglip-so400m-patch14-384" |
| | image_encoder2_path = "facebook/dinov2-giant" |
| | birefnet_path = "ZhengPeng7/BiRefNet" |
| | makoto_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai", |
| | "Makoto_Shinkai_style.safetensors") |
| | ghibli_style_path = hf_hub_download("InstantX/FLUX.1-dev-LoRA-Ghibli", |
| | "ghibli_style.safetensors") |
| |
|
| | |
| | |
| | |
| | pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, |
| | torch_dtype=torch.bfloat16) |
| | pipe.to(device) |
| | pipe.init_adapter( |
| | image_encoder_path=image_encoder_path, |
| | image_encoder_2_path=image_encoder2_path, |
| | subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, |
| | nb_token=1024), |
| | ) |
| |
|
| | |
| | |
| | |
| | birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, |
| | trust_remote_code=True) |
| | birefnet.to(device).eval() |
| | birefnet_tf = transforms.Compose([ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], |
| | [0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | |
| | |
| | |
| | def randomize_seed_fn(seed: int, randomize: bool) -> int: |
| | return random.randint(0, MAX_SEED) if randomize else seed |
| |
|
| | def _infer_matting(img_pil): |
| | with torch.no_grad(): |
| | inp = birefnet_tf(img_pil).unsqueeze(0).to(device) |
| | mask = birefnet(inp)[-1].sigmoid().cpu()[0, 0].numpy() |
| | return (mask * 255).astype(np.uint8) |
| |
|
| | def _bbox_from_mask(mask, th=128): |
| | ys, xs = np.where(mask >= th) |
| | if not len(xs): |
| | return [0, 0, mask.shape[1]-1, mask.shape[0]-1] |
| | return [xs.min(), ys.min(), xs.max(), ys.max()] |
| |
|
| | def _pad_square(arr, pad_val=255): |
| | h, w = arr.shape[:2] |
| | if h == w: |
| | return arr |
| | diff = abs(h - w) |
| | pad_1 = diff // 2 |
| | pad_2 = diff - pad_1 |
| | if h > w: |
| | pad = ((0, 0), (pad_1, pad_2), (0, 0)) |
| | else: |
| | pad = ((pad_1, pad_2), (0, 0), (0, 0)) |
| | return np.pad(arr, pad, constant_values=pad_val) |
| |
|
| | def remove_bkg(img_pil: Image.Image) -> Image.Image: |
| | mask = _infer_matting(img_pil) |
| | x1, y1, x2, y2 = _bbox_from_mask(mask) |
| | mask_bin = (mask >= 128).astype(np.uint8)[..., None] |
| | img_np = np.array(img_pil) |
| | obj = mask_bin * img_np + (1 - mask_bin) * 255 |
| | crop = obj[y1:y2+1, x1:x2+1] |
| | return Image.fromarray(_pad_square(crop).astype(np.uint8)) |
| |
|
| | def get_example(): |
| | return [ |
| | ["./assets/girl.jpg", |
| | "A girl is playing a guitar in street", 0.9, "Makoto Shinkai style"], |
| | ["./assets/boy.jpg", |
| | "A boy is riding a bike in snow", 0.9, "Makoto Shinkai style"], |
| | ] |
| |
|
| | @spaces.GPU |
| | def create_image(input_image, prompt, scale, |
| | guidance_scale, num_inference_steps, |
| | seed, style_mode): |
| | input_image = remove_bkg(input_image) |
| | gen = torch.manual_seed(seed) |
| |
|
| | if style_mode is None: |
| | imgs = pipe(prompt=prompt, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | width=1024, height=1024, |
| | subject_image=input_image, subject_scale=scale, |
| | generator=gen).images |
| | else: |
| | lora_path, trigger = ( |
| | (makoto_style_path, "Makoto Shinkai style") |
| | if style_mode == "Makoto Shinkai style" |
| | else (ghibli_style_path, "ghibli style") |
| | ) |
| | imgs = pipe.with_style_lora( |
| | lora_file_path=lora_path, trigger=trigger, |
| | prompt=prompt, num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | width=1024, height=1024, |
| | subject_image=input_image, subject_scale=scale, |
| | generator=gen).images |
| | return imgs |
| |
|
| | def run_for_examples(src, p, s, st): |
| | return create_image(src, p, s, 3.5, 28, 123456, st) |
| |
|
| | |
| | |
| | |
| | theme = Soft(primary_hue="pink", |
| | font=[gr.themes.GoogleFont("Inter")]) |
| |
|
| | css = """ |
| | body{ |
| | background:#141e30; |
| | background:linear-gradient(135deg,#141e30,#243b55); |
| | } |
| | #title{ |
| | text-align:center; |
| | font-size:2.2rem; |
| | font-weight:700; |
| | color:#ffffff; |
| | padding:20px 0 6px; |
| | } |
| | .card{ |
| | border-radius:18px; |
| | background:#ffffff0d; |
| | padding:18px 22px; |
| | backdrop-filter:blur(6px); |
| | } |
| | .gr-image,.gr-video{border-radius:14px} |
| | .gr-image:hover{box-shadow:0 0 0 4px #ec4899} |
| | footer{visibility:hidden} |
| | """ |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(css=css, theme=theme) as demo: |
| | |
| | gr.Markdown("<div id='title'>InstantCharacter PLUS</div>") |
| | gr.Markdown( |
| | "<b>Official 🤗 Gradio demo of " |
| | "<a href='https://instantcharacter.github.io/' target='_blank'>InstantCharacter</a></b>" |
| | ) |
| |
|
| | with gr.Tabs(): |
| | with gr.TabItem("Generate"): |
| | with gr.Row(equal_height=True): |
| | |
| | with gr.Column(elem_classes="card"): |
| | image_pil = gr.Image(label="Source Image", |
| | type="pil", height=380) |
| | prompt = gr.Textbox( |
| | label="Prompt", |
| | value="A character is riding a bike in snow", |
| | lines=2, |
| | ) |
| | scale = gr.Slider(0, 1.5, 1.0, step=0.01, label="Scale") |
| | style_mode = gr.Dropdown( |
| | ["None", "Makoto Shinkai style", "Ghibli style"], |
| | label="Style", |
| | value="Makoto Shinkai style", |
| | ) |
| |
|
| | with gr.Accordion("⚙️ Advanced Options", open=False): |
| | guidance_scale = gr.Slider( |
| | 1, 7, 3.5, step=0.01, label="Guidance scale" |
| | ) |
| | num_inference_steps = gr.Slider( |
| | 5, 50, 28, step=1, label="# Inference steps" |
| | ) |
| | seed = gr.Number(123456, label="Seed", precision=0) |
| | randomize_seed = gr.Checkbox( |
| | label="Randomize seed", value=True |
| | ) |
| |
|
| | generate_btn = gr.Button( |
| | "🚀 Generate", |
| | variant="primary", |
| | size="lg", |
| | elem_classes="contrast", |
| | ) |
| |
|
| | |
| | with gr.Column(elem_classes="card"): |
| | generated_image = gr.Gallery( |
| | label="Generated Image", |
| | show_label=True, |
| | height="auto", |
| | columns=[1], |
| | ) |
| |
|
| | |
| | generate_btn.click( |
| | randomize_seed_fn, |
| | [seed, randomize_seed], |
| | seed, |
| | queue=False, |
| | ).then( |
| | create_image, |
| | [ |
| | image_pil, |
| | prompt, |
| | scale, |
| | guidance_scale, |
| | num_inference_steps, |
| | seed, |
| | style_mode, |
| | ], |
| | generated_image, |
| | ) |
| |
|
| | |
| | gr.Markdown("### 🔥 Quick Examples") |
| | gr.Examples( |
| | examples=get_example(), |
| | inputs=[image_pil, prompt, scale, style_mode], |
| | outputs=generated_image, |
| | fn=run_for_examples, |
| | cache_examples=True, |
| | ) |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | demo.queue(max_size=10, api_open=False).launch() |
| |
|