haodongli commited on
Commit
436f5aa
·
1 Parent(s): a58f1cf
Files changed (8) hide show
  1. .gitignore +6 -0
  2. app.py +103 -0
  3. infer.py +472 -0
  4. infer.sh +30 -0
  5. pipeline.py +214 -0
  6. requirements.txt +14 -0
  7. utils/image_utils.py +514 -0
  8. utils/seed_all.py +33 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ outputs/
3
+ tmp/
4
+ .DS_Store
5
+ weights/
6
+ tmp_*
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # must be first!
2
+ import sys
3
+ import os
4
+ import torch
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from glob import glob
8
+ from contextlib import nullcontext
9
+ from pipeline import Lotus2Pipeline
10
+ from diffusers import (
11
+ FlowMatchEulerDiscreteScheduler,
12
+ FluxTransformer2DModel,
13
+ )
14
+ from infer import (
15
+ load_lora_and_lcm_weights,
16
+ process_single_image
17
+ )
18
+
19
+ pipeline = None
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ weight_dtype = torch.bfloat16
22
+ task = None
23
+
24
+ @spaces.GPU
25
+ def load_pipeline():
26
+ global pipeline, device, weight_dtype, task
27
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
28
+ 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
29
+ )
30
+ transformer = FluxTransformer2DModel.from_pretrained(
31
+ 'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
32
+ )
33
+ transformer.requires_grad_(False)
34
+ transformer.to(device=device, dtype=weight_dtype)
35
+ transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task)
36
+ pipeline = Lotus2Pipeline.from_pretrained(
37
+ 'black-forest-labs/FLUX.1-dev',
38
+ scheduler=noise_scheduler,
39
+ transformer=transformer,
40
+ revision=None,
41
+ variant=None,
42
+ torch_dtype=weight_dtype,
43
+ )
44
+ pipeline.local_continuity_module = local_continuity_module
45
+ pipeline = pipeline.to(device)
46
+
47
+ @spaces.GPU
48
+ def fn(image_path):
49
+ global pipeline, device, task
50
+ pipeline.set_progress_bar_config(disable=True)
51
+ with nullcontext():
52
+ _, output_vis, _ = process_single_image(
53
+ image_path, pipeline,
54
+ task_name=task,
55
+ device=device,
56
+ num_inference_steps=10,
57
+ process_res=1024
58
+ )
59
+ return [Image.open(image_path), output_vis]
60
+
61
+ def build_demo():
62
+ global task
63
+ inputs = [
64
+ gr.Image(label="Image", type="filepath")
65
+ ]
66
+ outputs = [
67
+ gr.ImageSlider(
68
+ label=f"{task.title()}",
69
+ type="pil",
70
+ slider_position=20,
71
+ )
72
+ ]
73
+ examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg")
74
+ demo = gr.Interface(
75
+ fn=fn,
76
+ title="Lotus-2: Advancing Geometric Dense Prediction with Powerful Image Generative Model",
77
+ description=f"""
78
+ <strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/Lotus-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful! 😊</strong>
79
+ <br>
80
+ <strong>Current Task: </strong><strong style="color: red;">{task.title()}</strong>
81
+ """,
82
+ inputs=inputs,
83
+ outputs=outputs,
84
+ examples=examples,
85
+ examples_per_page=10
86
+ )
87
+ return demo
88
+
89
+ def main(task_name):
90
+ global task
91
+ task = task_name
92
+ load_pipeline()
93
+ demo = build_demo()
94
+ demo.launch(
95
+ server_name="0.0.0.0",
96
+ server_port=6381,
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ task_name = "depth"
101
+ if not task_name in ['depth', 'normal']:
102
+ raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.")
103
+ main(task_name)
infer.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ Lotus-2 Inference Script
5
+
6
+ Usage:
7
+ python infer.py --pretrained_model_name_or_path <model_path> [other_args]
8
+
9
+ If --core_predictor_model_path, --lcm_model_path, or --detail_sharpener_model_path
10
+ are not provided, the script will automatically download the corresponding model
11
+ weights from the default HuggingFace repositories.
12
+ """
13
+
14
+ import argparse
15
+ import logging
16
+ import os
17
+ from contextlib import nullcontext
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from peft import LoraConfig, set_peft_model_state_dict
24
+ from PIL import Image
25
+ from torch import nn
26
+ from tqdm.auto import tqdm
27
+
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ HF_AVAILABLE = True
31
+ except ImportError:
32
+ HF_AVAILABLE = False
33
+ logging.warning("huggingface_hub not available. Model auto-download will not work.")
34
+
35
+ from diffusers import (
36
+ FlowMatchEulerDiscreteScheduler,
37
+ FluxTransformer2DModel,
38
+ )
39
+ from diffusers.utils import convert_unet_state_dict_to_peft
40
+ from utils.image_utils import colorize_depth_map
41
+ from pipeline import Lotus2Pipeline
42
+ from utils.seed_all import seed_all
43
+
44
+ # Default HuggingFace repositories and model filenames
45
+ DEFAULT_CORE_PREDICTOR_REPO = "jingheya/Lotus-2"
46
+ DEFAULT_LCM_REPO = "jingheya/Lotus-2"
47
+ DEFAULT_DETAIL_SHARPENER_REPO = "jingheya/Lotus-2"
48
+
49
+ CORE_PREDICTOR_FILENAME = {
50
+ "depth": "lotus-2_core_predictor_depth.safetensors",
51
+ "normal": "lotus-2_core_predictor_normal.safetensors"
52
+ }
53
+
54
+ LCM_FILENAME = {
55
+ "depth": "lotus-2_lcm_depth.safetensors",
56
+ "normal": "lotus-2_lcm_normal.safetensors"
57
+ }
58
+
59
+ DETAIL_SHARPENER_FILENAME = {
60
+ "depth": "lotus-2_detail_sharpener_depth.safetensors",
61
+ "normal": "lotus-2_detail_sharpener_normal.safetensors"
62
+ }
63
+
64
+ def get_model_path(model_path, repo_id, filename):
65
+ """
66
+ Get the local path for a model. If model_path is None, download from HuggingFace.
67
+
68
+ Args:
69
+ model_path: Local path to model or None to download from HF
70
+ repo_id: HuggingFace repository ID
71
+ filename: Model filename in the repository
72
+
73
+ Returns:
74
+ Local path to the model file
75
+ """
76
+ if model_path is not None:
77
+ return model_path
78
+
79
+ if not HF_AVAILABLE:
80
+ raise ImportError(
81
+ f"huggingface_hub is required for auto-downloading {filename} model weights. "
82
+ "Please install it with: pip install huggingface_hub"
83
+ )
84
+
85
+ logging.info(f"Downloading {filename} model weights from {repo_id}/{filename}")
86
+
87
+ try:
88
+ # Create cache directory if it doesn't exist
89
+ cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
90
+ os.makedirs(cache_dir, exist_ok=True)
91
+
92
+ # Download the entire repository and get the specific file
93
+ repo_path = snapshot_download(
94
+ repo_id=repo_id,
95
+ cache_dir=cache_dir,
96
+ local_files_only=False,
97
+ )
98
+
99
+ # Construct the full path to the specific file
100
+ full_path = os.path.join(repo_path, filename)
101
+
102
+ if not os.path.exists(full_path):
103
+ # Try to find the file in the repo
104
+ for root, dirs, files in os.walk(repo_path):
105
+ if filename in files:
106
+ full_path = os.path.join(root, filename)
107
+ break
108
+ else:
109
+ raise FileNotFoundError(f"Could not find {filename} in the downloaded repository")
110
+
111
+ logging.info(f"Successfully downloaded {filename} model to: {full_path}")
112
+ return full_path
113
+
114
+ except Exception as e:
115
+ raise RuntimeError(f"Failed to download {filename} model from {repo_id}: {str(e)}")
116
+
117
+
118
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
119
+ # check_min_version("0.33.0.dev0")
120
+
121
+
122
+ class Local_Continuity_Module(nn.Module):
123
+ def __init__(self, num_channels):
124
+ super().__init__()
125
+ self.lcm = nn.Sequential(
126
+ nn.Conv2d(num_channels, num_channels * 2, kernel_size=3, padding=1),
127
+ nn.GELU(),
128
+ nn.Conv2d(num_channels * 2, num_channels, kernel_size=3, padding=1),
129
+ )
130
+
131
+ def forward(self, x):
132
+ lcm_dtype = next(self.lcm.parameters()).dtype
133
+ if x.dtype != lcm_dtype:
134
+ x = x.to(dtype=lcm_dtype)
135
+ return x + self.lcm(x)
136
+
137
+ def parse_args(input_args=None):
138
+ parser = argparse.ArgumentParser(description="Run Lotus-2.")
139
+ parser.add_argument(
140
+ "--pretrained_model_name_or_path",
141
+ type=str,
142
+ default=None,
143
+ required=True,
144
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
145
+ )
146
+ parser.add_argument(
147
+ "--core_predictor_model_path",
148
+ type=str,
149
+ default=None,
150
+ help="Path to core predictor model weights",
151
+ )
152
+ parser.add_argument(
153
+ "--lcm_model_path",
154
+ type=str,
155
+ default=None,
156
+ help="Path to local continuity module model weights",
157
+ )
158
+ parser.add_argument(
159
+ "--detail_sharpener_model_path",
160
+ type=str,
161
+ default=None,
162
+ help="Path to detail sharpener model weights",
163
+ )
164
+ parser.add_argument(
165
+ "--revision",
166
+ type=str,
167
+ default=None,
168
+ required=False,
169
+ help="Revision of pretrained model identifier from huggingface.co/models.",
170
+ )
171
+ parser.add_argument(
172
+ "--variant",
173
+ type=str,
174
+ default=None,
175
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
176
+ )
177
+ parser.add_argument(
178
+ "--process_res",
179
+ type=int,
180
+ default=768,
181
+ help="The resolution for processing the images.",
182
+ )
183
+ parser.add_argument(
184
+ "--num_inference_steps",
185
+ type=int,
186
+ default=10,
187
+ help="Number of timesteps to infer the model.",
188
+ )
189
+ parser.add_argument(
190
+ "--input_dir",
191
+ type=str,
192
+ default=None,
193
+ help="The directory where the input images are stored.",
194
+ )
195
+ parser.add_argument(
196
+ "--output_dir",
197
+ type=str,
198
+ default="flux-dreambooth-lora",
199
+ help="The output directory where the model predictions will be written.",
200
+ )
201
+ parser.add_argument("--seed", type=int, default=None, help="Random seed.")
202
+ parser.add_argument(
203
+ "--task_name",
204
+ type=str,
205
+ default="depth", # "normal"
206
+ )
207
+ parser.add_argument(
208
+ "--mixed_precision",
209
+ type=str,
210
+ default=None,
211
+ choices=["no", "fp16", "bf16"],
212
+ help=(
213
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
214
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
215
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
216
+ ),
217
+ )
218
+
219
+ if input_args is not None:
220
+ args = parser.parse_args(input_args)
221
+ else:
222
+ args = parser.parse_args()
223
+
224
+ return args
225
+
226
+ def process_single_image(image_path, pipeline, task_name, device,
227
+ num_inference_steps, process_res=768):
228
+ image = Image.open(image_path).convert("RGB")
229
+ image_np = np.array(image).astype(np.float32)
230
+ image_ts = torch.tensor(image_np).permute(2,0,1).unsqueeze(0)
231
+ image_ts = image_ts / 127.5 - 1.0
232
+ image_ts = image_ts.to(device)
233
+
234
+ prediction = pipeline(
235
+ rgb_in=image_ts,
236
+ prompt='',
237
+ num_inference_steps=num_inference_steps,
238
+ output_type='np',
239
+ process_res=process_res,
240
+ ).images[0]
241
+
242
+ if task_name == "depth":
243
+ output_npy = prediction.mean(axis=-1)
244
+ output_vis = colorize_depth_map(output_npy, reverse_color=True)
245
+ elif task_name == "normal":
246
+ output_npy = prediction
247
+ output_vis = Image.fromarray((output_npy * 255).astype(np.uint8))
248
+ else:
249
+ raise ValueError(f"Invalid task name: {task_name}")
250
+
251
+ return image, output_vis, output_npy
252
+
253
+ def load_lora_and_lcm_weights(transformer, core_predictor_model_path, lcm_model_path, detail_sharpener_model_path, task_name):
254
+ lora_rank = 128 if task_name == 'depth' else 256
255
+ device = transformer.device
256
+ weight_dtype = transformer.dtype
257
+
258
+ target_lora_modules = [
259
+ "attn.to_k",
260
+ "attn.to_q",
261
+ "attn.to_v",
262
+ "attn.to_out.0",
263
+ "attn.add_k_proj",
264
+ "attn.add_q_proj",
265
+ "attn.add_v_proj",
266
+ "attn.to_add_out",
267
+ "ff.net.0.proj",
268
+ "ff.net.2",
269
+ "ff_context.net.0.proj",
270
+ "ff_context.net.2",
271
+ ]
272
+
273
+ # Auto-download models if paths are None
274
+ core_predictor_model_path = get_model_path(
275
+ core_predictor_model_path,
276
+ DEFAULT_CORE_PREDICTOR_REPO,
277
+ CORE_PREDICTOR_FILENAME[task_name]
278
+ )
279
+
280
+ lcm_model_path = get_model_path(
281
+ lcm_model_path,
282
+ DEFAULT_LCM_REPO,
283
+ LCM_FILENAME[task_name]
284
+ )
285
+
286
+ detail_sharpener_model_path = get_model_path(
287
+ detail_sharpener_model_path,
288
+ DEFAULT_DETAIL_SHARPENER_REPO,
289
+ DETAIL_SHARPENER_FILENAME[task_name]
290
+ )
291
+
292
+ # load lora weights for core predictor
293
+ core_transformer_lora_config = LoraConfig(
294
+ r=lora_rank,
295
+ lora_alpha=lora_rank,
296
+ init_lora_weights="gaussian",
297
+ target_modules=target_lora_modules,
298
+ )
299
+ transformer.add_adapter(core_transformer_lora_config, adapter_name="core_predictor")
300
+
301
+ core_lora_state_dict = Lotus2Pipeline.lora_state_dict(core_predictor_model_path)
302
+ core_transformer_state_dict = {
303
+ f'{k.replace("transformer.", "")}': v for k, v in core_lora_state_dict.items() if k.startswith("transformer.")
304
+ }
305
+ core_transformer_state_dict = convert_unet_state_dict_to_peft(core_transformer_state_dict)
306
+ incompatible_keys = set_peft_model_state_dict(transformer, core_transformer_state_dict, adapter_name="core_predictor")
307
+ if incompatible_keys is not None:
308
+ # check only for unexpected keys
309
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
310
+ if unexpected_keys:
311
+ logging.warning(
312
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
313
+ f" {unexpected_keys}. "
314
+ )
315
+
316
+ for name, param in transformer.named_parameters():
317
+ if "core_predictor" in name:
318
+ param.requires_grad = False
319
+ # transformer.to(device=device, dtype=weight_dtype)
320
+ logging.info(f"Successfully loaded lora weights for [core predictor].")
321
+
322
+ # stage1 lcm weights
323
+ local_continuity_module = Local_Continuity_Module(transformer.config.in_channels//4)
324
+ lcm_state_dict = torch.load(lcm_model_path, map_location="cpu", weights_only=True)
325
+ local_continuity_module.load_state_dict(lcm_state_dict)
326
+ local_continuity_module.requires_grad_(False)
327
+ local_continuity_module.to(device=device, dtype=weight_dtype)
328
+ logging.info(f"Successfully loaded weights for [local continuity module (LCM)].")
329
+
330
+ # stage2 lora weights (detail sharpener)
331
+ sharpener_transformer_lora_config = LoraConfig(
332
+ r=lora_rank,
333
+ lora_alpha=lora_rank,
334
+ init_lora_weights="gaussian",
335
+ target_modules=target_lora_modules,
336
+ )
337
+ transformer.add_adapter(sharpener_transformer_lora_config, adapter_name="detail_sharpener")
338
+
339
+ sharpener_lora_state_dict = Lotus2Pipeline.lora_state_dict(detail_sharpener_model_path)
340
+ sharpener_transformer_state_dict = {
341
+ f'{k.replace("transformer.", "")}': v for k, v in sharpener_lora_state_dict.items() if k.startswith("transformer.")
342
+ }
343
+ sharpener_transformer_state_dict = convert_unet_state_dict_to_peft(sharpener_transformer_state_dict)
344
+ incompatible_keys = set_peft_model_state_dict(transformer, sharpener_transformer_state_dict, adapter_name="detail_sharpener")
345
+ if incompatible_keys is not None:
346
+ # check only for unexpected keys
347
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
348
+ if unexpected_keys:
349
+ logging.warning(
350
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
351
+ f" {unexpected_keys}. "
352
+ )
353
+
354
+ # freeze the stage2 lora
355
+ for name, param in transformer.named_parameters():
356
+ if "detail_sharpener" in name:
357
+ param.requires_grad = False
358
+ # transformer.to(device=device, dtype=weight_dtype)
359
+ logging.info(f"Successfully loaded lora weights for [detail sharpener].")
360
+
361
+ return transformer, local_continuity_module
362
+
363
+ def main(args):
364
+ logging.basicConfig(
365
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
366
+ datefmt="%m/%d/%Y %H:%M:%S",
367
+ level=logging.INFO,
368
+ )
369
+ logging.info("Run Lotus-2! ")
370
+
371
+ # -------------------- Preparation --------------------
372
+ # Check if model paths are provided, if not, they will be auto-downloaded from HuggingFace
373
+ if args.core_predictor_model_path is None or args.lcm_model_path is None or args.detail_sharpener_model_path is None:
374
+ if HF_AVAILABLE:
375
+ logging.info("Some model paths are not provided. Model weights will be automatically downloaded from HuggingFace.")
376
+ logging.info(f"Core predictor repo: {DEFAULT_CORE_PREDICTOR_REPO}")
377
+ logging.info(f"LCM repo: {DEFAULT_LCM_REPO}")
378
+ logging.info(f"Detail sharpener repo: {DEFAULT_DETAIL_SHARPENER_REPO}")
379
+ else:
380
+ logging.warning("Some model paths are not provided and huggingface_hub is not available.")
381
+ logging.warning("Please install huggingface_hub: pip install huggingface_hub")
382
+ logging.warning("Or provide local paths for all model weights.")
383
+
384
+ # Random seed
385
+ if args.seed is not None:
386
+ seed_all(args.seed)
387
+
388
+ # Output directories
389
+ os.makedirs(args.output_dir, exist_ok=True)
390
+
391
+ output_dir_vis = os.path.join(args.output_dir, f'{args.task_name}_vis')
392
+ output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}_npy')
393
+ if not os.path.exists(output_dir_vis): os.makedirs(output_dir_vis)
394
+ if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)
395
+
396
+ logging.info(f"Output dir = {args.output_dir}")
397
+
398
+ # Mixed precision
399
+ if args.mixed_precision == "fp16":
400
+ weight_dtype = torch.float16
401
+ elif args.mixed_precision == "bf16":
402
+ weight_dtype = torch.bfloat16
403
+ else:
404
+ weight_dtype = torch.float32
405
+ logging.info(f"Running with {weight_dtype} precision.")
406
+
407
+ # Device
408
+ if torch.cuda.is_available():
409
+ device = torch.device("cuda")
410
+ else:
411
+ device = torch.device("cpu")
412
+ logging.warning("CUDA is not available. Running on CPU will be slow.")
413
+ logging.info(f"Device = {device}")
414
+
415
+ # -------------------- Data --------------------
416
+ input_dir = Path(args.input_dir)
417
+ test_images = list(input_dir.rglob('*.png')) + list(input_dir.rglob('*.jpg'))
418
+ test_images = sorted(test_images)
419
+ logging.info(f'==> There are {len(test_images)} images for validation.')
420
+
421
+ # -------------------- Load scheduler and models --------------------
422
+ # scheduler
423
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
424
+ args.pretrained_model_name_or_path, subfolder="scheduler", num_train_timesteps=10
425
+ )
426
+ # transformer
427
+ transformer = FluxTransformer2DModel.from_pretrained(
428
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
429
+ )
430
+ transformer.requires_grad_(False)
431
+ transformer.to(device=device, dtype=weight_dtype)
432
+
433
+ # load weights
434
+ transformer, local_continuity_module = load_lora_and_lcm_weights(transformer,
435
+ args.core_predictor_model_path,
436
+ args.lcm_model_path,
437
+ args.detail_sharpener_model_path,
438
+ args.task_name
439
+ )
440
+
441
+ # -------------------- Pipeline --------------------
442
+ pipeline = Lotus2Pipeline.from_pretrained(
443
+ args.pretrained_model_name_or_path,
444
+ scheduler=noise_scheduler,
445
+ transformer=transformer,
446
+ revision=args.revision,
447
+ variant=args.variant,
448
+ torch_dtype=weight_dtype,
449
+ )
450
+ pipeline.local_continuity_module = local_continuity_module
451
+ pipeline = pipeline.to(device)
452
+
453
+ # -------------------- Run inference! --------------------
454
+ pipeline.set_progress_bar_config(disable=True)
455
+
456
+ with nullcontext():
457
+ for image_path in tqdm(test_images):
458
+ # print("\n",image_path)
459
+ _, output_vis, output_npy = process_single_image(
460
+ image_path, pipeline,
461
+ task_name=args.task_name,
462
+ device=device,
463
+ num_inference_steps=args.num_inference_steps,
464
+ process_res=args.process_res
465
+ )
466
+
467
+ output_vis.save(os.path.join(output_dir_vis, f'{image_path.stem}.png'))
468
+ np.save(os.path.join(output_dir_npy, f'{image_path.stem}.npy'), output_npy)
469
+
470
+ if __name__ == "__main__":
471
+ args = parse_args()
472
+ main(args)
infer.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export OPENCV_IO_ENABLE_OPENEXR=1
2
+ export TOKENIZERS_PARALLELISM=false
3
+
4
+ export TASK_NAME="normal"
5
+
6
+ # paths
7
+ export MODEL_NAME="black-forest-labs/FLUX.1-dev"
8
+ # export CORE_PREDICTOR_MODEL_PATH="weights/lotus-2_core_predictor_$TASK_NAME.safetensors"
9
+ # export DETAIL_SHARPENER_MODEL_PATH="weights/lotus-2_detail_sharpener_$TASK_NAME.safetensors"
10
+ # export LCM_MODEL_PATH="weights/lotus-2_lcm_$TASK_NAME.safetensors"
11
+
12
+ export INPUT_DIR="assets"
13
+ export OUTPUT_DIR="outputs/infer/"
14
+
15
+ # configs
16
+ export NUM_INFERENCE_STEPS=10
17
+
18
+
19
+ CUDA_VISIBLE_DEVICES=0 python infer.py \
20
+ --pretrained_model_name_or_path=$MODEL_NAME \
21
+ --input_dir=$INPUT_DIR \
22
+ --output_dir=$OUTPUT_DIR \
23
+ --mixed_precision="bf16" \
24
+ --num_inference_steps=$NUM_INFERENCE_STEPS \
25
+ --seed="0" \
26
+ --task_name=$TASK_NAME \
27
+ --process_res=1024
28
+ # --core_predictor_model_path=$CORE_PREDICTOR_MODEL_PATH \
29
+ # --detail_sharpener_model_path=$DETAIL_SHARPENER_MODEL_PATH \
30
+ # --lcm_model_path=$LCM_MODEL_PATH \
pipeline.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, List, Dict, Any
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import FluxPipeline
5
+ from diffusers.pipelines.flux import FluxPipelineOutput
6
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
7
+ from diffusers.utils import is_torch_xla_available
8
+
9
+ from utils.image_utils import resize_image, resize_image_first
10
+
11
+ if is_torch_xla_available():
12
+ import torch_xla.core.xla_model as xm
13
+ XLA_AVAILABLE = True
14
+ else:
15
+ XLA_AVAILABLE = False
16
+
17
+ class Lotus2Pipeline(FluxPipeline):
18
+ @torch.no_grad()
19
+ def __call__(
20
+ self,
21
+ rgb_in: Optional[torch.FloatTensor] = None,
22
+ prompt: Union[str, List[str]] = None,
23
+ num_inference_steps: int = 10,
24
+ output_type: Optional[str] = "pil",
25
+ process_res: Optional[int] = None,
26
+ timestep_core_predictor: int = 1,
27
+ guidance_scale: float = 3.5,
28
+ return_dict: bool = True,
29
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
30
+ ):
31
+ r"""
32
+ Function invoked when calling the pipeline for generation.
33
+
34
+ Args:
35
+ rgb_in (`torch.FloatTensor`, *optional*):
36
+ The input image to be used for generation.
37
+ prompt (`str` or `List[str]`, *optional*):
38
+ The prompt or prompts to guide the prediction. Default is ''.
39
+ num_inference_steps (`int`, *optional*, defaults to 10):
40
+ The number of denoising steps. More denoising steps usually lead to a sharper prediction at the
41
+ expense of slower inference.
42
+ guidance_scale (`float`, *optional*, defaults to 7.0):
43
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
44
+ output_type (`str`, *optional*, defaults to `"pil"`):
45
+ The output format of the generate image. Choose between
46
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
47
+ return_dict (`bool`, *optional*, defaults to `True`):
48
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
49
+ joint_attention_kwargs (`dict`, *optional*):
50
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
51
+ `self.processor` in
52
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
53
+
54
+ Examples:
55
+
56
+ Returns:
57
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
58
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
59
+ images.
60
+ """
61
+ # 1. prepare
62
+ batch_size = rgb_in.shape[0]
63
+ input_size = rgb_in.shape[2:]
64
+ rgb_in = resize_image_first(rgb_in, process_res)
65
+ height, width = rgb_in.shape[2:]
66
+
67
+ self._guidance_scale = guidance_scale
68
+ self._joint_attention_kwargs = joint_attention_kwargs
69
+ self._interrupt = False
70
+
71
+ device = self._execution_device
72
+
73
+ # 2. encode prompt
74
+ (
75
+ prompt_embeds,
76
+ pooled_prompt_embeds,
77
+ text_ids,
78
+ ) = self.encode_prompt(
79
+ prompt=prompt,
80
+ prompt_2=None,
81
+ device=device,
82
+ )
83
+
84
+ # 3. prepare latent variables
85
+ rgb_in = rgb_in.to(device=device, dtype=self.dtype)
86
+ rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
87
+ rgb_latents = (rgb_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
88
+
89
+ packed_rgb_latents = self._pack_latents(
90
+ rgb_latents,
91
+ batch_size=rgb_latents.shape[0],
92
+ num_channels_latents=rgb_latents.shape[1],
93
+ height=rgb_latents.shape[2],
94
+ width=rgb_latents.shape[3],
95
+ )
96
+
97
+ latent_image_ids_core_predictor = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
98
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
99
+
100
+ # 4. prepare timesteps
101
+ timestep_core_predictor = torch.tensor(timestep_core_predictor).expand(batch_size).to(device=rgb_in.device, dtype=rgb_in.dtype)
102
+
103
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
104
+ image_seq_len = packed_rgb_latents.shape[1]
105
+ mu = calculate_shift(
106
+ image_seq_len,
107
+ self.scheduler.config.base_image_seq_len,
108
+ self.scheduler.config.max_image_seq_len,
109
+ self.scheduler.config.base_shift,
110
+ self.scheduler.config.max_shift,
111
+ )
112
+ timesteps, num_inference_steps = retrieve_timesteps(
113
+ self.scheduler,
114
+ num_inference_steps,
115
+ device,
116
+ sigmas=sigmas,
117
+ mu=mu,
118
+ )
119
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 0
120
+ self._num_timesteps = len(timesteps)
121
+
122
+ # 5. handle guidance
123
+ if self.transformer.config.guidance_embeds:
124
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
125
+ guidance = guidance.expand(packed_rgb_latents.shape[0])
126
+ else:
127
+ guidance = None
128
+
129
+ if self.joint_attention_kwargs is None:
130
+ self._joint_attention_kwargs = {}
131
+
132
+ # 6. core predictor
133
+ self.transformer.set_adapter("core_predictor")
134
+ latents = self.transformer(
135
+ hidden_states=packed_rgb_latents,
136
+ timestep=timestep_core_predictor / 1000,
137
+ guidance=guidance,
138
+ pooled_projections=pooled_prompt_embeds,
139
+ encoder_hidden_states=prompt_embeds,
140
+ txt_ids=text_ids,
141
+ img_ids=latent_image_ids_core_predictor,
142
+ joint_attention_kwargs=self.joint_attention_kwargs, # {}
143
+ return_dict=False,
144
+ )[0]
145
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
146
+ latents = self.local_continuity_module(latents)
147
+
148
+ # 7. Denoising loop for detail sharpener
149
+ self.transformer.set_adapter("detail_sharpener")
150
+ latents = self._pack_latents(
151
+ latents,
152
+ batch_size=latents.shape[0],
153
+ num_channels_latents=latents.shape[1],
154
+ height=latents.shape[2],
155
+ width=latents.shape[3],
156
+ )
157
+
158
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
159
+ for i, t in enumerate(timesteps):
160
+ if self.interrupt:
161
+ continue
162
+
163
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
164
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
165
+
166
+ noise_pred = self.transformer(
167
+ hidden_states=latents,
168
+ timestep=timestep / 1000,
169
+ guidance=guidance,
170
+ pooled_projections=pooled_prompt_embeds,
171
+ encoder_hidden_states=prompt_embeds,
172
+ txt_ids=text_ids,
173
+ img_ids=latent_image_ids,
174
+ joint_attention_kwargs=self.joint_attention_kwargs,
175
+ return_dict=False,
176
+ )[0]
177
+
178
+ # compute the previous noisy sample x_t -> x_t-1
179
+ latents_dtype = latents.dtype
180
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
181
+
182
+ if latents.dtype != latents_dtype:
183
+ if torch.backends.mps.is_available():
184
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
185
+ latents = latents.to(latents_dtype)
186
+
187
+ # call the callback, if provided
188
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
189
+ progress_bar.update()
190
+
191
+ if XLA_AVAILABLE:
192
+ xm.mark_step()
193
+
194
+ latents = latents.to(dtype=self.dtype)
195
+
196
+ if output_type == "latent":
197
+ image = latents
198
+
199
+ else:
200
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
201
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
202
+ image = self.vae.decode(latents, return_dict=False)[0]
203
+ image = self.image_processor.postprocess(image, output_type=output_type)
204
+
205
+ # Resize output image to match input size
206
+ image = resize_image(image, input_size)
207
+
208
+ # Offload all models
209
+ self.maybe_free_model_hooks()
210
+
211
+ if not return_dict:
212
+ return (image,)
213
+
214
+ return FluxPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ matplotlib==3.10.0
3
+ peft==0.14.0
4
+ protobuf==5.29.0
5
+ sentencepiece==0.2.0
6
+ opencv-python==4.11.0.86
7
+ huggingface-hub==0.36.0
8
+ diffusers==0.32.2
9
+ torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
10
+ torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
11
+ gradio==5.49.0
12
+ gradio-client==1.13.3
13
+ gradio-imageslider==0.0.20
14
+ spaces==0.42.1
utils/image_utils.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib
3
+ import numpy as np
4
+ from typing import List
5
+ import csv
6
+ import cv2
7
+
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.functional import resize
13
+
14
+ def numpy_to_pil(images: np.ndarray) -> List[Image.Image]:
15
+ r"""
16
+ Convert a numpy image or a batch of images to a PIL image.
17
+
18
+ Args:
19
+ images (`np.ndarray`):
20
+ The image array to convert to PIL format.
21
+
22
+ Returns:
23
+ `List[PIL.Image.Image]`:
24
+ A list of PIL images.
25
+ """
26
+ if images.ndim == 3:
27
+ images = images[None, ...]
28
+ images = (images * 255).round().astype("uint8")
29
+ if images.shape[-1] == 1:
30
+ # special case for grayscale (single channel) images
31
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
32
+ else:
33
+ pil_images = [Image.fromarray(image) for image in images]
34
+
35
+ return pil_images
36
+
37
+ def resize_output(image, target_size):
38
+ """
39
+ Resize output image to target size
40
+ Args:
41
+ image: Image in PIL.Image, numpy.array or torch.tensor format
42
+ target_size: tuple, target size (H, W)
43
+ Returns:
44
+ Resized image in original format
45
+ """
46
+ if isinstance(image, list):
47
+ return [resize_output(img, target_size) for img in image]
48
+
49
+ if isinstance(image, Image.Image):
50
+ return image.resize(target_size[::-1], Image.BILINEAR)
51
+ elif isinstance(image, np.ndarray):
52
+ # Handle numpy array with shape (1, H, W, 3)
53
+ if image.ndim == 4:
54
+ resized = np.stack([cv2.resize(img, target_size[::-1]) for img in image])
55
+ return resized
56
+ else:
57
+ return cv2.resize(image, target_size[::-1])
58
+ elif isinstance(image, torch.Tensor):
59
+ # Handle tensor with shape (1, 3, H, W)
60
+ if image.dim() == 4:
61
+ return torch.nn.functional.interpolate(
62
+ image,
63
+ size=target_size,
64
+ mode='bilinear',
65
+ align_corners=False
66
+ )
67
+ else:
68
+ return torch.nn.functional.interpolate(
69
+ image.unsqueeze(0),
70
+ size=target_size,
71
+ mode='bilinear',
72
+ align_corners=False
73
+ ).squeeze(0)
74
+ else:
75
+ raise ValueError(f"Unsupported image format: {type(image)}")
76
+
77
+ def resize_image(image, target_size):
78
+ """
79
+ Resize output image to target size
80
+ Args:
81
+ image: Image in PIL.Image, numpy.array or torch.tensor format
82
+ target_size: tuple, target size (H, W)
83
+ Returns:
84
+ Resized image in original format
85
+ """
86
+ if isinstance(image, list):
87
+ return [resize_image(img, target_size) for img in image]
88
+
89
+ if isinstance(image, Image.Image):
90
+ return image.resize(target_size[::-1], Image.BILINEAR)
91
+ elif isinstance(image, np.ndarray):
92
+ # Handle numpy array with shape (1, H, W, 3)
93
+ if image.ndim == 4:
94
+ resized = np.stack([cv2.resize(img, target_size[::-1]) for img in image])
95
+ return resized
96
+ else:
97
+ return cv2.resize(image, target_size[::-1])
98
+ elif isinstance(image, torch.Tensor):
99
+ # Handle tensor with shape (1, 3, H, W)
100
+ if image.dim() == 4:
101
+ return torch.nn.functional.interpolate(
102
+ image,
103
+ size=target_size,
104
+ mode='bilinear',
105
+ align_corners=False
106
+ )
107
+ else:
108
+ return torch.nn.functional.interpolate(
109
+ image.unsqueeze(0),
110
+ size=target_size,
111
+ mode='bilinear',
112
+ align_corners=False
113
+ ).squeeze(0)
114
+ else:
115
+ raise ValueError(f"Unsupported image format: {type(image)}")
116
+
117
+ def resize_image_first(image_tensor, process_res=None):
118
+ if process_res:
119
+ max_edge = max(image_tensor.shape[2], image_tensor.shape[3])
120
+ if max_edge > process_res:
121
+ scale = process_res / max_edge
122
+ new_height = int(image_tensor.shape[2] * scale)
123
+ new_width = int(image_tensor.shape[3] * scale)
124
+ image_tensor = resize_image(image_tensor, (new_height, new_width))
125
+
126
+ image_tensor = resize_to_multiple_of_16(image_tensor)
127
+
128
+ return image_tensor
129
+
130
+
131
+ def smooth_image(image, method='gaussian', kernel_size=31, sigma=15.0, bilateral_d=9, bilateral_color=75, bilateral_space=75):
132
+ """
133
+ 应用多种平滑方法来消除图像中的网格伪影
134
+
135
+ Args:
136
+ image: PIL.Image, numpy.array 或 torch.tensor 格式的图像
137
+ method: 平滑方法,可选 'gaussian'(高斯模糊), 'bilateral'(双边滤波), 'median'(中值滤波),
138
+ 'guided'(引导滤波), 'strong'(结合多种滤波的强力平滑)
139
+ kernel_size: 高斯和中值滤波的核大小,默认为31,应为奇数
140
+ sigma: 高斯滤波的标准差,默认为15.0
141
+ bilateral_d: 双边滤波的直径,默认为9
142
+ bilateral_color: 双边滤波的颜色空间标准差,默认为75
143
+ bilateral_space: 双边滤波的坐标空间标准差,默认为75
144
+
145
+ Returns:
146
+ 平滑后的图像,保持原始格式
147
+ """
148
+ if isinstance(image, list):
149
+ return [smooth_image(img, method, kernel_size, sigma, bilateral_d, bilateral_color, bilateral_space) for img in image]
150
+
151
+ # 确保kernel_size是奇数
152
+ if kernel_size % 2 == 0:
153
+ kernel_size += 1
154
+
155
+ # 转换为numpy数组进行处理
156
+ is_pil = isinstance(image, Image.Image)
157
+ is_tensor = isinstance(image, torch.Tensor)
158
+
159
+ if is_pil:
160
+ img_array = np.array(image)
161
+ elif is_tensor:
162
+ device = image.device
163
+ if image.dim() == 4: # (B, C, H, W)
164
+ batch_size, channels, height, width = image.shape
165
+ img_array = image.permute(0, 2, 3, 1).cpu().numpy() # (B, H, W, C)
166
+ else: # (C, H, W)
167
+ img_array = image.permute(1, 2, 0).cpu().numpy() # (H, W, C)
168
+ else:
169
+ img_array = image
170
+
171
+ # 保存原始数据类型
172
+ original_dtype = img_array.dtype
173
+
174
+ # 应用选定的平滑方法
175
+ if method == 'gaussian':
176
+ # 标准高斯模糊,适合轻微平滑
177
+ if img_array.ndim == 4:
178
+ smoothed = np.stack([cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma) for img in img_array])
179
+ else:
180
+ smoothed = cv2.GaussianBlur(img_array, (kernel_size, kernel_size), sigma)
181
+
182
+ elif method == 'bilateral':
183
+ # 双边滤波,保持边缘的同时平滑平坦区域
184
+ if img_array.ndim == 4:
185
+ # 确保图像是8位类型
186
+ imgs_uint8 = [img.astype(np.uint8) if img.dtype != np.uint8 else img for img in img_array]
187
+ smoothed = np.stack([cv2.bilateralFilter(img, bilateral_d, bilateral_color, bilateral_space) for img in imgs_uint8])
188
+ # 转回原始类型
189
+ if original_dtype != np.uint8:
190
+ smoothed = smoothed.astype(original_dtype)
191
+ else:
192
+ # 确保图像是8位类型
193
+ img_uint8 = img_array.astype(np.uint8) if img_array.dtype != np.uint8 else img_array
194
+ smoothed = cv2.bilateralFilter(img_uint8, bilateral_d, bilateral_color, bilateral_space)
195
+ # 转回原始类型
196
+ if original_dtype != np.uint8:
197
+ smoothed = smoothed.astype(original_dtype)
198
+
199
+ elif method == 'median':
200
+ # 中值滤波,对于消除盐和胡椒噪声和小格子非常有效
201
+ # 中值滤波要求输入为uint8或uint16
202
+ if img_array.ndim == 4:
203
+ # 转换为8位无符号整数并确保格式正确
204
+ imgs_uint8 = []
205
+ for img in img_array:
206
+ # 对浮点图像进行缩放到0-255范围
207
+ if img.dtype != np.uint8:
208
+ if img.max() <= 1.0: # 检查是否是0-1范围的浮点数
209
+ img = (img * 255).astype(np.uint8)
210
+ else:
211
+ img = img.astype(np.uint8)
212
+ imgs_uint8.append(img)
213
+
214
+ smoothed = np.stack([cv2.medianBlur(img, kernel_size) for img in imgs_uint8])
215
+ # 转回原始类型
216
+ if original_dtype != np.uint8:
217
+ if original_dtype == np.float32 or original_dtype == np.float64:
218
+ if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
219
+ smoothed = smoothed.astype(float) / 255.0
220
+
221
+ else:
222
+ # 转换为8位无符号整数
223
+ if img_array.dtype != np.uint8:
224
+ if img_array.max() <= 1.0: # 检查是否是0-1范围的浮点数
225
+ img_uint8 = (img_array * 255).astype(np.uint8)
226
+ else:
227
+ img_uint8 = img_array.astype(np.uint8)
228
+ else:
229
+ img_uint8 = img_array
230
+
231
+ smoothed = cv2.medianBlur(img_uint8, kernel_size)
232
+ # 转回原始类型
233
+ if original_dtype != np.uint8:
234
+ if original_dtype == np.float32 or original_dtype == np.float64:
235
+ if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
236
+ smoothed = smoothed.astype(float) / 255.0
237
+ else:
238
+ smoothed = smoothed.astype(original_dtype)
239
+
240
+ elif method == 'guided':
241
+ # 引导滤波,在保持边缘的同时平滑区域
242
+ if img_array.ndim == 4:
243
+ smoothed = np.stack([cv2.ximgproc.guidedFilter(
244
+ guide=img, src=img, radius=kernel_size//2, eps=1e-6) for img in img_array])
245
+ else:
246
+ smoothed = cv2.ximgproc.guidedFilter(
247
+ guide=img_array, src=img_array, radius=kernel_size//2, eps=1e-6)
248
+
249
+ elif method == 'strong':
250
+ # 强力平滑:先应用中值滤波去除尖锐噪点,然后用双边滤波保持边缘,最后用高斯进一步平滑
251
+ if img_array.ndim == 4:
252
+ # 转换为8位无符号整数
253
+ imgs_uint8 = []
254
+ for img in img_array:
255
+ # 对浮点图像进行缩放到0-255范围
256
+ if img.dtype != np.uint8:
257
+ if img.max() <= 1.0: # 检查是否是0-1范围的浮点数
258
+ img = (img * 255).astype(np.uint8)
259
+ else:
260
+ img = img.astype(np.uint8)
261
+ imgs_uint8.append(img)
262
+
263
+ temp = np.stack([cv2.medianBlur(img, min(15, kernel_size)) for img in imgs_uint8])
264
+ temp = np.stack([cv2.bilateralFilter(img, bilateral_d, bilateral_color, bilateral_space) for img in temp])
265
+ smoothed = np.stack([cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma) for img in temp])
266
+
267
+ # 转回原始类型
268
+ if original_dtype != np.uint8:
269
+ if original_dtype == np.float32 or original_dtype == np.float64:
270
+ if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
271
+ smoothed = smoothed.astype(float) / 255.0
272
+ else:
273
+ smoothed = smoothed.astype(original_dtype)
274
+ else:
275
+ # 转换为8位无符号整数
276
+ if img_array.dtype != np.uint8:
277
+ if img_array.max() <= 1.0: # 检查是否是0-1范围的浮点数
278
+ img_uint8 = (img_array * 255).astype(np.uint8)
279
+ else:
280
+ img_uint8 = img_array.astype(np.uint8)
281
+ else:
282
+ img_uint8 = img_array
283
+
284
+ temp = cv2.medianBlur(img_uint8, min(15, kernel_size))
285
+ temp = cv2.bilateralFilter(temp, bilateral_d, bilateral_color, bilateral_space)
286
+ smoothed = cv2.GaussianBlur(temp, (kernel_size, kernel_size), sigma)
287
+
288
+ # 转回原始类型
289
+ if original_dtype != np.uint8:
290
+ if original_dtype == np.float32 or original_dtype == np.float64:
291
+ if img_array.max() <= 1.0: # 检查原始数据是否在0-1范围
292
+ smoothed = smoothed.astype(float) / 255.0
293
+ else:
294
+ smoothed = smoothed.astype(original_dtype)
295
+
296
+ else:
297
+ raise ValueError(f"不支持的平滑方法: {method},请选择 'gaussian', 'bilateral', 'median', 'guided' 或 'strong'")
298
+
299
+ # 将结果转换回原始格式
300
+ if is_pil:
301
+ # 如果结果是浮点类型且值在0-1之间,需要先转换为0-255的uint8
302
+ if smoothed.dtype == np.float32 or smoothed.dtype == np.float64:
303
+ if smoothed.max() <= 1.0:
304
+ smoothed = (smoothed * 255).astype(np.uint8)
305
+ return Image.fromarray(smoothed.astype(np.uint8))
306
+ elif is_tensor:
307
+ if image.dim() == 4:
308
+ return torch.from_numpy(smoothed).permute(0, 3, 1, 2).to(device)
309
+ else:
310
+ return torch.from_numpy(smoothed).permute(2, 0, 1).to(device)
311
+ else:
312
+ return smoothed
313
+
314
+ def resize_to_multiple_of_16(image_tensor):
315
+ """
316
+ Resize image tensor to make shorter side closest multiple of 16 while maintaining aspect ratio
317
+ Args:
318
+ image_tensor: Input tensor of shape (B, C, H, W)
319
+ Returns:
320
+ Resized tensor where shorter side is multiple of 16
321
+ """
322
+ # Calculate scale ratio based on shorter side to make it closest multiple of 16
323
+ h, w = image_tensor.shape[2], image_tensor.shape[3]
324
+ min_side = min(h, w)
325
+ scale = (min_side // 16) * 16 / min_side
326
+
327
+ # Calculate new height and width
328
+ new_h = int(h * scale)
329
+ new_w = int(w * scale)
330
+
331
+ # Ensure both height and width are multiples of 16
332
+ new_h = (new_h // 16) * 16
333
+ new_w = (new_w // 16) * 16
334
+
335
+ # Resize image while maintaining aspect ratio
336
+ resized_tensor = torch.nn.functional.interpolate(
337
+ image_tensor,
338
+ size=(new_h, new_w),
339
+ mode='bilinear',
340
+ align_corners=False
341
+ )
342
+ return resized_tensor
343
+
344
+ def load_color_list(csv_path):
345
+ color_list = []
346
+ with open(csv_path, newline='') as file:
347
+ reader = csv.reader(file)
348
+
349
+ next(reader)
350
+
351
+ for row in reader:
352
+ last_three = tuple(map(int, row[-3:]))
353
+ color_list.append(last_three)
354
+
355
+ color_list = [(0,0,0)] + color_list
356
+
357
+ return color_list
358
+
359
+ def conver_rgb_to_semantic_map(image: Image, color_list: List):
360
+ # Convert PIL Image to numpy array
361
+ image_array = np.array(image)
362
+
363
+ # Initialize an empty array for the indexed image
364
+ indexed_image = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=int)
365
+
366
+ # Loop through each pixel in the image
367
+ for i in range(image_array.shape[0]):
368
+ for j in range(image_array.shape[1]):
369
+ # Get the color of the current pixel
370
+ pixel_color = tuple(image_array[i, j][:3]) # Exclude the alpha channel if present
371
+
372
+ # Find the closest color from the color list and get its index
373
+ # Here, the Euclidean distance is used to find the closest color
374
+ distances = np.sqrt(np.sum((np.array(color_list) - np.array(pixel_color))**2, axis=1))
375
+ closest_color_index = np.argmin(distances)
376
+
377
+ # Set the index in the indexed image
378
+ indexed_image[i, j] = closest_color_index
379
+
380
+ indexed_image = indexed_image - 1
381
+
382
+ return indexed_image
383
+
384
+
385
+ def concatenate_images(*image_lists):
386
+ # Ensure at least one image list is provided
387
+ if not image_lists or not image_lists[0]:
388
+ raise ValueError("At least one non-empty image list must be provided")
389
+
390
+ # Determine the maximum width of any single row and the total height
391
+ max_width = 0
392
+ total_height = 0
393
+ row_widths = []
394
+ row_heights = []
395
+
396
+ # Compute dimensions for each row
397
+ for image_list in image_lists:
398
+ if image_list: # Ensure the list is not empty
399
+ width = sum(img.width for img in image_list)
400
+ height = max(img.height for img in image_list)
401
+ max_width = max(max_width, width)
402
+ total_height += height
403
+ row_widths.append(width)
404
+ row_heights.append(height)
405
+
406
+ # Create a new image to concatenate everything into
407
+ new_image = Image.new('RGB', (max_width, total_height))
408
+
409
+ # Concatenate each row of images
410
+ y_offset = 0
411
+ for i, image_list in enumerate(image_lists):
412
+ x_offset = 0
413
+ for img in image_list:
414
+ new_image.paste(img, (x_offset, y_offset))
415
+ x_offset += img.width
416
+ y_offset += row_heights[i] # Move the offset down to the next row
417
+
418
+ return new_image
419
+
420
+
421
+ # def concatenate_images(image_list1, image_list2):
422
+ # # Ensure both image lists are not empty
423
+ # if not image_list1 or not image_list2:
424
+ # raise ValueError("Image lists cannot be empty")
425
+
426
+ # # Get the width and height of the first image
427
+ # width, height = image_list1[0].size
428
+
429
+ # # Calculate the total width and height
430
+ # total_width = max(len(image_list1), len(image_list2)) * width
431
+ # total_height = 2 * height # For two rows
432
+
433
+ # # Create a new image to concatenate everything into
434
+ # new_image = Image.new('RGB', (total_width, total_height))
435
+
436
+ # # Concatenate the first row of images
437
+ # x_offset = 0
438
+ # for img in image_list1:
439
+ # new_image.paste(img, (x_offset, 0))
440
+ # x_offset += img.width
441
+
442
+ # # Concatenate the second row of images
443
+ # x_offset = 0
444
+ # for img in image_list2:
445
+ # new_image.paste(img, (x_offset, height))
446
+ # x_offset += img.width
447
+
448
+ # return new_image
449
+
450
+ def colorize_depth_map(depth, mask=None, reverse_color=False):
451
+ cm = matplotlib.colormaps["Spectral"]
452
+ # normalize
453
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()))
454
+ # colorize
455
+ if reverse_color:
456
+ img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3] # Invert the depth values before applying colormap
457
+ else:
458
+ img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
459
+
460
+ depth_colored = (img_colored_np * 255).astype(np.uint8)
461
+ if mask is not None:
462
+ masked_image = np.zeros_like(depth_colored)
463
+ masked_image[mask.numpy()] = depth_colored[mask.numpy()]
464
+ depth_colored_img = Image.fromarray(masked_image)
465
+ else:
466
+ depth_colored_img = Image.fromarray(depth_colored)
467
+ return depth_colored_img
468
+
469
+
470
+ def resize_max_res(
471
+ img: torch.Tensor,
472
+ max_edge_resolution: int,
473
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
474
+ ) -> torch.Tensor:
475
+ """
476
+ Resize image to limit maximum edge length while keeping aspect ratio.
477
+
478
+ Args:
479
+ img (`torch.Tensor`):
480
+ Image tensor to be resized. Expected shape: [B, C, H, W]
481
+ max_edge_resolution (`int`):
482
+ Maximum edge length (pixel).
483
+ resample_method (`PIL.Image.Resampling`):
484
+ Resampling method used to resize images.
485
+
486
+ Returns:
487
+ `torch.Tensor`: Resized image.
488
+ """
489
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
490
+
491
+ original_height, original_width = img.shape[-2:]
492
+ downscale_factor = min(
493
+ max_edge_resolution / original_width, max_edge_resolution / original_height
494
+ )
495
+
496
+ new_width = int(original_width * downscale_factor)
497
+ new_height = int(original_height * downscale_factor)
498
+
499
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
500
+ return resized_img
501
+
502
+
503
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
504
+ resample_method_dict = {
505
+ "bilinear": InterpolationMode.BILINEAR,
506
+ "bicubic": InterpolationMode.BICUBIC,
507
+ "nearest": InterpolationMode.NEAREST_EXACT,
508
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
509
+ }
510
+ resample_method = resample_method_dict.get(method_str, None)
511
+ if resample_method is None:
512
+ raise ValueError(f"Unknown resampling method: {resample_method}")
513
+ else:
514
+ return resample_method
utils/seed_all.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import numpy as np
22
+ import random
23
+ import torch
24
+
25
+
26
+ def seed_all(seed: int = 0):
27
+ """
28
+ Set random seeds of all components.
29
+ """
30
+ random.seed(seed)
31
+ np.random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)