| | import os |
| | import random |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
| | import matplotlib.pyplot as plt |
| | import cv2 |
| | import torch.nn.functional as F |
| |
|
| | |
| |
|
| |
|
| | class _bn_relu_conv(nn.Module): |
| | def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): |
| | super(_bn_relu_conv, self).__init__() |
| | self.model = nn.Sequential( |
| | nn.BatchNorm2d(in_filters, eps=1e-3), |
| | nn.LeakyReLU(0.2), |
| | nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | |
| | print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) |
| | for i,layer in enumerate(self.model): |
| | if i != 2: |
| | x = layer(x) |
| | else: |
| | x = layer(x) |
| | |
| | print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) |
| | print(x[0]) |
| | return x |
| |
|
| |
|
| | class _u_bn_relu_conv(nn.Module): |
| | def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): |
| | super(_u_bn_relu_conv, self).__init__() |
| | self.model = nn.Sequential( |
| | nn.BatchNorm2d(in_filters, eps=1e-3), |
| | nn.LeakyReLU(0.2), |
| | nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), |
| | nn.Upsample(scale_factor=2, mode='nearest') |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| |
|
| | class _shortcut(nn.Module): |
| | def __init__(self, in_filters, nb_filters, subsample=1): |
| | super(_shortcut, self).__init__() |
| | self.process = False |
| | self.model = None |
| | if in_filters != nb_filters or subsample != 1: |
| | self.process = True |
| | self.model = nn.Sequential( |
| | nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) |
| | ) |
| |
|
| | def forward(self, x, y): |
| | |
| | if self.process: |
| | y0 = self.model(x) |
| | |
| | return y0 + y |
| | else: |
| | |
| | return x + y |
| |
|
| | class _u_shortcut(nn.Module): |
| | def __init__(self, in_filters, nb_filters, subsample): |
| | super(_u_shortcut, self).__init__() |
| | self.process = False |
| | self.model = None |
| | if in_filters != nb_filters: |
| | self.process = True |
| | self.model = nn.Sequential( |
| | nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), |
| | nn.Upsample(scale_factor=2, mode='nearest') |
| | ) |
| |
|
| | def forward(self, x, y): |
| | if self.process: |
| | return self.model(x) + y |
| | else: |
| | return x + y |
| |
|
| |
|
| | class basic_block(nn.Module): |
| | def __init__(self, in_filters, nb_filters, init_subsample=1): |
| | super(basic_block, self).__init__() |
| | self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) |
| | self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) |
| | self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) |
| |
|
| | def forward(self, x): |
| | x1 = self.conv1(x) |
| | x2 = self.residual(x1) |
| | return self.shortcut(x, x2) |
| |
|
| | class _u_basic_block(nn.Module): |
| | def __init__(self, in_filters, nb_filters, init_subsample=1): |
| | super(_u_basic_block, self).__init__() |
| | self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) |
| | self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) |
| | self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) |
| |
|
| | def forward(self, x): |
| | y = self.residual(self.conv1(x)) |
| | return self.shortcut(x, y) |
| |
|
| |
|
| | class _residual_block(nn.Module): |
| | def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): |
| | super(_residual_block, self).__init__() |
| | layers = [] |
| | for i in range(repetitions): |
| | init_subsample = 1 |
| | if i == repetitions - 1 and not is_first_layer: |
| | init_subsample = 2 |
| | if i == 0: |
| | l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) |
| | else: |
| | l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) |
| | layers.append(l) |
| |
|
| | self.model = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | class _upsampling_residual_block(nn.Module): |
| | def __init__(self, in_filters, nb_filters, repetitions): |
| | super(_upsampling_residual_block, self).__init__() |
| | layers = [] |
| | for i in range(repetitions): |
| | l = None |
| | if i == 0: |
| | l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters) |
| | else: |
| | l = basic_block(in_filters=nb_filters, nb_filters=nb_filters) |
| | layers.append(l) |
| |
|
| | self.model = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| |
|
| | class res_skip(nn.Module): |
| |
|
| | def __init__(self): |
| | super(res_skip, self).__init__() |
| | self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True) |
| | self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3) |
| | self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5) |
| | self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7) |
| | self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12) |
| | |
| | self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7) |
| | self.res1 = _shortcut(in_filters=192, nb_filters=192) |
| |
|
| | self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5) |
| | self.res2 = _shortcut(in_filters=96, nb_filters=96) |
| |
|
| | self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3) |
| | self.res3 = _shortcut(in_filters=48, nb_filters=48) |
| |
|
| | self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2) |
| | self.res4 = _shortcut(in_filters=24, nb_filters=24) |
| |
|
| | self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True) |
| | self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1) |
| |
|
| | def forward(self, x): |
| | x0 = self.block0(x) |
| | x1 = self.block1(x0) |
| | x2 = self.block2(x1) |
| | x3 = self.block3(x2) |
| | x4 = self.block4(x3) |
| |
|
| | x5 = self.block5(x4) |
| | res1 = self.res1(x3, x5) |
| |
|
| | x6 = self.block6(res1) |
| | res2 = self.res2(x2, x6) |
| |
|
| | x7 = self.block7(res2) |
| | res3 = self.res3(x1, x7) |
| |
|
| | x8 = self.block8(res3) |
| | res4 = self.res4(x0, x8) |
| |
|
| | x9 = self.block9(res4) |
| | y = self.conv15(x9) |
| |
|
| | return y |
| |
|
| | class MyDataset(Dataset): |
| | def __init__(self, image_paths, transform=None): |
| | self.image_paths = image_paths |
| | self.transform = transform |
| | |
| | def get_class_label(self, image_name): |
| | |
| | head, tail = os.path.split(image_name) |
| | |
| | return tail |
| | |
| | def __getitem__(self, index): |
| | image_path = self.image_paths[index] |
| | x = Image.open(image_path) |
| | y = self.get_class_label(image_path.split('/')[-1]) |
| | if self.transform is not None: |
| | x = self.transform(x) |
| | return x, y |
| | |
| | def __len__(self): |
| | return len(self.image_paths) |
| |
|
| | def loadImages(folder): |
| | imgs = [] |
| | matches = [] |
| | |
| | |
| | for filename in os.listdir(folder): |
| | |
| | file_path = os.path.join(folder, filename) |
| | |
| | if os.path.isfile(file_path): |
| | matches.append(file_path) |
| | |
| | return matches |
| |
|
| |
|
| | def crop_center_square(image): |
| | """ |
| | 将图像中心裁剪为正方形 |
| | |
| | :param image: PIL.Image对象 |
| | :return: 裁剪后的PIL.Image对象 |
| | """ |
| | |
| | width, height = image.size |
| | |
| | |
| | side_length = min(width, height) |
| | |
| | |
| | left = (width - side_length) // 2 |
| | top = (height - side_length) // 2 |
| | right = left + side_length |
| | bottom = top + side_length |
| | |
| | |
| | cropped_image = image.crop((left, top, right, bottom)) |
| | |
| | return cropped_image |
| |
|
| | def crop_image(image, crop_size, stride): |
| | """ |
| | 根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :param crop_size: 裁剪大小,例如 (384, 384) |
| | :param stride: 重叠步长,例如 128 |
| | :return: 裁剪后的图像列表 |
| | """ |
| | width, height = image.size |
| | crop_width, crop_height = crop_size |
| | cropped_images = [] |
| |
|
| | for j in range(0, height - crop_height + 1, stride): |
| | for i in range(0, width - crop_width + 1, stride): |
| | crop_box = (i, j, i + crop_width, j + crop_height) |
| | cropped_image = image.crop(crop_box) |
| | cropped_images.append(cropped_image) |
| |
|
| | return cropped_images |
| |
|
| | def process_image_ref(image): |
| | """ |
| | 处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :return: 包含所有裁剪后图像的列表 |
| | """ |
| | |
| | resized_image_512 = image.resize((512, 512)) |
| |
|
| | |
| | image_list = [resized_image_512] |
| |
|
| | |
| | crop_size_384 = (384, 384) |
| | stride_384 = 128 |
| | image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | return image_list |
| |
|
| |
|
| | def process_image_Q(image): |
| | """ |
| | 处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :return: 包含所有裁剪后图像的列表 |
| | """ |
| | |
| | resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB") |
| |
|
| | |
| | image_list = [] |
| |
|
| | |
| | crop_size_384 = (384, 384) |
| | stride_384 = 128 |
| | image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) |
| |
|
| | return image_list |
| |
|
| | def process_image(image, target_width=512, target_height = 512): |
| | |
| | img_width, img_height = image.size |
| | img_ratio = img_width / img_height |
| | |
| | |
| | |
| | target_ratio = target_width / target_height |
| | |
| | |
| | ratio_error = abs(img_ratio - target_ratio) / target_ratio |
| | |
| | if ratio_error < 0.15: |
| | |
| | resized_image = image.resize((target_width, target_height), Image.BICUBIC) |
| | else: |
| | |
| | if img_ratio > target_ratio: |
| | |
| | new_width = int(img_height * target_ratio) |
| | |
| | left = int((0 + img_width - new_width)/2) |
| | top = 0 |
| | right = left + new_width |
| | bottom = img_height |
| | else: |
| | |
| | new_height = int(img_width / target_ratio) |
| | left = 0 |
| | |
| | top = int((0 + img_height - new_height)/2) |
| | right = img_width |
| | bottom = top + new_height |
| | |
| | cropped_image = image.crop((left, top, right, bottom)) |
| | resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC) |
| | |
| | return resized_image.convert('RGB') |
| |
|
| | def crop_image_varres(image, crop_size, h_stride, w_stride): |
| | """ |
| | 根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :param crop_size: 裁剪大小,例如 (384, 384) |
| | :param stride: 重叠步长,例如 128 |
| | :return: 裁剪后的图像列表 |
| | """ |
| | width, height = image.size |
| | crop_width, crop_height = crop_size |
| | cropped_images = [] |
| |
|
| | for j in range(0, height - crop_height + 1, h_stride): |
| | for i in range(0, width - crop_width + 1, w_stride): |
| | crop_box = (i, j, i + crop_width, j + crop_height) |
| | cropped_image = image.crop(crop_box) |
| | cropped_images.append(cropped_image) |
| |
|
| | return cropped_images |
| |
|
| | def process_image_ref_varres(image, target_width=512, target_height = 512): |
| | """ |
| | 处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :return: 包含所有裁剪后图像的列表 |
| | """ |
| | |
| | resized_image_512 = image.resize((target_width, target_height)) |
| |
|
| | |
| | image_list = [resized_image_512] |
| |
|
| | |
| | crop_size_384 = (target_width//4*3, target_height//4*3) |
| | w_stride_384 = target_width//4 |
| | h_stride_384 = target_height//4 |
| | image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | return image_list |
| |
|
| |
|
| | def process_image_Q_varres(image, target_width=512, target_height = 512): |
| | """ |
| | 处理输入的PIL图像,返回包含所有裁剪后图像的列表。 |
| | |
| | :param image: PIL.Image对象 |
| | :return: 包含所有裁剪后图像的列表 |
| | """ |
| | |
| | resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB") |
| |
|
| | |
| | image_list = [] |
| |
|
| | |
| | crop_size_384 = (target_width//4*3, target_height//4*3) |
| | w_stride_384 = target_width//4 |
| | h_stride_384 = target_height//4 |
| | image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) |
| |
|
| |
|
| | return image_list |
| |
|
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| | class ResNetBlock(nn.Module): |
| | def __init__(self, in_channels, out_channels, stride=1): |
| | super(ResNetBlock, self).__init__() |
| | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(out_channels) |
| | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(out_channels) |
| |
|
| | self.shortcut = nn.Sequential() |
| | if stride != 1 or in_channels != out_channels: |
| | self.shortcut = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), |
| | nn.BatchNorm2d(out_channels) |
| | ) |
| |
|
| | def forward(self, x): |
| | out = F.relu(self.bn1(self.conv1(x))) |
| | out = self.bn2(self.conv2(out)) |
| | out += self.shortcut(x) |
| | out = F.relu(out) |
| | return out |
| |
|
| | |
| | class TwoLayerResNet(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super(TwoLayerResNet, self).__init__() |
| | self.block1 = ResNetBlock(in_channels, out_channels) |
| | self.block2 = ResNetBlock(out_channels, out_channels) |
| | self.block3 = ResNetBlock(out_channels, out_channels) |
| | self.block4 = ResNetBlock(out_channels, out_channels) |
| |
|
| | def forward(self, x): |
| | x = self.block1(x) |
| | x = self.block2(x) |
| | x = self.block3(x) |
| | x = self.block4(x) |
| | return x |
| | |
| |
|
| | class MultiHiddenResNetModel(nn.Module): |
| | def __init__(self, channels_list, num_tensors): |
| | super(MultiHiddenResNetModel, self).__init__() |
| | self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)]) |
| |
|
| | def forward(self, tensor_list): |
| | processed_list = [] |
| | for i, tensor in enumerate(tensor_list): |
| | |
| | tensor = self.two_layer_resnets[i](tensor) |
| | processed_list.append(tensor) |
| | |
| | return processed_list |
| | |
| |
|
| | def calculate_target_size(h, w): |
| | |
| | if random.random()>0.5: |
| | target_h = (h // 8) * 8 |
| | target_w = (w // 8) * 8 |
| | elif random.random()>0.5: |
| | target_h = (h // 8) * 8 |
| | target_w = (w // 8) * 8 |
| | else: |
| | target_h = (h // 8) * 8 |
| | target_w = (w // 8) * 8 |
| | |
| | |
| | if target_h == 0: |
| | target_h = 8 |
| | if target_w == 0: |
| | target_w = 8 |
| | |
| | return target_h, target_w |
| |
|
| |
|
| | def downsample_tensor(tensor): |
| | |
| | b, c, h, w = tensor.shape |
| | |
| | |
| | target_h, target_w = calculate_target_size(h, w) |
| | |
| | |
| | downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False) |
| | |
| | return downsampled_tensor |
| |
|
| |
|
| |
|
| | def get_pixart_config(): |
| | pixart_config = { |
| | "_class_name": "Transformer2DModel", |
| | "_diffusers_version": "0.22.0.dev0", |
| | "activation_fn": "gelu-approximate", |
| | "attention_bias": True, |
| | "attention_head_dim": 72, |
| | "attention_type": "default", |
| | "caption_channels": 4096, |
| | "cross_attention_dim": 1152, |
| | "double_self_attention": False, |
| | "dropout": 0.0, |
| | "in_channels": 4, |
| | |
| | "norm_elementwise_affine": False, |
| | "norm_eps": 1e-06, |
| | "norm_num_groups": 32, |
| | "norm_type": "ada_norm_single", |
| | "num_attention_heads": 16, |
| | "num_embeds_ada_norm": 1000, |
| | "num_layers": 28, |
| | "num_vector_embeds": None, |
| | "only_cross_attention": False, |
| | "out_channels": 8, |
| | "patch_size": 2, |
| | "sample_size": 128, |
| | "upcast_attention": False, |
| | |
| | "use_linear_projection": False |
| | } |
| | return pixart_config |
| |
|
| |
|
| |
|
| | class DoubleConv(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super().__init__() |
| | self.double_conv = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, 3, 1, 1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU(), |
| | nn.Conv2d(out_channels, out_channels, 3, 1, 1), |
| | nn.BatchNorm2d(out_channels), |
| | nn.ReLU() |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.double_conv(x) |
| |
|
| |
|
| | class UNet(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | self.left_conv_1 = DoubleConv(6, 64) |
| | self.down_1 = nn.MaxPool2d(2, 2) |
| |
|
| | self.left_conv_2 = DoubleConv(64, 128) |
| | self.down_2 = nn.MaxPool2d(2, 2) |
| |
|
| | self.left_conv_3 = DoubleConv(128, 256) |
| | self.down_3 = nn.MaxPool2d(2, 2) |
| |
|
| | self.left_conv_4 = DoubleConv(256, 512) |
| | self.down_4 = nn.MaxPool2d(2, 2) |
| |
|
| | |
| | self.center_conv = DoubleConv(512, 1024) |
| |
|
| | |
| | self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2) |
| | self.right_conv_1 = DoubleConv(1024, 512) |
| |
|
| | self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2) |
| | self.right_conv_2 = DoubleConv(512, 256) |
| |
|
| | self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2) |
| | self.right_conv_3 = DoubleConv(256, 128) |
| |
|
| | self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2) |
| | self.right_conv_4 = DoubleConv(128, 64) |
| |
|
| | |
| | self.output = nn.Conv2d(64, 3, 1, 1, 0) |
| |
|
| | def forward(self, x): |
| | |
| | x1 = self.left_conv_1(x) |
| | x1_down = self.down_1(x1) |
| |
|
| | x2 = self.left_conv_2(x1_down) |
| | x2_down = self.down_2(x2) |
| |
|
| | x3 = self.left_conv_3(x2_down) |
| | x3_down = self.down_3(x3) |
| |
|
| | x4 = self.left_conv_4(x3_down) |
| | x4_down = self.down_4(x4) |
| |
|
| | |
| | x5 = self.center_conv(x4_down) |
| |
|
| | |
| | x6_up = self.up_1(x5) |
| | temp = torch.cat((x6_up, x4), dim=1) |
| | x6 = self.right_conv_1(temp) |
| |
|
| | x7_up = self.up_2(x6) |
| | temp = torch.cat((x7_up, x3), dim=1) |
| | x7 = self.right_conv_2(temp) |
| |
|
| | x8_up = self.up_3(x7) |
| | temp = torch.cat((x8_up, x2), dim=1) |
| | x8 = self.right_conv_3(temp) |
| |
|
| | x9_up = self.up_4(x8) |
| | temp = torch.cat((x9_up, x1), dim=1) |
| | x9 = self.right_conv_4(temp) |
| |
|
| | |
| | output = self.output(x9) |
| |
|
| | return output |
| |
|
| |
|
| | |
| | import sys |
| | sys.path.append('./BidirectionalTranslation') |
| | from data.base_dataset import BaseDataset, get_params, get_transform |
| | from data.image_folder import make_dataset |
| |
|
| | def get_ScreenVAE_input(A_img, opt): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | L_img = A_img |
| | |
| | |
| | if A_img.size != L_img.size: |
| | A_img = A_img.resize(L_img.size, Image.ANTIALIAS) |
| | if A_img.size[1] > 2500: |
| | A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS) |
| | |
| | |
| | ow, oh = A_img.size |
| | transform_params = get_params(opt, A_img.size) |
| | |
| | |
| | A_transform = get_transform(opt, transform_params, grayscale=False) |
| | L_transform = get_transform(opt, transform_params, grayscale=True) |
| | A = A_transform(A_img) |
| | L = L_transform(L_img) |
| | |
| | |
| | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 |
| | Ai = tmp.unsqueeze(0) |
| | |
| | return {'A': A.unsqueeze(0), 'Ai': Ai.unsqueeze(0), 'L': L.unsqueeze(0), 'A_paths': '', 'h': oh, 'w': ow, 'B': torch.zeros(1), |
| | 'Bs': torch.zeros(1), |
| | 'Bi': torch.zeros(1), |
| | 'Bl': torch.zeros(1),} |
| |
|
| |
|
| | def get_bidirectional_translation_opt(opt): |
| | opt.results_dir = './results/test/western2manga' |
| | opt.dataroot = './datasets/color2manga' |
| | opt.checkpoints_dir = '/group/40034/zhuangjunhao/ScreenStyle/BidirectionalTranslation/checkpoints/color2manga/' |
| | opt.name = 'color2manga_cycle_ganstft' |
| | opt.model = 'cycle_ganstft' |
| | opt.direction = 'BtoA' |
| | opt.preprocess = 'none' |
| | opt.load_size = 512 |
| | opt.crop_size = 1024 |
| | opt.input_nc = 1 |
| | opt.output_nc = 3 |
| | opt.nz = 64 |
| | opt.netE = 'conv_256' |
| | opt.num_test = 30 |
| | opt.n_samples = 1 |
| | opt.upsample = 'bilinear' |
| | opt.ngf = 48 |
| | opt.nef = 48 |
| | opt.ndf = 32 |
| | opt.center_crop = True |
| | opt.color2screen = True |
| | opt.no_flip = True |
| |
|
| | |
| | opt.num_threads = 1 |
| | opt.batch_size = 1 |
| | opt.serial_batches = True |
| | return opt |