Spaces:
Sleeping
Sleeping
init
Browse files- .gitignore +6 -0
- app.py +103 -0
- infer.py +472 -0
- infer.sh +30 -0
- pipeline.py +214 -0
- requirements.txt +14 -0
- utils/image_utils.py +514 -0
- utils/seed_all.py +33 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
outputs/
|
| 3 |
+
tmp/
|
| 4 |
+
.DS_Store
|
| 5 |
+
weights/
|
| 6 |
+
tmp_*
|
app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces # must be first!
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from glob import glob
|
| 8 |
+
from contextlib import nullcontext
|
| 9 |
+
from pipeline import Lotus2Pipeline
|
| 10 |
+
from diffusers import (
|
| 11 |
+
FlowMatchEulerDiscreteScheduler,
|
| 12 |
+
FluxTransformer2DModel,
|
| 13 |
+
)
|
| 14 |
+
from infer import (
|
| 15 |
+
load_lora_and_lcm_weights,
|
| 16 |
+
process_single_image
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
pipeline = None
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
weight_dtype = torch.bfloat16
|
| 22 |
+
task = None
|
| 23 |
+
|
| 24 |
+
@spaces.GPU
|
| 25 |
+
def load_pipeline():
|
| 26 |
+
global pipeline, device, weight_dtype, task
|
| 27 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 28 |
+
'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
|
| 29 |
+
)
|
| 30 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 31 |
+
'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
|
| 32 |
+
)
|
| 33 |
+
transformer.requires_grad_(False)
|
| 34 |
+
transformer.to(device=device, dtype=weight_dtype)
|
| 35 |
+
transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task)
|
| 36 |
+
pipeline = Lotus2Pipeline.from_pretrained(
|
| 37 |
+
'black-forest-labs/FLUX.1-dev',
|
| 38 |
+
scheduler=noise_scheduler,
|
| 39 |
+
transformer=transformer,
|
| 40 |
+
revision=None,
|
| 41 |
+
variant=None,
|
| 42 |
+
torch_dtype=weight_dtype,
|
| 43 |
+
)
|
| 44 |
+
pipeline.local_continuity_module = local_continuity_module
|
| 45 |
+
pipeline = pipeline.to(device)
|
| 46 |
+
|
| 47 |
+
@spaces.GPU
|
| 48 |
+
def fn(image_path):
|
| 49 |
+
global pipeline, device, task
|
| 50 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 51 |
+
with nullcontext():
|
| 52 |
+
_, output_vis, _ = process_single_image(
|
| 53 |
+
image_path, pipeline,
|
| 54 |
+
task_name=task,
|
| 55 |
+
device=device,
|
| 56 |
+
num_inference_steps=10,
|
| 57 |
+
process_res=1024
|
| 58 |
+
)
|
| 59 |
+
return [Image.open(image_path), output_vis]
|
| 60 |
+
|
| 61 |
+
def build_demo():
|
| 62 |
+
global task
|
| 63 |
+
inputs = [
|
| 64 |
+
gr.Image(label="Image", type="filepath")
|
| 65 |
+
]
|
| 66 |
+
outputs = [
|
| 67 |
+
gr.ImageSlider(
|
| 68 |
+
label=f"{task.title()}",
|
| 69 |
+
type="pil",
|
| 70 |
+
slider_position=20,
|
| 71 |
+
)
|
| 72 |
+
]
|
| 73 |
+
examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg")
|
| 74 |
+
demo = gr.Interface(
|
| 75 |
+
fn=fn,
|
| 76 |
+
title="Lotus-2: Advancing Geometric Dense Prediction with Powerful Image Generative Model",
|
| 77 |
+
description=f"""
|
| 78 |
+
<strong>Please consider starring <span style="color: orange">★</span> our <a href="https://github.com/EnVision-Research/Lotus-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful! 😊</strong>
|
| 79 |
+
<br>
|
| 80 |
+
<strong>Current Task: </strong><strong style="color: red;">{task.title()}</strong>
|
| 81 |
+
""",
|
| 82 |
+
inputs=inputs,
|
| 83 |
+
outputs=outputs,
|
| 84 |
+
examples=examples,
|
| 85 |
+
examples_per_page=10
|
| 86 |
+
)
|
| 87 |
+
return demo
|
| 88 |
+
|
| 89 |
+
def main(task_name):
|
| 90 |
+
global task
|
| 91 |
+
task = task_name
|
| 92 |
+
load_pipeline()
|
| 93 |
+
demo = build_demo()
|
| 94 |
+
demo.launch(
|
| 95 |
+
server_name="0.0.0.0",
|
| 96 |
+
server_port=6381,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
task_name = "depth"
|
| 101 |
+
if not task_name in ['depth', 'normal']:
|
| 102 |
+
raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.")
|
| 103 |
+
main(task_name)
|
infer.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
"""
|
| 4 |
+
Lotus-2 Inference Script
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python infer.py --pretrained_model_name_or_path <model_path> [other_args]
|
| 8 |
+
|
| 9 |
+
If --core_predictor_model_path, --lcm_model_path, or --detail_sharpener_model_path
|
| 10 |
+
are not provided, the script will automatically download the corresponding model
|
| 11 |
+
weights from the default HuggingFace repositories.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import nullcontext
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from torch import nn
|
| 26 |
+
from tqdm.auto import tqdm
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from huggingface_hub import snapshot_download
|
| 30 |
+
HF_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
HF_AVAILABLE = False
|
| 33 |
+
logging.warning("huggingface_hub not available. Model auto-download will not work.")
|
| 34 |
+
|
| 35 |
+
from diffusers import (
|
| 36 |
+
FlowMatchEulerDiscreteScheduler,
|
| 37 |
+
FluxTransformer2DModel,
|
| 38 |
+
)
|
| 39 |
+
from diffusers.utils import convert_unet_state_dict_to_peft
|
| 40 |
+
from utils.image_utils import colorize_depth_map
|
| 41 |
+
from pipeline import Lotus2Pipeline
|
| 42 |
+
from utils.seed_all import seed_all
|
| 43 |
+
|
| 44 |
+
# Default HuggingFace repositories and model filenames
|
| 45 |
+
DEFAULT_CORE_PREDICTOR_REPO = "jingheya/Lotus-2"
|
| 46 |
+
DEFAULT_LCM_REPO = "jingheya/Lotus-2"
|
| 47 |
+
DEFAULT_DETAIL_SHARPENER_REPO = "jingheya/Lotus-2"
|
| 48 |
+
|
| 49 |
+
CORE_PREDICTOR_FILENAME = {
|
| 50 |
+
"depth": "lotus-2_core_predictor_depth.safetensors",
|
| 51 |
+
"normal": "lotus-2_core_predictor_normal.safetensors"
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
LCM_FILENAME = {
|
| 55 |
+
"depth": "lotus-2_lcm_depth.safetensors",
|
| 56 |
+
"normal": "lotus-2_lcm_normal.safetensors"
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
DETAIL_SHARPENER_FILENAME = {
|
| 60 |
+
"depth": "lotus-2_detail_sharpener_depth.safetensors",
|
| 61 |
+
"normal": "lotus-2_detail_sharpener_normal.safetensors"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def get_model_path(model_path, repo_id, filename):
|
| 65 |
+
"""
|
| 66 |
+
Get the local path for a model. If model_path is None, download from HuggingFace.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_path: Local path to model or None to download from HF
|
| 70 |
+
repo_id: HuggingFace repository ID
|
| 71 |
+
filename: Model filename in the repository
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Local path to the model file
|
| 75 |
+
"""
|
| 76 |
+
if model_path is not None:
|
| 77 |
+
return model_path
|
| 78 |
+
|
| 79 |
+
if not HF_AVAILABLE:
|
| 80 |
+
raise ImportError(
|
| 81 |
+
f"huggingface_hub is required for auto-downloading {filename} model weights. "
|
| 82 |
+
"Please install it with: pip install huggingface_hub"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
logging.info(f"Downloading {filename} model weights from {repo_id}/{filename}")
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# Create cache directory if it doesn't exist
|
| 89 |
+
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
| 90 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
# Download the entire repository and get the specific file
|
| 93 |
+
repo_path = snapshot_download(
|
| 94 |
+
repo_id=repo_id,
|
| 95 |
+
cache_dir=cache_dir,
|
| 96 |
+
local_files_only=False,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Construct the full path to the specific file
|
| 100 |
+
full_path = os.path.join(repo_path, filename)
|
| 101 |
+
|
| 102 |
+
if not os.path.exists(full_path):
|
| 103 |
+
# Try to find the file in the repo
|
| 104 |
+
for root, dirs, files in os.walk(repo_path):
|
| 105 |
+
if filename in files:
|
| 106 |
+
full_path = os.path.join(root, filename)
|
| 107 |
+
break
|
| 108 |
+
else:
|
| 109 |
+
raise FileNotFoundError(f"Could not find {filename} in the downloaded repository")
|
| 110 |
+
|
| 111 |
+
logging.info(f"Successfully downloaded {filename} model to: {full_path}")
|
| 112 |
+
return full_path
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
raise RuntimeError(f"Failed to download {filename} model from {repo_id}: {str(e)}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 119 |
+
# check_min_version("0.33.0.dev0")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Local_Continuity_Module(nn.Module):
|
| 123 |
+
def __init__(self, num_channels):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.lcm = nn.Sequential(
|
| 126 |
+
nn.Conv2d(num_channels, num_channels * 2, kernel_size=3, padding=1),
|
| 127 |
+
nn.GELU(),
|
| 128 |
+
nn.Conv2d(num_channels * 2, num_channels, kernel_size=3, padding=1),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
lcm_dtype = next(self.lcm.parameters()).dtype
|
| 133 |
+
if x.dtype != lcm_dtype:
|
| 134 |
+
x = x.to(dtype=lcm_dtype)
|
| 135 |
+
return x + self.lcm(x)
|
| 136 |
+
|
| 137 |
+
def parse_args(input_args=None):
|
| 138 |
+
parser = argparse.ArgumentParser(description="Run Lotus-2.")
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--pretrained_model_name_or_path",
|
| 141 |
+
type=str,
|
| 142 |
+
default=None,
|
| 143 |
+
required=True,
|
| 144 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--core_predictor_model_path",
|
| 148 |
+
type=str,
|
| 149 |
+
default=None,
|
| 150 |
+
help="Path to core predictor model weights",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--lcm_model_path",
|
| 154 |
+
type=str,
|
| 155 |
+
default=None,
|
| 156 |
+
help="Path to local continuity module model weights",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--detail_sharpener_model_path",
|
| 160 |
+
type=str,
|
| 161 |
+
default=None,
|
| 162 |
+
help="Path to detail sharpener model weights",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--revision",
|
| 166 |
+
type=str,
|
| 167 |
+
default=None,
|
| 168 |
+
required=False,
|
| 169 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--variant",
|
| 173 |
+
type=str,
|
| 174 |
+
default=None,
|
| 175 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--process_res",
|
| 179 |
+
type=int,
|
| 180 |
+
default=768,
|
| 181 |
+
help="The resolution for processing the images.",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--num_inference_steps",
|
| 185 |
+
type=int,
|
| 186 |
+
default=10,
|
| 187 |
+
help="Number of timesteps to infer the model.",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--input_dir",
|
| 191 |
+
type=str,
|
| 192 |
+
default=None,
|
| 193 |
+
help="The directory where the input images are stored.",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--output_dir",
|
| 197 |
+
type=str,
|
| 198 |
+
default="flux-dreambooth-lora",
|
| 199 |
+
help="The output directory where the model predictions will be written.",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument("--seed", type=int, default=None, help="Random seed.")
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--task_name",
|
| 204 |
+
type=str,
|
| 205 |
+
default="depth", # "normal"
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--mixed_precision",
|
| 209 |
+
type=str,
|
| 210 |
+
default=None,
|
| 211 |
+
choices=["no", "fp16", "bf16"],
|
| 212 |
+
help=(
|
| 213 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 214 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 215 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 216 |
+
),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if input_args is not None:
|
| 220 |
+
args = parser.parse_args(input_args)
|
| 221 |
+
else:
|
| 222 |
+
args = parser.parse_args()
|
| 223 |
+
|
| 224 |
+
return args
|
| 225 |
+
|
| 226 |
+
def process_single_image(image_path, pipeline, task_name, device,
|
| 227 |
+
num_inference_steps, process_res=768):
|
| 228 |
+
image = Image.open(image_path).convert("RGB")
|
| 229 |
+
image_np = np.array(image).astype(np.float32)
|
| 230 |
+
image_ts = torch.tensor(image_np).permute(2,0,1).unsqueeze(0)
|
| 231 |
+
image_ts = image_ts / 127.5 - 1.0
|
| 232 |
+
image_ts = image_ts.to(device)
|
| 233 |
+
|
| 234 |
+
prediction = pipeline(
|
| 235 |
+
rgb_in=image_ts,
|
| 236 |
+
prompt='',
|
| 237 |
+
num_inference_steps=num_inference_steps,
|
| 238 |
+
output_type='np',
|
| 239 |
+
process_res=process_res,
|
| 240 |
+
).images[0]
|
| 241 |
+
|
| 242 |
+
if task_name == "depth":
|
| 243 |
+
output_npy = prediction.mean(axis=-1)
|
| 244 |
+
output_vis = colorize_depth_map(output_npy, reverse_color=True)
|
| 245 |
+
elif task_name == "normal":
|
| 246 |
+
output_npy = prediction
|
| 247 |
+
output_vis = Image.fromarray((output_npy * 255).astype(np.uint8))
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(f"Invalid task name: {task_name}")
|
| 250 |
+
|
| 251 |
+
return image, output_vis, output_npy
|
| 252 |
+
|
| 253 |
+
def load_lora_and_lcm_weights(transformer, core_predictor_model_path, lcm_model_path, detail_sharpener_model_path, task_name):
|
| 254 |
+
lora_rank = 128 if task_name == 'depth' else 256
|
| 255 |
+
device = transformer.device
|
| 256 |
+
weight_dtype = transformer.dtype
|
| 257 |
+
|
| 258 |
+
target_lora_modules = [
|
| 259 |
+
"attn.to_k",
|
| 260 |
+
"attn.to_q",
|
| 261 |
+
"attn.to_v",
|
| 262 |
+
"attn.to_out.0",
|
| 263 |
+
"attn.add_k_proj",
|
| 264 |
+
"attn.add_q_proj",
|
| 265 |
+
"attn.add_v_proj",
|
| 266 |
+
"attn.to_add_out",
|
| 267 |
+
"ff.net.0.proj",
|
| 268 |
+
"ff.net.2",
|
| 269 |
+
"ff_context.net.0.proj",
|
| 270 |
+
"ff_context.net.2",
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
# Auto-download models if paths are None
|
| 274 |
+
core_predictor_model_path = get_model_path(
|
| 275 |
+
core_predictor_model_path,
|
| 276 |
+
DEFAULT_CORE_PREDICTOR_REPO,
|
| 277 |
+
CORE_PREDICTOR_FILENAME[task_name]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
lcm_model_path = get_model_path(
|
| 281 |
+
lcm_model_path,
|
| 282 |
+
DEFAULT_LCM_REPO,
|
| 283 |
+
LCM_FILENAME[task_name]
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
detail_sharpener_model_path = get_model_path(
|
| 287 |
+
detail_sharpener_model_path,
|
| 288 |
+
DEFAULT_DETAIL_SHARPENER_REPO,
|
| 289 |
+
DETAIL_SHARPENER_FILENAME[task_name]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# load lora weights for core predictor
|
| 293 |
+
core_transformer_lora_config = LoraConfig(
|
| 294 |
+
r=lora_rank,
|
| 295 |
+
lora_alpha=lora_rank,
|
| 296 |
+
init_lora_weights="gaussian",
|
| 297 |
+
target_modules=target_lora_modules,
|
| 298 |
+
)
|
| 299 |
+
transformer.add_adapter(core_transformer_lora_config, adapter_name="core_predictor")
|
| 300 |
+
|
| 301 |
+
core_lora_state_dict = Lotus2Pipeline.lora_state_dict(core_predictor_model_path)
|
| 302 |
+
core_transformer_state_dict = {
|
| 303 |
+
f'{k.replace("transformer.", "")}': v for k, v in core_lora_state_dict.items() if k.startswith("transformer.")
|
| 304 |
+
}
|
| 305 |
+
core_transformer_state_dict = convert_unet_state_dict_to_peft(core_transformer_state_dict)
|
| 306 |
+
incompatible_keys = set_peft_model_state_dict(transformer, core_transformer_state_dict, adapter_name="core_predictor")
|
| 307 |
+
if incompatible_keys is not None:
|
| 308 |
+
# check only for unexpected keys
|
| 309 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
| 310 |
+
if unexpected_keys:
|
| 311 |
+
logging.warning(
|
| 312 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
| 313 |
+
f" {unexpected_keys}. "
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
for name, param in transformer.named_parameters():
|
| 317 |
+
if "core_predictor" in name:
|
| 318 |
+
param.requires_grad = False
|
| 319 |
+
# transformer.to(device=device, dtype=weight_dtype)
|
| 320 |
+
logging.info(f"Successfully loaded lora weights for [core predictor].")
|
| 321 |
+
|
| 322 |
+
# stage1 lcm weights
|
| 323 |
+
local_continuity_module = Local_Continuity_Module(transformer.config.in_channels//4)
|
| 324 |
+
lcm_state_dict = torch.load(lcm_model_path, map_location="cpu", weights_only=True)
|
| 325 |
+
local_continuity_module.load_state_dict(lcm_state_dict)
|
| 326 |
+
local_continuity_module.requires_grad_(False)
|
| 327 |
+
local_continuity_module.to(device=device, dtype=weight_dtype)
|
| 328 |
+
logging.info(f"Successfully loaded weights for [local continuity module (LCM)].")
|
| 329 |
+
|
| 330 |
+
# stage2 lora weights (detail sharpener)
|
| 331 |
+
sharpener_transformer_lora_config = LoraConfig(
|
| 332 |
+
r=lora_rank,
|
| 333 |
+
lora_alpha=lora_rank,
|
| 334 |
+
init_lora_weights="gaussian",
|
| 335 |
+
target_modules=target_lora_modules,
|
| 336 |
+
)
|
| 337 |
+
transformer.add_adapter(sharpener_transformer_lora_config, adapter_name="detail_sharpener")
|
| 338 |
+
|
| 339 |
+
sharpener_lora_state_dict = Lotus2Pipeline.lora_state_dict(detail_sharpener_model_path)
|
| 340 |
+
sharpener_transformer_state_dict = {
|
| 341 |
+
f'{k.replace("transformer.", "")}': v for k, v in sharpener_lora_state_dict.items() if k.startswith("transformer.")
|
| 342 |
+
}
|
| 343 |
+
sharpener_transformer_state_dict = convert_unet_state_dict_to_peft(sharpener_transformer_state_dict)
|
| 344 |
+
incompatible_keys = set_peft_model_state_dict(transformer, sharpener_transformer_state_dict, adapter_name="detail_sharpener")
|
| 345 |
+
if incompatible_keys is not None:
|
| 346 |
+
# check only for unexpected keys
|
| 347 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
| 348 |
+
if unexpected_keys:
|
| 349 |
+
logging.warning(
|
| 350 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
| 351 |
+
f" {unexpected_keys}. "
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# freeze the stage2 lora
|
| 355 |
+
for name, param in transformer.named_parameters():
|
| 356 |
+
if "detail_sharpener" in name:
|
| 357 |
+
param.requires_grad = False
|
| 358 |
+
# transformer.to(device=device, dtype=weight_dtype)
|
| 359 |
+
logging.info(f"Successfully loaded lora weights for [detail sharpener].")
|
| 360 |
+
|
| 361 |
+
return transformer, local_continuity_module
|
| 362 |
+
|
| 363 |
+
def main(args):
|
| 364 |
+
logging.basicConfig(
|
| 365 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 366 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 367 |
+
level=logging.INFO,
|
| 368 |
+
)
|
| 369 |
+
logging.info("Run Lotus-2! ")
|
| 370 |
+
|
| 371 |
+
# -------------------- Preparation --------------------
|
| 372 |
+
# Check if model paths are provided, if not, they will be auto-downloaded from HuggingFace
|
| 373 |
+
if args.core_predictor_model_path is None or args.lcm_model_path is None or args.detail_sharpener_model_path is None:
|
| 374 |
+
if HF_AVAILABLE:
|
| 375 |
+
logging.info("Some model paths are not provided. Model weights will be automatically downloaded from HuggingFace.")
|
| 376 |
+
logging.info(f"Core predictor repo: {DEFAULT_CORE_PREDICTOR_REPO}")
|
| 377 |
+
logging.info(f"LCM repo: {DEFAULT_LCM_REPO}")
|
| 378 |
+
logging.info(f"Detail sharpener repo: {DEFAULT_DETAIL_SHARPENER_REPO}")
|
| 379 |
+
else:
|
| 380 |
+
logging.warning("Some model paths are not provided and huggingface_hub is not available.")
|
| 381 |
+
logging.warning("Please install huggingface_hub: pip install huggingface_hub")
|
| 382 |
+
logging.warning("Or provide local paths for all model weights.")
|
| 383 |
+
|
| 384 |
+
# Random seed
|
| 385 |
+
if args.seed is not None:
|
| 386 |
+
seed_all(args.seed)
|
| 387 |
+
|
| 388 |
+
# Output directories
|
| 389 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 390 |
+
|
| 391 |
+
output_dir_vis = os.path.join(args.output_dir, f'{args.task_name}_vis')
|
| 392 |
+
output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}_npy')
|
| 393 |
+
if not os.path.exists(output_dir_vis): os.makedirs(output_dir_vis)
|
| 394 |
+
if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)
|
| 395 |
+
|
| 396 |
+
logging.info(f"Output dir = {args.output_dir}")
|
| 397 |
+
|
| 398 |
+
# Mixed precision
|
| 399 |
+
if args.mixed_precision == "fp16":
|
| 400 |
+
weight_dtype = torch.float16
|
| 401 |
+
elif args.mixed_precision == "bf16":
|
| 402 |
+
weight_dtype = torch.bfloat16
|
| 403 |
+
else:
|
| 404 |
+
weight_dtype = torch.float32
|
| 405 |
+
logging.info(f"Running with {weight_dtype} precision.")
|
| 406 |
+
|
| 407 |
+
# Device
|
| 408 |
+
if torch.cuda.is_available():
|
| 409 |
+
device = torch.device("cuda")
|
| 410 |
+
else:
|
| 411 |
+
device = torch.device("cpu")
|
| 412 |
+
logging.warning("CUDA is not available. Running on CPU will be slow.")
|
| 413 |
+
logging.info(f"Device = {device}")
|
| 414 |
+
|
| 415 |
+
# -------------------- Data --------------------
|
| 416 |
+
input_dir = Path(args.input_dir)
|
| 417 |
+
test_images = list(input_dir.rglob('*.png')) + list(input_dir.rglob('*.jpg'))
|
| 418 |
+
test_images = sorted(test_images)
|
| 419 |
+
logging.info(f'==> There are {len(test_images)} images for validation.')
|
| 420 |
+
|
| 421 |
+
# -------------------- Load scheduler and models --------------------
|
| 422 |
+
# scheduler
|
| 423 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 424 |
+
args.pretrained_model_name_or_path, subfolder="scheduler", num_train_timesteps=10
|
| 425 |
+
)
|
| 426 |
+
# transformer
|
| 427 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 428 |
+
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
|
| 429 |
+
)
|
| 430 |
+
transformer.requires_grad_(False)
|
| 431 |
+
transformer.to(device=device, dtype=weight_dtype)
|
| 432 |
+
|
| 433 |
+
# load weights
|
| 434 |
+
transformer, local_continuity_module = load_lora_and_lcm_weights(transformer,
|
| 435 |
+
args.core_predictor_model_path,
|
| 436 |
+
args.lcm_model_path,
|
| 437 |
+
args.detail_sharpener_model_path,
|
| 438 |
+
args.task_name
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# -------------------- Pipeline --------------------
|
| 442 |
+
pipeline = Lotus2Pipeline.from_pretrained(
|
| 443 |
+
args.pretrained_model_name_or_path,
|
| 444 |
+
scheduler=noise_scheduler,
|
| 445 |
+
transformer=transformer,
|
| 446 |
+
revision=args.revision,
|
| 447 |
+
variant=args.variant,
|
| 448 |
+
torch_dtype=weight_dtype,
|
| 449 |
+
)
|
| 450 |
+
pipeline.local_continuity_module = local_continuity_module
|
| 451 |
+
pipeline = pipeline.to(device)
|
| 452 |
+
|
| 453 |
+
# -------------------- Run inference! --------------------
|
| 454 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 455 |
+
|
| 456 |
+
with nullcontext():
|
| 457 |
+
for image_path in tqdm(test_images):
|
| 458 |
+
# print("\n",image_path)
|
| 459 |
+
_, output_vis, output_npy = process_single_image(
|
| 460 |
+
image_path, pipeline,
|
| 461 |
+
task_name=args.task_name,
|
| 462 |
+
device=device,
|
| 463 |
+
num_inference_steps=args.num_inference_steps,
|
| 464 |
+
process_res=args.process_res
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
output_vis.save(os.path.join(output_dir_vis, f'{image_path.stem}.png'))
|
| 468 |
+
np.save(os.path.join(output_dir_npy, f'{image_path.stem}.npy'), output_npy)
|
| 469 |
+
|
| 470 |
+
if __name__ == "__main__":
|
| 471 |
+
args = parse_args()
|
| 472 |
+
main(args)
|
infer.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export OPENCV_IO_ENABLE_OPENEXR=1
|
| 2 |
+
export TOKENIZERS_PARALLELISM=false
|
| 3 |
+
|
| 4 |
+
export TASK_NAME="normal"
|
| 5 |
+
|
| 6 |
+
# paths
|
| 7 |
+
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
| 8 |
+
# export CORE_PREDICTOR_MODEL_PATH="weights/lotus-2_core_predictor_$TASK_NAME.safetensors"
|
| 9 |
+
# export DETAIL_SHARPENER_MODEL_PATH="weights/lotus-2_detail_sharpener_$TASK_NAME.safetensors"
|
| 10 |
+
# export LCM_MODEL_PATH="weights/lotus-2_lcm_$TASK_NAME.safetensors"
|
| 11 |
+
|
| 12 |
+
export INPUT_DIR="assets"
|
| 13 |
+
export OUTPUT_DIR="outputs/infer/"
|
| 14 |
+
|
| 15 |
+
# configs
|
| 16 |
+
export NUM_INFERENCE_STEPS=10
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
CUDA_VISIBLE_DEVICES=0 python infer.py \
|
| 20 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 21 |
+
--input_dir=$INPUT_DIR \
|
| 22 |
+
--output_dir=$OUTPUT_DIR \
|
| 23 |
+
--mixed_precision="bf16" \
|
| 24 |
+
--num_inference_steps=$NUM_INFERENCE_STEPS \
|
| 25 |
+
--seed="0" \
|
| 26 |
+
--task_name=$TASK_NAME \
|
| 27 |
+
--process_res=1024
|
| 28 |
+
# --core_predictor_model_path=$CORE_PREDICTOR_MODEL_PATH \
|
| 29 |
+
# --detail_sharpener_model_path=$DETAIL_SHARPENER_MODEL_PATH \
|
| 30 |
+
# --lcm_model_path=$LCM_MODEL_PATH \
|
pipeline.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Optional, List, Dict, Any
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import FluxPipeline
|
| 5 |
+
from diffusers.pipelines.flux import FluxPipelineOutput
|
| 6 |
+
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
| 7 |
+
from diffusers.utils import is_torch_xla_available
|
| 8 |
+
|
| 9 |
+
from utils.image_utils import resize_image, resize_image_first
|
| 10 |
+
|
| 11 |
+
if is_torch_xla_available():
|
| 12 |
+
import torch_xla.core.xla_model as xm
|
| 13 |
+
XLA_AVAILABLE = True
|
| 14 |
+
else:
|
| 15 |
+
XLA_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
class Lotus2Pipeline(FluxPipeline):
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def __call__(
|
| 20 |
+
self,
|
| 21 |
+
rgb_in: Optional[torch.FloatTensor] = None,
|
| 22 |
+
prompt: Union[str, List[str]] = None,
|
| 23 |
+
num_inference_steps: int = 10,
|
| 24 |
+
output_type: Optional[str] = "pil",
|
| 25 |
+
process_res: Optional[int] = None,
|
| 26 |
+
timestep_core_predictor: int = 1,
|
| 27 |
+
guidance_scale: float = 3.5,
|
| 28 |
+
return_dict: bool = True,
|
| 29 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 30 |
+
):
|
| 31 |
+
r"""
|
| 32 |
+
Function invoked when calling the pipeline for generation.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
rgb_in (`torch.FloatTensor`, *optional*):
|
| 36 |
+
The input image to be used for generation.
|
| 37 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 38 |
+
The prompt or prompts to guide the prediction. Default is ''.
|
| 39 |
+
num_inference_steps (`int`, *optional*, defaults to 10):
|
| 40 |
+
The number of denoising steps. More denoising steps usually lead to a sharper prediction at the
|
| 41 |
+
expense of slower inference.
|
| 42 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 43 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 44 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 45 |
+
The output format of the generate image. Choose between
|
| 46 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 47 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 49 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 50 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 51 |
+
`self.processor` in
|
| 52 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 53 |
+
|
| 54 |
+
Examples:
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 58 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 59 |
+
images.
|
| 60 |
+
"""
|
| 61 |
+
# 1. prepare
|
| 62 |
+
batch_size = rgb_in.shape[0]
|
| 63 |
+
input_size = rgb_in.shape[2:]
|
| 64 |
+
rgb_in = resize_image_first(rgb_in, process_res)
|
| 65 |
+
height, width = rgb_in.shape[2:]
|
| 66 |
+
|
| 67 |
+
self._guidance_scale = guidance_scale
|
| 68 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 69 |
+
self._interrupt = False
|
| 70 |
+
|
| 71 |
+
device = self._execution_device
|
| 72 |
+
|
| 73 |
+
# 2. encode prompt
|
| 74 |
+
(
|
| 75 |
+
prompt_embeds,
|
| 76 |
+
pooled_prompt_embeds,
|
| 77 |
+
text_ids,
|
| 78 |
+
) = self.encode_prompt(
|
| 79 |
+
prompt=prompt,
|
| 80 |
+
prompt_2=None,
|
| 81 |
+
device=device,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 3. prepare latent variables
|
| 85 |
+
rgb_in = rgb_in.to(device=device, dtype=self.dtype)
|
| 86 |
+
rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
|
| 87 |
+
rgb_latents = (rgb_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 88 |
+
|
| 89 |
+
packed_rgb_latents = self._pack_latents(
|
| 90 |
+
rgb_latents,
|
| 91 |
+
batch_size=rgb_latents.shape[0],
|
| 92 |
+
num_channels_latents=rgb_latents.shape[1],
|
| 93 |
+
height=rgb_latents.shape[2],
|
| 94 |
+
width=rgb_latents.shape[3],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
latent_image_ids_core_predictor = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
|
| 98 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
|
| 99 |
+
|
| 100 |
+
# 4. prepare timesteps
|
| 101 |
+
timestep_core_predictor = torch.tensor(timestep_core_predictor).expand(batch_size).to(device=rgb_in.device, dtype=rgb_in.dtype)
|
| 102 |
+
|
| 103 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 104 |
+
image_seq_len = packed_rgb_latents.shape[1]
|
| 105 |
+
mu = calculate_shift(
|
| 106 |
+
image_seq_len,
|
| 107 |
+
self.scheduler.config.base_image_seq_len,
|
| 108 |
+
self.scheduler.config.max_image_seq_len,
|
| 109 |
+
self.scheduler.config.base_shift,
|
| 110 |
+
self.scheduler.config.max_shift,
|
| 111 |
+
)
|
| 112 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 113 |
+
self.scheduler,
|
| 114 |
+
num_inference_steps,
|
| 115 |
+
device,
|
| 116 |
+
sigmas=sigmas,
|
| 117 |
+
mu=mu,
|
| 118 |
+
)
|
| 119 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 0
|
| 120 |
+
self._num_timesteps = len(timesteps)
|
| 121 |
+
|
| 122 |
+
# 5. handle guidance
|
| 123 |
+
if self.transformer.config.guidance_embeds:
|
| 124 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 125 |
+
guidance = guidance.expand(packed_rgb_latents.shape[0])
|
| 126 |
+
else:
|
| 127 |
+
guidance = None
|
| 128 |
+
|
| 129 |
+
if self.joint_attention_kwargs is None:
|
| 130 |
+
self._joint_attention_kwargs = {}
|
| 131 |
+
|
| 132 |
+
# 6. core predictor
|
| 133 |
+
self.transformer.set_adapter("core_predictor")
|
| 134 |
+
latents = self.transformer(
|
| 135 |
+
hidden_states=packed_rgb_latents,
|
| 136 |
+
timestep=timestep_core_predictor / 1000,
|
| 137 |
+
guidance=guidance,
|
| 138 |
+
pooled_projections=pooled_prompt_embeds,
|
| 139 |
+
encoder_hidden_states=prompt_embeds,
|
| 140 |
+
txt_ids=text_ids,
|
| 141 |
+
img_ids=latent_image_ids_core_predictor,
|
| 142 |
+
joint_attention_kwargs=self.joint_attention_kwargs, # {}
|
| 143 |
+
return_dict=False,
|
| 144 |
+
)[0]
|
| 145 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 146 |
+
latents = self.local_continuity_module(latents)
|
| 147 |
+
|
| 148 |
+
# 7. Denoising loop for detail sharpener
|
| 149 |
+
self.transformer.set_adapter("detail_sharpener")
|
| 150 |
+
latents = self._pack_latents(
|
| 151 |
+
latents,
|
| 152 |
+
batch_size=latents.shape[0],
|
| 153 |
+
num_channels_latents=latents.shape[1],
|
| 154 |
+
height=latents.shape[2],
|
| 155 |
+
width=latents.shape[3],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 159 |
+
for i, t in enumerate(timesteps):
|
| 160 |
+
if self.interrupt:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 164 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 165 |
+
|
| 166 |
+
noise_pred = self.transformer(
|
| 167 |
+
hidden_states=latents,
|
| 168 |
+
timestep=timestep / 1000,
|
| 169 |
+
guidance=guidance,
|
| 170 |
+
pooled_projections=pooled_prompt_embeds,
|
| 171 |
+
encoder_hidden_states=prompt_embeds,
|
| 172 |
+
txt_ids=text_ids,
|
| 173 |
+
img_ids=latent_image_ids,
|
| 174 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 175 |
+
return_dict=False,
|
| 176 |
+
)[0]
|
| 177 |
+
|
| 178 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 179 |
+
latents_dtype = latents.dtype
|
| 180 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 181 |
+
|
| 182 |
+
if latents.dtype != latents_dtype:
|
| 183 |
+
if torch.backends.mps.is_available():
|
| 184 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 185 |
+
latents = latents.to(latents_dtype)
|
| 186 |
+
|
| 187 |
+
# call the callback, if provided
|
| 188 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 189 |
+
progress_bar.update()
|
| 190 |
+
|
| 191 |
+
if XLA_AVAILABLE:
|
| 192 |
+
xm.mark_step()
|
| 193 |
+
|
| 194 |
+
latents = latents.to(dtype=self.dtype)
|
| 195 |
+
|
| 196 |
+
if output_type == "latent":
|
| 197 |
+
image = latents
|
| 198 |
+
|
| 199 |
+
else:
|
| 200 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 201 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 202 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 203 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 204 |
+
|
| 205 |
+
# Resize output image to match input size
|
| 206 |
+
image = resize_image(image, input_size)
|
| 207 |
+
|
| 208 |
+
# Offload all models
|
| 209 |
+
self.maybe_free_model_hooks()
|
| 210 |
+
|
| 211 |
+
if not return_dict:
|
| 212 |
+
return (image,)
|
| 213 |
+
|
| 214 |
+
return FluxPipelineOutput(images=image)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
matplotlib==3.10.0
|
| 3 |
+
peft==0.14.0
|
| 4 |
+
protobuf==5.29.0
|
| 5 |
+
sentencepiece==0.2.0
|
| 6 |
+
opencv-python==4.11.0.86
|
| 7 |
+
huggingface-hub==0.36.0
|
| 8 |
+
diffusers==0.32.2
|
| 9 |
+
torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
| 10 |
+
torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
|
| 11 |
+
gradio==5.49.0
|
| 12 |
+
gradio-client==1.13.3
|
| 13 |
+
gradio-imageslider==0.0.20
|
| 14 |
+
spaces==0.42.1
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import matplotlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List
|
| 5 |
+
import csv
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.transforms import InterpolationMode
|
| 12 |
+
from torchvision.transforms.functional import resize
|
| 13 |
+
|
| 14 |
+
def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
|
| 15 |
+
r"""
|
| 16 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
images (`np.ndarray`):
|
| 20 |
+
The image array to convert to PIL format.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
`List[PIL.Image.Image]`:
|
| 24 |
+
A list of PIL images.
|
| 25 |
+
"""
|
| 26 |
+
if images.ndim == 3:
|
| 27 |
+
images = images[None, ...]
|
| 28 |
+
images = (images * 255).round().astype("uint8")
|
| 29 |
+
if images.shape[-1] == 1:
|
| 30 |
+
# special case for grayscale (single channel) images
|
| 31 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 32 |
+
else:
|
| 33 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 34 |
+
|
| 35 |
+
return pil_images
|
| 36 |
+
|
| 37 |
+
def resize_output(image, target_size):
|
| 38 |
+
"""
|
| 39 |
+
Resize output image to target size
|
| 40 |
+
Args:
|
| 41 |
+
image: Image in PIL.Image, numpy.array or torch.tensor format
|
| 42 |
+
target_size: tuple, target size (H, W)
|
| 43 |
+
Returns:
|
| 44 |
+
Resized image in original format
|
| 45 |
+
"""
|
| 46 |
+
if isinstance(image, list):
|
| 47 |
+
return [resize_output(img, target_size) for img in image]
|
| 48 |
+
|
| 49 |
+
if isinstance(image, Image.Image):
|
| 50 |
+
return image.resize(target_size[::-1], Image.BILINEAR)
|
| 51 |
+
elif isinstance(image, np.ndarray):
|
| 52 |
+
# Handle numpy array with shape (1, H, W, 3)
|
| 53 |
+
if image.ndim == 4:
|
| 54 |
+
resized = np.stack([cv2.resize(img, target_size[::-1]) for img in image])
|
| 55 |
+
return resized
|
| 56 |
+
else:
|
| 57 |
+
return cv2.resize(image, target_size[::-1])
|
| 58 |
+
elif isinstance(image, torch.Tensor):
|
| 59 |
+
# Handle tensor with shape (1, 3, H, W)
|
| 60 |
+
if image.dim() == 4:
|
| 61 |
+
return torch.nn.functional.interpolate(
|
| 62 |
+
image,
|
| 63 |
+
size=target_size,
|
| 64 |
+
mode='bilinear',
|
| 65 |
+
align_corners=False
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
return torch.nn.functional.interpolate(
|
| 69 |
+
image.unsqueeze(0),
|
| 70 |
+
size=target_size,
|
| 71 |
+
mode='bilinear',
|
| 72 |
+
align_corners=False
|
| 73 |
+
).squeeze(0)
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unsupported image format: {type(image)}")
|
| 76 |
+
|
| 77 |
+
def resize_image(image, target_size):
|
| 78 |
+
"""
|
| 79 |
+
Resize output image to target size
|
| 80 |
+
Args:
|
| 81 |
+
image: Image in PIL.Image, numpy.array or torch.tensor format
|
| 82 |
+
target_size: tuple, target size (H, W)
|
| 83 |
+
Returns:
|
| 84 |
+
Resized image in original format
|
| 85 |
+
"""
|
| 86 |
+
if isinstance(image, list):
|
| 87 |
+
return [resize_image(img, target_size) for img in image]
|
| 88 |
+
|
| 89 |
+
if isinstance(image, Image.Image):
|
| 90 |
+
return image.resize(target_size[::-1], Image.BILINEAR)
|
| 91 |
+
elif isinstance(image, np.ndarray):
|
| 92 |
+
# Handle numpy array with shape (1, H, W, 3)
|
| 93 |
+
if image.ndim == 4:
|
| 94 |
+
resized = np.stack([cv2.resize(img, target_size[::-1]) for img in image])
|
| 95 |
+
return resized
|
| 96 |
+
else:
|
| 97 |
+
return cv2.resize(image, target_size[::-1])
|
| 98 |
+
elif isinstance(image, torch.Tensor):
|
| 99 |
+
# Handle tensor with shape (1, 3, H, W)
|
| 100 |
+
if image.dim() == 4:
|
| 101 |
+
return torch.nn.functional.interpolate(
|
| 102 |
+
image,
|
| 103 |
+
size=target_size,
|
| 104 |
+
mode='bilinear',
|
| 105 |
+
align_corners=False
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
return torch.nn.functional.interpolate(
|
| 109 |
+
image.unsqueeze(0),
|
| 110 |
+
size=target_size,
|
| 111 |
+
mode='bilinear',
|
| 112 |
+
align_corners=False
|
| 113 |
+
).squeeze(0)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Unsupported image format: {type(image)}")
|
| 116 |
+
|
| 117 |
+
def resize_image_first(image_tensor, process_res=None):
|
| 118 |
+
if process_res:
|
| 119 |
+
max_edge = max(image_tensor.shape[2], image_tensor.shape[3])
|
| 120 |
+
if max_edge > process_res:
|
| 121 |
+
scale = process_res / max_edge
|
| 122 |
+
new_height = int(image_tensor.shape[2] * scale)
|
| 123 |
+
new_width = int(image_tensor.shape[3] * scale)
|
| 124 |
+
image_tensor = resize_image(image_tensor, (new_height, new_width))
|
| 125 |
+
|
| 126 |
+
image_tensor = resize_to_multiple_of_16(image_tensor)
|
| 127 |
+
|
| 128 |
+
return image_tensor
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def smooth_image(image, method='gaussian', kernel_size=31, sigma=15.0, bilateral_d=9, bilateral_color=75, bilateral_space=75):
|
| 132 |
+
"""
|
| 133 |
+
应用多种平滑方法来消除图像中的网格伪影
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
image: PIL.Image, numpy.array 或 torch.tensor 格式的图像
|
| 137 |
+
method: 平滑方法,可选 'gaussian'(高斯模糊), 'bilateral'(双边滤波), 'median'(中值滤波),
|
| 138 |
+
'guided'(引导滤波), 'strong'(结合多种滤波的强力平滑)
|
| 139 |
+
kernel_size: 高斯和中值滤波的核大小,默认为31,应为奇数
|
| 140 |
+
sigma: 高斯滤波的标准差,默认为15.0
|
| 141 |
+
bilateral_d: 双边滤波的直径,默认为9
|
| 142 |
+
bilateral_color: 双边滤波的颜色空间标准差,默认为75
|
| 143 |
+
bilateral_space: 双边滤波的坐标空间标准差,默认为75
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
平滑后的图像,保持原始格式
|
| 147 |
+
"""
|
| 148 |
+
if isinstance(image, list):
|
| 149 |
+
return [smooth_image(img, method, kernel_size, sigma, bilateral_d, bilateral_color, bilateral_space) for img in image]
|
| 150 |
+
|
| 151 |
+
# 确保kernel_size是奇数
|
| 152 |
+
if kernel_size % 2 == 0:
|
| 153 |
+
kernel_size += 1
|
| 154 |
+
|
| 155 |
+
# 转换为numpy数组进行处理
|
| 156 |
+
is_pil = isinstance(image, Image.Image)
|
| 157 |
+
is_tensor = isinstance(image, torch.Tensor)
|
| 158 |
+
|
| 159 |
+
if is_pil:
|
| 160 |
+
img_array = np.array(image)
|
| 161 |
+
elif is_tensor:
|
| 162 |
+
device = image.device
|
| 163 |
+
if image.dim() == 4: # (B, C, H, W)
|
| 164 |
+
batch_size, channels, height, width = image.shape
|
| 165 |
+
img_array = image.permute(0, 2, 3, 1).cpu().numpy() # (B, H, W, C)
|
| 166 |
+
else: # (C, H, W)
|
| 167 |
+
img_array = image.permute(1, 2, 0).cpu().numpy() # (H, W, C)
|
| 168 |
+
else:
|
| 169 |
+
img_array = image
|
| 170 |
+
|
| 171 |
+
# 保存原始数据类型
|
| 172 |
+
original_dtype = img_array.dtype
|
| 173 |
+
|
| 174 |
+
# 应用选定的平滑方法
|
| 175 |
+
if method == 'gaussian':
|
| 176 |
+
# 标准高斯模糊,适合轻微平滑
|
| 177 |
+
if img_array.ndim == 4:
|
| 178 |
+
smoothed = np.stack([cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma) for img in img_array])
|
| 179 |
+
else:
|
| 180 |
+
smoothed = cv2.GaussianBlur(img_array, (kernel_size, kernel_size), sigma)
|
| 181 |
+
|
| 182 |
+
elif method == 'bilateral':
|
| 183 |
+
# 双边滤波,保持边缘的同时平滑平坦区域
|
| 184 |
+
if img_array.ndim == 4:
|
| 185 |
+
# 确保图像是8位类型
|
| 186 |
+
imgs_uint8 = [img.astype(np.uint8) if img.dtype != np.uint8 else img for img in img_array]
|
| 187 |
+
smoothed = np.stack([cv2.bilateralFilter(img, bilateral_d, bilateral_color, bilateral_space) for img in imgs_uint8])
|
| 188 |
+
# 转回原始类型
|
| 189 |
+
if original_dtype != np.uint8:
|
| 190 |
+
smoothed = smoothed.astype(original_dtype)
|
| 191 |
+
else:
|
| 192 |
+
# 确保图像是8位类型
|
| 193 |
+
img_uint8 = img_array.astype(np.uint8) if img_array.dtype != np.uint8 else img_array
|
| 194 |
+
smoothed = cv2.bilateralFilter(img_uint8, bilateral_d, bilateral_color, bilateral_space)
|
| 195 |
+
# 转回原始类型
|
| 196 |
+
if original_dtype != np.uint8:
|
| 197 |
+
smoothed = smoothed.astype(original_dtype)
|
| 198 |
+
|
| 199 |
+
elif method == 'median':
|
| 200 |
+
# 中值滤波,对于消除盐和胡椒噪声和小格子非常有效
|
| 201 |
+
# 中值滤波要求输入为uint8或uint16
|
| 202 |
+
if img_array.ndim == 4:
|
| 203 |
+
# 转换为8位无符号整数并确保格式正确
|
| 204 |
+
imgs_uint8 = []
|
| 205 |
+
for img in img_array:
|
| 206 |
+
# 对浮点图像进行缩放到0-255范围
|
| 207 |
+
if img.dtype != np.uint8:
|
| 208 |
+
if img.max() <= 1.0: # 检查是否是0-1范围的浮点数
|
| 209 |
+
img = (img * 255).astype(np.uint8)
|
| 210 |
+
else:
|
| 211 |
+
img = img.astype(np.uint8)
|
| 212 |
+
imgs_uint8.append(img)
|
| 213 |
+
|
| 214 |
+
smoothed = np.stack([cv2.medianBlur(img, kernel_size) for img in imgs_uint8])
|
| 215 |
+
# 转回原始类型
|
| 216 |
+
if original_dtype != np.uint8:
|
| 217 |
+
if original_dtype == np.float32 or original_dtype == np.float64:
|
| 218 |
+
if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
|
| 219 |
+
smoothed = smoothed.astype(float) / 255.0
|
| 220 |
+
|
| 221 |
+
else:
|
| 222 |
+
# 转换为8位无符号整数
|
| 223 |
+
if img_array.dtype != np.uint8:
|
| 224 |
+
if img_array.max() <= 1.0: # 检查是否是0-1范围的浮点数
|
| 225 |
+
img_uint8 = (img_array * 255).astype(np.uint8)
|
| 226 |
+
else:
|
| 227 |
+
img_uint8 = img_array.astype(np.uint8)
|
| 228 |
+
else:
|
| 229 |
+
img_uint8 = img_array
|
| 230 |
+
|
| 231 |
+
smoothed = cv2.medianBlur(img_uint8, kernel_size)
|
| 232 |
+
# 转回原始类型
|
| 233 |
+
if original_dtype != np.uint8:
|
| 234 |
+
if original_dtype == np.float32 or original_dtype == np.float64:
|
| 235 |
+
if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
|
| 236 |
+
smoothed = smoothed.astype(float) / 255.0
|
| 237 |
+
else:
|
| 238 |
+
smoothed = smoothed.astype(original_dtype)
|
| 239 |
+
|
| 240 |
+
elif method == 'guided':
|
| 241 |
+
# 引导滤波,在保持边缘的同时平滑区域
|
| 242 |
+
if img_array.ndim == 4:
|
| 243 |
+
smoothed = np.stack([cv2.ximgproc.guidedFilter(
|
| 244 |
+
guide=img, src=img, radius=kernel_size//2, eps=1e-6) for img in img_array])
|
| 245 |
+
else:
|
| 246 |
+
smoothed = cv2.ximgproc.guidedFilter(
|
| 247 |
+
guide=img_array, src=img_array, radius=kernel_size//2, eps=1e-6)
|
| 248 |
+
|
| 249 |
+
elif method == 'strong':
|
| 250 |
+
# 强力平滑:先应用中值滤波去除尖锐噪点,然后用双边滤波保持边缘,最后用高斯进一步平滑
|
| 251 |
+
if img_array.ndim == 4:
|
| 252 |
+
# 转换为8位无符号整数
|
| 253 |
+
imgs_uint8 = []
|
| 254 |
+
for img in img_array:
|
| 255 |
+
# 对浮点图像进行缩放到0-255范围
|
| 256 |
+
if img.dtype != np.uint8:
|
| 257 |
+
if img.max() <= 1.0: # 检查是否是0-1范围的浮点数
|
| 258 |
+
img = (img * 255).astype(np.uint8)
|
| 259 |
+
else:
|
| 260 |
+
img = img.astype(np.uint8)
|
| 261 |
+
imgs_uint8.append(img)
|
| 262 |
+
|
| 263 |
+
temp = np.stack([cv2.medianBlur(img, min(15, kernel_size)) for img in imgs_uint8])
|
| 264 |
+
temp = np.stack([cv2.bilateralFilter(img, bilateral_d, bilateral_color, bilateral_space) for img in temp])
|
| 265 |
+
smoothed = np.stack([cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma) for img in temp])
|
| 266 |
+
|
| 267 |
+
# 转回原始类型
|
| 268 |
+
if original_dtype != np.uint8:
|
| 269 |
+
if original_dtype == np.float32 or original_dtype == np.float64:
|
| 270 |
+
if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
|
| 271 |
+
smoothed = smoothed.astype(float) / 255.0
|
| 272 |
+
else:
|
| 273 |
+
smoothed = smoothed.astype(original_dtype)
|
| 274 |
+
else:
|
| 275 |
+
# 转换为8位无符号整数
|
| 276 |
+
if img_array.dtype != np.uint8:
|
| 277 |
+
if img_array.max() <= 1.0: # 检查是否是0-1范围的浮点数
|
| 278 |
+
img_uint8 = (img_array * 255).astype(np.uint8)
|
| 279 |
+
else:
|
| 280 |
+
img_uint8 = img_array.astype(np.uint8)
|
| 281 |
+
else:
|
| 282 |
+
img_uint8 = img_array
|
| 283 |
+
|
| 284 |
+
temp = cv2.medianBlur(img_uint8, min(15, kernel_size))
|
| 285 |
+
temp = cv2.bilateralFilter(temp, bilateral_d, bilateral_color, bilateral_space)
|
| 286 |
+
smoothed = cv2.GaussianBlur(temp, (kernel_size, kernel_size), sigma)
|
| 287 |
+
|
| 288 |
+
# 转回原始类型
|
| 289 |
+
if original_dtype != np.uint8:
|
| 290 |
+
if original_dtype == np.float32 or original_dtype == np.float64:
|
| 291 |
+
if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
|
| 292 |
+
smoothed = smoothed.astype(float) / 255.0
|
| 293 |
+
else:
|
| 294 |
+
smoothed = smoothed.astype(original_dtype)
|
| 295 |
+
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError(f"不支持的平滑方法: {method},请选择 'gaussian', 'bilateral', 'median', 'guided' 或 'strong'")
|
| 298 |
+
|
| 299 |
+
# 将结果转换回原始格式
|
| 300 |
+
if is_pil:
|
| 301 |
+
# 如果结果是浮点类型且值在0-1之间,需要先转换为0-255的uint8
|
| 302 |
+
if smoothed.dtype == np.float32 or smoothed.dtype == np.float64:
|
| 303 |
+
if smoothed.max() <= 1.0:
|
| 304 |
+
smoothed = (smoothed * 255).astype(np.uint8)
|
| 305 |
+
return Image.fromarray(smoothed.astype(np.uint8))
|
| 306 |
+
elif is_tensor:
|
| 307 |
+
if image.dim() == 4:
|
| 308 |
+
return torch.from_numpy(smoothed).permute(0, 3, 1, 2).to(device)
|
| 309 |
+
else:
|
| 310 |
+
return torch.from_numpy(smoothed).permute(2, 0, 1).to(device)
|
| 311 |
+
else:
|
| 312 |
+
return smoothed
|
| 313 |
+
|
| 314 |
+
def resize_to_multiple_of_16(image_tensor):
|
| 315 |
+
"""
|
| 316 |
+
Resize image tensor to make shorter side closest multiple of 16 while maintaining aspect ratio
|
| 317 |
+
Args:
|
| 318 |
+
image_tensor: Input tensor of shape (B, C, H, W)
|
| 319 |
+
Returns:
|
| 320 |
+
Resized tensor where shorter side is multiple of 16
|
| 321 |
+
"""
|
| 322 |
+
# Calculate scale ratio based on shorter side to make it closest multiple of 16
|
| 323 |
+
h, w = image_tensor.shape[2], image_tensor.shape[3]
|
| 324 |
+
min_side = min(h, w)
|
| 325 |
+
scale = (min_side // 16) * 16 / min_side
|
| 326 |
+
|
| 327 |
+
# Calculate new height and width
|
| 328 |
+
new_h = int(h * scale)
|
| 329 |
+
new_w = int(w * scale)
|
| 330 |
+
|
| 331 |
+
# Ensure both height and width are multiples of 16
|
| 332 |
+
new_h = (new_h // 16) * 16
|
| 333 |
+
new_w = (new_w // 16) * 16
|
| 334 |
+
|
| 335 |
+
# Resize image while maintaining aspect ratio
|
| 336 |
+
resized_tensor = torch.nn.functional.interpolate(
|
| 337 |
+
image_tensor,
|
| 338 |
+
size=(new_h, new_w),
|
| 339 |
+
mode='bilinear',
|
| 340 |
+
align_corners=False
|
| 341 |
+
)
|
| 342 |
+
return resized_tensor
|
| 343 |
+
|
| 344 |
+
def load_color_list(csv_path):
|
| 345 |
+
color_list = []
|
| 346 |
+
with open(csv_path, newline='') as file:
|
| 347 |
+
reader = csv.reader(file)
|
| 348 |
+
|
| 349 |
+
next(reader)
|
| 350 |
+
|
| 351 |
+
for row in reader:
|
| 352 |
+
last_three = tuple(map(int, row[-3:]))
|
| 353 |
+
color_list.append(last_three)
|
| 354 |
+
|
| 355 |
+
color_list = [(0,0,0)] + color_list
|
| 356 |
+
|
| 357 |
+
return color_list
|
| 358 |
+
|
| 359 |
+
def conver_rgb_to_semantic_map(image: Image, color_list: List):
|
| 360 |
+
# Convert PIL Image to numpy array
|
| 361 |
+
image_array = np.array(image)
|
| 362 |
+
|
| 363 |
+
# Initialize an empty array for the indexed image
|
| 364 |
+
indexed_image = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=int)
|
| 365 |
+
|
| 366 |
+
# Loop through each pixel in the image
|
| 367 |
+
for i in range(image_array.shape[0]):
|
| 368 |
+
for j in range(image_array.shape[1]):
|
| 369 |
+
# Get the color of the current pixel
|
| 370 |
+
pixel_color = tuple(image_array[i, j][:3]) # Exclude the alpha channel if present
|
| 371 |
+
|
| 372 |
+
# Find the closest color from the color list and get its index
|
| 373 |
+
# Here, the Euclidean distance is used to find the closest color
|
| 374 |
+
distances = np.sqrt(np.sum((np.array(color_list) - np.array(pixel_color))**2, axis=1))
|
| 375 |
+
closest_color_index = np.argmin(distances)
|
| 376 |
+
|
| 377 |
+
# Set the index in the indexed image
|
| 378 |
+
indexed_image[i, j] = closest_color_index
|
| 379 |
+
|
| 380 |
+
indexed_image = indexed_image - 1
|
| 381 |
+
|
| 382 |
+
return indexed_image
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def concatenate_images(*image_lists):
|
| 386 |
+
# Ensure at least one image list is provided
|
| 387 |
+
if not image_lists or not image_lists[0]:
|
| 388 |
+
raise ValueError("At least one non-empty image list must be provided")
|
| 389 |
+
|
| 390 |
+
# Determine the maximum width of any single row and the total height
|
| 391 |
+
max_width = 0
|
| 392 |
+
total_height = 0
|
| 393 |
+
row_widths = []
|
| 394 |
+
row_heights = []
|
| 395 |
+
|
| 396 |
+
# Compute dimensions for each row
|
| 397 |
+
for image_list in image_lists:
|
| 398 |
+
if image_list: # Ensure the list is not empty
|
| 399 |
+
width = sum(img.width for img in image_list)
|
| 400 |
+
height = max(img.height for img in image_list)
|
| 401 |
+
max_width = max(max_width, width)
|
| 402 |
+
total_height += height
|
| 403 |
+
row_widths.append(width)
|
| 404 |
+
row_heights.append(height)
|
| 405 |
+
|
| 406 |
+
# Create a new image to concatenate everything into
|
| 407 |
+
new_image = Image.new('RGB', (max_width, total_height))
|
| 408 |
+
|
| 409 |
+
# Concatenate each row of images
|
| 410 |
+
y_offset = 0
|
| 411 |
+
for i, image_list in enumerate(image_lists):
|
| 412 |
+
x_offset = 0
|
| 413 |
+
for img in image_list:
|
| 414 |
+
new_image.paste(img, (x_offset, y_offset))
|
| 415 |
+
x_offset += img.width
|
| 416 |
+
y_offset += row_heights[i] # Move the offset down to the next row
|
| 417 |
+
|
| 418 |
+
return new_image
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# def concatenate_images(image_list1, image_list2):
|
| 422 |
+
# # Ensure both image lists are not empty
|
| 423 |
+
# if not image_list1 or not image_list2:
|
| 424 |
+
# raise ValueError("Image lists cannot be empty")
|
| 425 |
+
|
| 426 |
+
# # Get the width and height of the first image
|
| 427 |
+
# width, height = image_list1[0].size
|
| 428 |
+
|
| 429 |
+
# # Calculate the total width and height
|
| 430 |
+
# total_width = max(len(image_list1), len(image_list2)) * width
|
| 431 |
+
# total_height = 2 * height # For two rows
|
| 432 |
+
|
| 433 |
+
# # Create a new image to concatenate everything into
|
| 434 |
+
# new_image = Image.new('RGB', (total_width, total_height))
|
| 435 |
+
|
| 436 |
+
# # Concatenate the first row of images
|
| 437 |
+
# x_offset = 0
|
| 438 |
+
# for img in image_list1:
|
| 439 |
+
# new_image.paste(img, (x_offset, 0))
|
| 440 |
+
# x_offset += img.width
|
| 441 |
+
|
| 442 |
+
# # Concatenate the second row of images
|
| 443 |
+
# x_offset = 0
|
| 444 |
+
# for img in image_list2:
|
| 445 |
+
# new_image.paste(img, (x_offset, height))
|
| 446 |
+
# x_offset += img.width
|
| 447 |
+
|
| 448 |
+
# return new_image
|
| 449 |
+
|
| 450 |
+
def colorize_depth_map(depth, mask=None, reverse_color=False):
|
| 451 |
+
cm = matplotlib.colormaps["Spectral"]
|
| 452 |
+
# normalize
|
| 453 |
+
depth = ((depth - depth.min()) / (depth.max() - depth.min()))
|
| 454 |
+
# colorize
|
| 455 |
+
if reverse_color:
|
| 456 |
+
img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3] # Invert the depth values before applying colormap
|
| 457 |
+
else:
|
| 458 |
+
img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
|
| 459 |
+
|
| 460 |
+
depth_colored = (img_colored_np * 255).astype(np.uint8)
|
| 461 |
+
if mask is not None:
|
| 462 |
+
masked_image = np.zeros_like(depth_colored)
|
| 463 |
+
masked_image[mask.numpy()] = depth_colored[mask.numpy()]
|
| 464 |
+
depth_colored_img = Image.fromarray(masked_image)
|
| 465 |
+
else:
|
| 466 |
+
depth_colored_img = Image.fromarray(depth_colored)
|
| 467 |
+
return depth_colored_img
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def resize_max_res(
|
| 471 |
+
img: torch.Tensor,
|
| 472 |
+
max_edge_resolution: int,
|
| 473 |
+
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
|
| 474 |
+
) -> torch.Tensor:
|
| 475 |
+
"""
|
| 476 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
img (`torch.Tensor`):
|
| 480 |
+
Image tensor to be resized. Expected shape: [B, C, H, W]
|
| 481 |
+
max_edge_resolution (`int`):
|
| 482 |
+
Maximum edge length (pixel).
|
| 483 |
+
resample_method (`PIL.Image.Resampling`):
|
| 484 |
+
Resampling method used to resize images.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
`torch.Tensor`: Resized image.
|
| 488 |
+
"""
|
| 489 |
+
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
|
| 490 |
+
|
| 491 |
+
original_height, original_width = img.shape[-2:]
|
| 492 |
+
downscale_factor = min(
|
| 493 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
new_width = int(original_width * downscale_factor)
|
| 497 |
+
new_height = int(original_height * downscale_factor)
|
| 498 |
+
|
| 499 |
+
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
|
| 500 |
+
return resized_img
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def get_tv_resample_method(method_str: str) -> InterpolationMode:
|
| 504 |
+
resample_method_dict = {
|
| 505 |
+
"bilinear": InterpolationMode.BILINEAR,
|
| 506 |
+
"bicubic": InterpolationMode.BICUBIC,
|
| 507 |
+
"nearest": InterpolationMode.NEAREST_EXACT,
|
| 508 |
+
"nearest-exact": InterpolationMode.NEAREST_EXACT,
|
| 509 |
+
}
|
| 510 |
+
resample_method = resample_method_dict.get(method_str, None)
|
| 511 |
+
if resample_method is None:
|
| 512 |
+
raise ValueError(f"Unknown resampling method: {resample_method}")
|
| 513 |
+
else:
|
| 514 |
+
return resample_method
|
utils/seed_all.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# --------------------------------------------------------------------------
|
| 15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 18 |
+
# --------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import random
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def seed_all(seed: int = 0):
|
| 27 |
+
"""
|
| 28 |
+
Set random seeds of all components.
|
| 29 |
+
"""
|
| 30 |
+
random.seed(seed)
|
| 31 |
+
np.random.seed(seed)
|
| 32 |
+
torch.manual_seed(seed)
|
| 33 |
+
torch.cuda.manual_seed_all(seed)
|