Spaces:
Running
Running
Update lama_inpaint.py
Browse files- lama_inpaint.py +27 -15
lama_inpaint.py
CHANGED
|
@@ -5,7 +5,6 @@ import torch
|
|
| 5 |
import yaml
|
| 6 |
import glob
|
| 7 |
import argparse
|
| 8 |
-
from PIL import Image
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
from pathlib import Path
|
| 11 |
|
|
@@ -20,6 +19,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
|
|
| 20 |
from saicinpainting.evaluation.utils import move_to_device
|
| 21 |
from saicinpainting.training.trainers import load_checkpoint
|
| 22 |
from saicinpainting.evaluation.data import pad_tensor_to_modulo
|
|
|
|
| 23 |
|
| 24 |
from utils import load_img_to_array, save_array_to_img
|
| 25 |
|
|
@@ -53,8 +53,7 @@ def inpaint_img_with_lama(
|
|
| 53 |
train_config, checkpoint_path, strict=False, map_location=device
|
| 54 |
)
|
| 55 |
model.freeze()
|
| 56 |
-
|
| 57 |
-
model.to(device)
|
| 58 |
|
| 59 |
batch = {}
|
| 60 |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
|
|
@@ -62,16 +61,30 @@ def inpaint_img_with_lama(
|
|
| 62 |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
|
| 63 |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
|
| 64 |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
|
| 65 |
-
batch = move_to_device(batch, device)
|
| 66 |
-
batch["mask"] = (batch["mask"] > 0) * 1
|
| 67 |
-
|
| 68 |
-
batch = model(batch)
|
| 69 |
-
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
| 70 |
-
cur_res = cur_res.detach().cpu().numpy()
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
cur_res = cur_res[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
| 77 |
return cur_res
|
|
@@ -98,8 +111,7 @@ def build_lama_model(config_p: str, ckpt_p: str, device="cuda"):
|
|
| 98 |
train_config, checkpoint_path, strict=False, map_location=device
|
| 99 |
)
|
| 100 |
model.freeze()
|
| 101 |
-
|
| 102 |
-
model.to(device)
|
| 103 |
|
| 104 |
return model
|
| 105 |
|
|
|
|
| 5 |
import yaml
|
| 6 |
import glob
|
| 7 |
import argparse
|
|
|
|
| 8 |
from omegaconf import OmegaConf
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
|
|
| 19 |
from saicinpainting.evaluation.utils import move_to_device
|
| 20 |
from saicinpainting.training.trainers import load_checkpoint
|
| 21 |
from saicinpainting.evaluation.data import pad_tensor_to_modulo
|
| 22 |
+
from saicinpainting.evaluation.refinement import refine_predict
|
| 23 |
|
| 24 |
from utils import load_img_to_array, save_array_to_img
|
| 25 |
|
|
|
|
| 53 |
train_config, checkpoint_path, strict=False, map_location=device
|
| 54 |
)
|
| 55 |
model.freeze()
|
| 56 |
+
model.to(device)
|
|
|
|
| 57 |
|
| 58 |
batch = {}
|
| 59 |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
|
|
|
|
| 61 |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
|
| 62 |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
|
| 63 |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
|
| 64 |
+
# batch = move_to_device(batch, device)
|
| 65 |
+
# batch["mask"] = (batch["mask"] > 0) * 1
|
| 66 |
+
|
| 67 |
+
# batch = model(batch)
|
| 68 |
+
# cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
| 69 |
+
# cur_res = cur_res.detach().cpu().numpy()
|
| 70 |
+
if predict_config.get("refine", False):
|
| 71 |
+
batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size]
|
| 72 |
+
cur_res = refine_predict(batch, model, **predict_config.refiner)
|
| 73 |
+
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 74 |
+
else:
|
| 75 |
+
batch = move_to_device(batch, device)
|
| 76 |
+
batch["mask"] = (batch["mask"] > 0) * 1
|
| 77 |
+
batch = model(batch)
|
| 78 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
| 79 |
+
cur_res = cur_res.detach().cpu().numpy()
|
| 80 |
+
|
| 81 |
+
if unpad_to_size is not None:
|
| 82 |
+
orig_height, orig_width = unpad_to_size
|
| 83 |
+
cur_res = cur_res[:orig_height, :orig_width]
|
| 84 |
+
|
| 85 |
+
# if unpad_to_size is not None:
|
| 86 |
+
# orig_height, orig_width = unpad_to_size
|
| 87 |
+
# cur_res = cur_res[:orig_height, :orig_width]
|
| 88 |
|
| 89 |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
| 90 |
return cur_res
|
|
|
|
| 111 |
train_config, checkpoint_path, strict=False, map_location=device
|
| 112 |
)
|
| 113 |
model.freeze()
|
| 114 |
+
model.to(device)
|
|
|
|
| 115 |
|
| 116 |
return model
|
| 117 |
|