Spaces:
Runtime error
Runtime error
| import os | |
| import imageio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import pydiffvg | |
| import skimage | |
| import skimage.io | |
| import torch | |
| import wandb | |
| import PIL | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from skimage.transform import resize | |
| from U2Net_.model import U2NET | |
| def imwrite(img, filename, gamma=2.2, normalize=False, use_wandb=False, wandb_name="", step=0, input_im=None): | |
| directory = os.path.dirname(filename) | |
| if directory != '' and not os.path.exists(directory): | |
| os.makedirs(directory) | |
| if not isinstance(img, np.ndarray): | |
| img = img.data.numpy() | |
| if normalize: | |
| img_rng = np.max(img) - np.min(img) | |
| if img_rng > 0: | |
| img = (img - np.min(img)) / img_rng | |
| img = np.clip(img, 0.0, 1.0) | |
| if img.ndim == 2: | |
| # repeat along the third dimension | |
| img = np.expand_dims(img, 2) | |
| img[:, :, :3] = np.power(img[:, :, :3], 1.0/gamma) | |
| img = (img * 255).astype(np.uint8) | |
| skimage.io.imsave(filename, img, check_contrast=False) | |
| images = [wandb.Image(Image.fromarray(img), caption="output")] | |
| if input_im is not None and step == 0: | |
| images.append(wandb.Image(input_im, caption="input")) | |
| if use_wandb: | |
| wandb.log({wandb_name + "_": images}, step=step) | |
| def plot_batch(inputs, outputs, output_dir, step, use_wandb, title): | |
| plt.figure() | |
| plt.subplot(2, 1, 1) | |
| grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2) | |
| npgrid = grid.cpu().numpy() | |
| plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') | |
| plt.axis("off") | |
| plt.title("inputs") | |
| plt.subplot(2, 1, 2) | |
| grid = make_grid(outputs, normalize=False, pad_value=2) | |
| npgrid = grid.detach().cpu().numpy() | |
| plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') | |
| plt.axis("off") | |
| plt.title("outputs") | |
| plt.tight_layout() | |
| if use_wandb: | |
| wandb.log({"output": wandb.Image(plt)}, step=step) | |
| plt.savefig("{}/{}".format(output_dir, title)) | |
| plt.close() | |
| def log_input(use_wandb, epoch, inputs, output_dir): | |
| grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2) | |
| npgrid = grid.cpu().numpy() | |
| plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') | |
| plt.axis("off") | |
| plt.tight_layout() | |
| if use_wandb: | |
| wandb.log({"input": wandb.Image(plt)}, step=epoch) | |
| plt.close() | |
| input_ = inputs[0].cpu().clone().detach().permute(1, 2, 0).numpy() | |
| input_ = (input_ - input_.min()) / (input_.max() - input_.min()) | |
| input_ = (input_ * 255).astype(np.uint8) | |
| imageio.imwrite("{}/{}.png".format(output_dir, "input"), input_) | |
| def log_sketch_summary_final(path_svg, use_wandb, device, epoch, loss, title): | |
| canvas_width, canvas_height, shapes, shape_groups = load_svg(path_svg) | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_args = pydiffvg.RenderFunction.serialize_scene( | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| img = _render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_args) | |
| img = img[:, :, 3:4] * img[:, :, :3] + \ | |
| torch.ones(img.shape[0], img.shape[1], 3, | |
| device=device) * (1 - img[:, :, 3:4]) | |
| img = img[:, :, :3] | |
| plt.imshow(img.cpu().numpy()) | |
| plt.axis("off") | |
| plt.title(f"{title} best res [{epoch}] [{loss}.]") | |
| if use_wandb: | |
| wandb.log({title: wandb.Image(plt)}) | |
| plt.close() | |
| def log_sketch_summary(sketch, title, use_wandb): | |
| plt.figure() | |
| grid = make_grid(sketch.clone().detach(), normalize=True, pad_value=2) | |
| npgrid = grid.cpu().numpy() | |
| plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') | |
| plt.axis("off") | |
| plt.title(title) | |
| plt.tight_layout() | |
| if use_wandb: | |
| wandb.run.summary["best_loss_im"] = wandb.Image(plt) | |
| plt.close() | |
| def load_svg(path_svg): | |
| svg = os.path.join(path_svg) | |
| canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( | |
| svg) | |
| return canvas_width, canvas_height, shapes, shape_groups | |
| def read_svg(path_svg, device, multiply=False): | |
| canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( | |
| path_svg) | |
| if multiply: | |
| canvas_width *= 2 | |
| canvas_height *= 2 | |
| for path in shapes: | |
| path.points *= 2 | |
| path.stroke_width *= 2 | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_args = pydiffvg.RenderFunction.serialize_scene( | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| img = _render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_args) | |
| img = img[:, :, 3:4] * img[:, :, :3] + \ | |
| torch.ones(img.shape[0], img.shape[1], 3, | |
| device=device) * (1 - img[:, :, 3:4]) | |
| img = img[:, :, :3] | |
| return img | |
| def plot_attn_dino(attn, threshold_map, inputs, inds, use_wandb, output_path): | |
| # currently supports one image (and not a batch) | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(2, attn.shape[0] + 2, 1) | |
| main_im = make_grid(inputs, normalize=True, pad_value=2) | |
| main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
| plt.imshow(main_im, interpolation='nearest') | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.title("input im") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, 2) | |
| plt.imshow(attn.sum(0).numpy(), interpolation='nearest') | |
| plt.title("atn map sum") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) | |
| plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') | |
| plt.title("prob sum") | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) | |
| plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') | |
| plt.title("thresh sum") | |
| plt.axis("off") | |
| for i in range(attn.shape[0]): | |
| plt.subplot(2, attn.shape[0] + 2, i + 3) | |
| plt.imshow(attn[i].numpy()) | |
| plt.axis("off") | |
| plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) | |
| plt.imshow(threshold_map[i].numpy()) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| if use_wandb: | |
| wandb.log({"attention_map": wandb.Image(plt)}) | |
| plt.savefig(output_path) | |
| plt.close() | |
| def plot_attn_clip(attn, threshold_map, inputs, inds, use_wandb, output_path, display_logs): | |
| # currently supports one image (and not a batch) | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(1, 3, 1) | |
| main_im = make_grid(inputs, normalize=True, pad_value=2) | |
| main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
| plt.imshow(main_im, interpolation='nearest') | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.title("input im") | |
| plt.axis("off") | |
| plt.subplot(1, 3, 2) | |
| plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) | |
| plt.title("atn map") | |
| plt.axis("off") | |
| plt.subplot(1, 3, 3) | |
| threshold_map_ = (threshold_map - threshold_map.min()) / \ | |
| (threshold_map.max() - threshold_map.min()) | |
| plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) | |
| plt.title("prob softmax") | |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
| plt.axis("off") | |
| plt.tight_layout() | |
| if use_wandb: | |
| wandb.log({"attention_map": wandb.Image(plt)}) | |
| plt.savefig(output_path) | |
| plt.close() | |
| def plot_atten(attn, threshold_map, inputs, inds, use_wandb, output_path, saliency_model, display_logs): | |
| if saliency_model == "dino": | |
| plot_attn_dino(attn, threshold_map, inputs, | |
| inds, use_wandb, output_path) | |
| elif saliency_model == "clip": | |
| plot_attn_clip(attn, threshold_map, inputs, inds, | |
| use_wandb, output_path, display_logs) | |
| def fix_image_scale(im): | |
| im_np = np.array(im) / 255 | |
| height, width = im_np.shape[0], im_np.shape[1] | |
| max_len = max(height, width) + 20 | |
| new_background = np.ones((max_len, max_len, 3)) | |
| y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 | |
| new_background[y: y + height, x: x + width] = im_np | |
| new_background = (new_background / new_background.max() | |
| * 255).astype(np.uint8) | |
| new_im = Image.fromarray(new_background) | |
| return new_im | |
| def get_mask_u2net(args, pil_im): | |
| w, h = pil_im.size[0], pil_im.size[1] | |
| im_size = min(w, h) | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(min(320, im_size), interpolation=PIL.Image.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=( | |
| 0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| input_im_trans = data_transforms(pil_im).unsqueeze(0).to(args.device) | |
| model_dir = os.path.join("./U2Net_/saved_models/u2net.pth") | |
| net = U2NET(3, 1) | |
| if torch.cuda.is_available() and args.use_gpu: | |
| net.load_state_dict(torch.load(model_dir)) | |
| net.to(args.device) | |
| else: | |
| net.load_state_dict(torch.load(model_dir, map_location='cpu')) | |
| net.eval() | |
| with torch.no_grad(): | |
| d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) | |
| pred = d1[:, 0, :, :] | |
| pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
| predict = pred | |
| predict[predict < 0.5] = 0 | |
| predict[predict >= 0.5] = 1 | |
| mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0) | |
| mask = mask.cpu().numpy() | |
| mask = resize(mask, (h, w), anti_aliasing=False) | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| # predict_np = predict.clone().cpu().data.numpy() | |
| im = Image.fromarray((mask[:, :, 0]*255).astype(np.uint8)).convert('RGB') | |
| im.save(f"{args.output_dir}/mask.png") | |
| im_np = np.array(pil_im) | |
| im_np = im_np / im_np.max() | |
| im_np = mask * im_np | |
| im_np[mask == 0] = 1 | |
| im_final = (im_np / im_np.max() * 255).astype(np.uint8) | |
| im_final = Image.fromarray(im_final) | |
| return im_final, predict | |