| | ''' |
| | python colorflow_cli.py \ |
| | --input_image ./input.jpg \ |
| | --reference_images ./ref1.jpg ./ref2.jpg \ |
| | --output_dir ./results \ |
| | --input_style Sketch \ |
| | --resolution 640x640 \ |
| | --seed 123 \ |
| | --num_inference_steps 20 |
| | ''' |
| |
|
| | |
| | from app_func import * |
| | import argparse |
| | import torch |
| | from PIL import Image |
| | import os |
| | import logging |
| |
|
| | |
| | |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="ColorFlow命令行图像上色工具") |
| | parser.add_argument("--input_image", type=str, required=True, help="输入图像路径") |
| | parser.add_argument("--reference_images", type=str, nargs='+', required=True, help="参考图像路径列表") |
| | parser.add_argument("--output_dir", type=str, default="./output", help="输出目录") |
| | parser.add_argument("--input_style", type=str, default="GrayImage(ScreenStyle)", |
| | choices=["GrayImage(ScreenStyle)", "Sketch"], help="输入样式类型") |
| | parser.add_argument("--resolution", type=str, default="640x640", |
| | choices=["640x640", "512x800", "800x512"], help="分辨率设置") |
| | parser.add_argument("--seed", type=int, default=0, help="随机种子") |
| | parser.add_argument("--num_inference_steps", type=int, default=10, help="推理步数") |
| | return parser.parse_args() |
| |
|
| | def save_image(image: Image.Image, path: str, format: str = "PNG") -> None: |
| | """安全保存图像并处理异常""" |
| | try: |
| | image.save(path, format=format) |
| | logging.info(f"成功保存图像至: {path}") |
| | except Exception as e: |
| | logging.error(f"保存图像失败: {str(e)}") |
| | raise |
| |
|
| | def main(): |
| | args = parse_args() |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | |
| | |
| | global cur_input_style, pipeline, MultiResNetModel |
| | cur_input_style = None |
| | load_ckpt(args.input_style) |
| | |
| | |
| | input_img = Image.open(args.input_image).convert("RGB") |
| | input_context, extracted_line, _ = extract_line_image(input_img, args.input_style, args.resolution) |
| | |
| | |
| | high_res_img, up_img, raw_output, preprocessed_bw = colorize_image( |
| | VAE_input=extracted_line, |
| | input_context=input_context, |
| | reference_images=args.reference_images, |
| | resolution=args.resolution, |
| | seed=args.seed, |
| | input_style=args.input_style, |
| | num_inference_steps=args.num_inference_steps |
| | ) |
| | |
| | |
| | save_image(high_res_img, os.path.join(args.output_dir, "colorized_result.png")) |
| | save_image(up_img, os.path.join(args.output_dir, "upsampled_intermediate.png")) |
| | save_image(raw_output, os.path.join(args.output_dir, "raw_generated_output.png")) |
| | save_image(preprocessed_bw, os.path.join(args.output_dir, "preprocessed_bw.png")) |
| |
|
| | if __name__ == "__main__": |
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| | main() |