Spaces:
Runtime error
Runtime error
| # sudo cog push r8.im/yael-vinker/clipasso | |
| # Prediction interface for Cog ⚙️ | |
| # https://github.com/replicate/cog/blob/main/docs/python.md | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| warnings.simplefilter('ignore') | |
| from cog import BasePredictor, Input, Path | |
| import subprocess as sp | |
| import os | |
| import re | |
| import imageio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pydiffvg | |
| import torch | |
| from PIL import Image | |
| import multiprocessing as mp | |
| from shutil import copyfile | |
| import argparse | |
| import math | |
| import sys | |
| import time | |
| import traceback | |
| import PIL | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wandb | |
| from torchvision import models, transforms | |
| from tqdm import tqdm | |
| import config | |
| import sketch_utils as utils | |
| from models.loss import Loss | |
| from models.painter_params import Painter, PainterOptimizer | |
| class Predictor(BasePredictor): | |
| def setup(self): | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| self.num_iter = 2001 | |
| self.save_interval = 100 | |
| self.num_sketches = 3 | |
| self.use_gpu = True | |
| def predict( | |
| self, | |
| target_image: Path = Input(description="Input image (square, without background)"), | |
| num_strokes: int = Input(description="The number of strokes used to create the sketch, which determines the level of abstraction",default=16), | |
| trials: int = Input(description="It is recommended to use 3 trials to recieve the best sketch, but it might be slower",default=3), | |
| mask_object: int = Input(description="It is recommended to use images without a background, however, if your image contains a background, you can mask it out by using this flag with 1 as an argument",default=0), | |
| fix_scale: int = Input(description="If your image is not squared, it might be cut off, it is recommended to use this flag with 1 as input to automatically fix the scale without cutting the image",default=0), | |
| ) -> Path: | |
| self.num_sketches = trials | |
| target_image_name = os.path.basename(str(target_image)) | |
| multiprocess = False | |
| abs_path = os.path.abspath(os.getcwd()) | |
| target = str(target_image) | |
| assert os.path.isfile(target), f"{target} does not exists!" | |
| test_name = os.path.splitext(target_image_name)[0] | |
| output_dir = f"{abs_path}/output_sketches/{test_name}/" | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| print("=" * 50) | |
| print(f"Processing [{target_image_name}] ...") | |
| print(f"Results will be saved to \n[{output_dir}] ...") | |
| print("=" * 50) | |
| if not torch.cuda.is_available(): | |
| self.use_gpu = False | |
| print("CUDA is not configured with GPU, running with CPU instead.") | |
| print("Note that this will be very slow, it is recommended to use colab.") | |
| print(f"GPU: {self.use_gpu}") | |
| seeds = list(range(0, self.num_sketches * 1000, 1000)) | |
| losses_all = {} | |
| for seed in seeds: | |
| wandb_name = f"{test_name}_{num_strokes}strokes_seed{seed}" | |
| sp.run(["python", "config.py", target, | |
| "--num_paths", str(num_strokes), | |
| "--output_dir", output_dir, | |
| "--wandb_name", wandb_name, | |
| "--num_iter", str(self.num_iter), | |
| "--save_interval", str(self.save_interval), | |
| "--seed", str(seed), | |
| "--use_gpu", str(int(self.use_gpu)), | |
| "--fix_scale", str(fix_scale), | |
| "--mask_object", str(mask_object), | |
| "--mask_object_attention", str( | |
| mask_object), | |
| "--display_logs", str(int(0))]) | |
| config_init = np.load(f"{output_dir}/{wandb_name}/config_init.npy", allow_pickle=True)[()] | |
| args = Args(config_init) | |
| args.cog_display = True | |
| final_config = vars(args) | |
| try: | |
| configs_to_save = main(args) | |
| except BaseException as err: | |
| print(f"Unexpected error occurred:\n {err}") | |
| print(traceback.format_exc()) | |
| sys.exit(1) | |
| for k in configs_to_save.keys(): | |
| final_config[k] = configs_to_save[k] | |
| np.save(f"{args.output_dir}/config.npy", final_config) | |
| if args.use_wandb: | |
| wandb.finish() | |
| config = np.load(f"{output_dir}/{wandb_name}/config.npy", | |
| allow_pickle=True)[()] | |
| loss_eval = np.array(config['loss_eval']) | |
| inds = np.argsort(loss_eval) | |
| losses_all[wandb_name] = loss_eval[inds][0] | |
| # return Path(f"{output_dir}/{wandb_name}/best_iter.svg") | |
| sorted_final = dict(sorted(losses_all.items(), key=lambda item: item[1])) | |
| copyfile(f"{output_dir}/{list(sorted_final.keys())[0]}/best_iter.svg", | |
| f"{output_dir}/{list(sorted_final.keys())[0]}_best.svg") | |
| target_path = f"{abs_path}/target_images/{target_image_name}" | |
| svg_files = os.listdir(output_dir) | |
| svg_files = [f for f in svg_files if "best.svg" in f] | |
| svg_output_path = f"{output_dir}/{svg_files[0]}" | |
| sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy() | |
| sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB') | |
| sketch_res.save(f"{abs_path}/output_sketches/sketch.png") | |
| return Path(svg_output_path) | |
| class Args(): | |
| def __init__(self, config): | |
| for k in config.keys(): | |
| setattr(self, k, config[k]) | |
| def load_renderer(args, target_im=None, mask=None): | |
| renderer = Painter(num_strokes=args.num_paths, args=args, | |
| num_segments=args.num_segments, | |
| imsize=args.image_scale, | |
| device=args.device, | |
| target_im=target_im, | |
| mask=mask) | |
| renderer = renderer.to(args.device) | |
| return renderer | |
| def get_target(args): | |
| target = Image.open(args.target) | |
| if target.mode == "RGBA": | |
| # Create a white rgba background | |
| new_image = Image.new("RGBA", target.size, "WHITE") | |
| # Paste the image on the background. | |
| new_image.paste(target, (0, 0), target) | |
| target = new_image | |
| target = target.convert("RGB") | |
| masked_im, mask = utils.get_mask_u2net(args, target) | |
| if args.mask_object: | |
| target = masked_im | |
| if args.fix_scale: | |
| target = utils.fix_image_scale(target) | |
| transforms_ = [] | |
| if target.size[0] != target.size[1]: | |
| transforms_.append(transforms.Resize( | |
| (args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC)) | |
| else: | |
| transforms_.append(transforms.Resize( | |
| args.image_scale, interpolation=PIL.Image.BICUBIC)) | |
| transforms_.append(transforms.CenterCrop(args.image_scale)) | |
| transforms_.append(transforms.ToTensor()) | |
| data_transforms = transforms.Compose(transforms_) | |
| target_ = data_transforms(target).unsqueeze(0).to(args.device) | |
| return target_, mask | |
| def main(args): | |
| loss_func = Loss(args) | |
| inputs, mask = get_target(args) | |
| utils.log_input(args.use_wandb, 0, inputs, args.output_dir) | |
| renderer = load_renderer(args, inputs, mask) | |
| optimizer = PainterOptimizer(args, renderer) | |
| counter = 0 | |
| configs_to_save = {"loss_eval": []} | |
| best_loss, best_fc_loss = 100, 100 | |
| best_iter, best_iter_fc = 0, 0 | |
| min_delta = 1e-5 | |
| terminate = False | |
| renderer.set_random_noise(0) | |
| img = renderer.init_image(stage=0) | |
| optimizer.init_optimizers() | |
| for epoch in tqdm(range(args.num_iter)): | |
| renderer.set_random_noise(epoch) | |
| if args.lr_scheduler: | |
| optimizer.update_lr(counter) | |
| start = time.time() | |
| optimizer.zero_grad_() | |
| sketches = renderer.get_image().to(args.device) | |
| losses_dict = loss_func(sketches, inputs.detach( | |
| ), renderer.get_color_parameters(), renderer, counter, optimizer) | |
| loss = sum(list(losses_dict.values())) | |
| loss.backward() | |
| optimizer.step_() | |
| if epoch % args.save_interval == 0: | |
| utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter, | |
| use_wandb=args.use_wandb, title=f"iter{epoch}.jpg") | |
| renderer.save_svg( | |
| f"{args.output_dir}/svg_logs", f"svg_iter{epoch}") | |
| # if args.cog_display: | |
| # yield Path(f"{args.output_dir}/svg_logs/svg_iter{epoch}.svg") | |
| if epoch % args.eval_interval == 0: | |
| with torch.no_grad(): | |
| losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters( | |
| ), renderer.get_points_parans(), counter, optimizer, mode="eval") | |
| loss_eval = sum(list(losses_dict_eval.values())) | |
| configs_to_save["loss_eval"].append(loss_eval.item()) | |
| for k in losses_dict_eval.keys(): | |
| if k not in configs_to_save.keys(): | |
| configs_to_save[k] = [] | |
| configs_to_save[k].append(losses_dict_eval[k].item()) | |
| if args.clip_fc_loss_weight: | |
| if losses_dict_eval["fc"].item() < best_fc_loss: | |
| best_fc_loss = losses_dict_eval["fc"].item( | |
| ) / args.clip_fc_loss_weight | |
| best_iter_fc = epoch | |
| # print( | |
| # f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]") | |
| cur_delta = loss_eval.item() - best_loss | |
| if abs(cur_delta) > min_delta: | |
| if cur_delta < 0: | |
| best_loss = loss_eval.item() | |
| best_iter = epoch | |
| terminate = False | |
| utils.plot_batch( | |
| inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg") | |
| renderer.save_svg(args.output_dir, "best_iter") | |
| if args.use_wandb: | |
| wandb.run.summary["best_loss"] = best_loss | |
| wandb.run.summary["best_loss_fc"] = best_fc_loss | |
| wandb_dict = {"delta": cur_delta, | |
| "loss_eval": loss_eval.item()} | |
| for k in losses_dict_eval.keys(): | |
| wandb_dict[k + "_eval"] = losses_dict_eval[k].item() | |
| wandb.log(wandb_dict, step=counter) | |
| if abs(cur_delta) <= min_delta: | |
| if terminate: | |
| break | |
| terminate = True | |
| if counter == 0 and args.attention_init: | |
| utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(), | |
| args.use_wandb, "{}/{}.jpg".format( | |
| args.output_dir, "attention_map"), | |
| args.saliency_model, args.display_logs) | |
| if args.use_wandb: | |
| wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()} | |
| for k in losses_dict.keys(): | |
| wandb_dict[k] = losses_dict[k].item() | |
| wandb.log(wandb_dict, step=counter) | |
| counter += 1 | |
| renderer.save_svg(args.output_dir, "final_svg") | |
| path_svg = os.path.join(args.output_dir, "best_iter.svg") | |
| utils.log_sketch_summary_final( | |
| path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total") | |
| return configs_to_save | |
| def read_svg(path_svg, multiply=False): | |
| device = torch.device("cuda" if ( | |
| torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu") | |
| canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( | |
| path_svg) | |
| if multiply: | |
| canvas_width *= 2 | |
| canvas_height *= 2 | |
| for path in shapes: | |
| path.points *= 2 | |
| path.stroke_width *= 2 | |
| _render = pydiffvg.RenderFunction.apply | |
| scene_args = pydiffvg.RenderFunction.serialize_scene( | |
| canvas_width, canvas_height, shapes, shape_groups) | |
| img = _render(canvas_width, # width | |
| canvas_height, # height | |
| 2, # num_samples_x | |
| 2, # num_samples_y | |
| 0, # seed | |
| None, | |
| *scene_args) | |
| img = img[:, :, 3:4] * img[:, :, :3] + \ | |
| torch.ones(img.shape[0], img.shape[1], 3, | |
| device=device) * (1 - img[:, :, 3:4]) | |
| img = img[:, :, :3] | |
| return img | |