|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import gc |
|
|
from .remove_background import RemoveBackground |
|
|
from custom_logger import logger_config |
|
|
|
|
|
|
|
|
class RemoveBackgroundAI(RemoveBackground): |
|
|
def __init__(self, model_name='briaai/RMBG-2.0', device='cuda' if torch.cuda.is_available() else 'cpu', image_size=(1024, 1024)): |
|
|
""" |
|
|
Initialize the BackgroundRemover with a pre-trained model. |
|
|
""" |
|
|
super().__init__("remove_background_ai") |
|
|
self.device = device |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True) |
|
|
if device == 'cuda': |
|
|
torch.set_float32_matmul_precision('high') |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize(image_size, antialias=True), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def process(self, image_path, output_path=None, save_alpha=True): |
|
|
""" |
|
|
Remove background while preserving exact foreground position and size. |
|
|
|
|
|
Args: |
|
|
image_path (str or Path): Path to the input image |
|
|
output_path (str or Path, optional): Path to save the output image |
|
|
save_alpha (bool): If True, save with transparency (PNG RGBA), else black background |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Processed image with background removed |
|
|
""" |
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
original_size = image.size |
|
|
input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = self.model(input_tensor)[-1].sigmoid().cpu() |
|
|
pred = preds[0].squeeze() |
|
|
mask_pil = transforms.ToPILImage()(pred) |
|
|
mask = mask_pil.resize(image.size, Image.LANCZOS) |
|
|
|
|
|
|
|
|
if save_alpha: |
|
|
|
|
|
result_image = Image.new("RGBA", original_size, (0, 0, 0, 0)) |
|
|
image_rgba = image.convert("RGBA") |
|
|
|
|
|
|
|
|
mask_array = np.array(mask) |
|
|
image_array = np.array(image_rgba) |
|
|
|
|
|
|
|
|
image_array[:, :, 3] = mask_array |
|
|
|
|
|
result_image = Image.fromarray(image_array) |
|
|
else: |
|
|
|
|
|
result_image = Image.new("RGB", original_size, (0, 0, 0)) |
|
|
|
|
|
result_image.paste(image, mask=mask) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
if isinstance(output_path, Path): |
|
|
output_path = str(output_path) |
|
|
|
|
|
if save_alpha: |
|
|
|
|
|
if not output_path.lower().endswith('.png'): |
|
|
output_path = os.path.splitext(output_path)[0] + '.png' |
|
|
result_image.save(output_path, format="PNG") |
|
|
else: |
|
|
result_image.save(output_path) |
|
|
|
|
|
return result_image |
|
|
|
|
|
def remove_background(self, image_path, output_path=None, crop=False, bg_color=(0, 0, 0), save_alpha=False, keep_size=True, skip_crop=False): |
|
|
""" |
|
|
Original method - kept for backward compatibility. |
|
|
For preserving exact position/size, use remove_background_preserve_position() instead. |
|
|
""" |
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
original_size = image.size |
|
|
input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = self.model(input_tensor)[-1].sigmoid().cpu() |
|
|
pred = preds[0].squeeze() |
|
|
mask_pil = transforms.ToPILImage()(pred) |
|
|
mask = mask_pil.resize(image.size) |
|
|
|
|
|
|
|
|
image_rgba = image.convert("RGBA") |
|
|
temp_image = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
|
|
temp_image.paste(image_rgba, mask=mask) |
|
|
|
|
|
|
|
|
if skip_crop: |
|
|
print(f"Skipping cropping for {image_path}") |
|
|
else: |
|
|
|
|
|
if crop: |
|
|
temp_image = self._crop_to_content(temp_image) |
|
|
if keep_size: |
|
|
|
|
|
padded_image = Image.new("RGBA", original_size, (0, 0, 0, 0)) |
|
|
offset_x = (original_size[0] - temp_image.size[0]) // 2 |
|
|
offset_y = (original_size[1] - temp_image.size[1]) // 2 |
|
|
padded_image.paste(temp_image, (offset_x, offset_y)) |
|
|
temp_image = padded_image |
|
|
|
|
|
|
|
|
if output_path: |
|
|
if isinstance(output_path, Path): |
|
|
output_path = str(output_path) |
|
|
|
|
|
if save_alpha: |
|
|
if not output_path.lower().endswith('.png'): |
|
|
output_path = os.path.splitext(output_path)[0] + '.png' |
|
|
temp_image.save(output_path, format="PNG") |
|
|
else: |
|
|
bg_image = Image.new("RGB", temp_image.size, bg_color) |
|
|
bg_image.paste(temp_image, mask=temp_image.split()[3]) |
|
|
bg_image.save(output_path) |
|
|
|
|
|
return temp_image |
|
|
|
|
|
def _crop_to_content(self, image): |
|
|
""" |
|
|
Crop the image to the bounding box of the non-transparent content. |
|
|
""" |
|
|
img_array = np.array(image) |
|
|
alpha_channel = img_array[:, :, 3] |
|
|
non_empty_columns = np.where(alpha_channel.max(axis=0) > 0)[0] |
|
|
non_empty_rows = np.where(alpha_channel.max(axis=1) > 0)[0] |
|
|
|
|
|
if len(non_empty_columns) > 0 and len(non_empty_rows) > 0: |
|
|
crop_box = ( |
|
|
non_empty_columns.min(), |
|
|
non_empty_rows.min(), |
|
|
non_empty_columns.max() + 1, |
|
|
non_empty_rows.max() + 1 |
|
|
) |
|
|
return image.crop(crop_box) |
|
|
|
|
|
return image |
|
|
|
|
|
def cleanup(self): |
|
|
""" |
|
|
Clean up resources used by the model. |
|
|
""" |
|
|
if self.device == 'cuda': |
|
|
self.model.to('cpu') |
|
|
del self.model |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
print("Model unloaded and resources cleaned up") |
|
|
|
|
|
|
|
|
def remove_background_batch(folder, output_path=None, save_alpha=True, preserve_position=True): |
|
|
""" |
|
|
Process all images in a folder and remove backgrounds while preserving position and size. |
|
|
|
|
|
Args: |
|
|
folder (str): Folder containing images to process |
|
|
output_path (str): Output folder path |
|
|
save_alpha (bool): If True, saves output as PNG with transparency |
|
|
preserve_position (bool): If True, keeps foreground in exact original position |
|
|
""" |
|
|
remover = BackgroundRemover() |
|
|
input_path = Path(folder) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
Path(output_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
image_files = [] |
|
|
for ext in ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'webp']: |
|
|
image_files.extend(input_path.glob(f"*.{ext}")) |
|
|
image_files.extend(input_path.glob(f"*.{ext.upper()}")) |
|
|
|
|
|
print(f"Found {len(image_files)} images to process") |
|
|
|
|
|
try: |
|
|
for img_path in tqdm(image_files, desc="Removing Background", unit="image"): |
|
|
try: |
|
|
|
|
|
if output_path: |
|
|
output_filename = os.path.basename(img_path) |
|
|
if save_alpha and not output_filename.lower().endswith('.png'): |
|
|
output_filename = os.path.splitext(output_filename)[0] + '.png' |
|
|
output_file = os.path.join(output_path, output_filename) |
|
|
else: |
|
|
output_file = img_path |
|
|
|
|
|
if preserve_position: |
|
|
|
|
|
remover.remove_background_preserve_position( |
|
|
image_path=img_path, |
|
|
output_path=output_file, |
|
|
save_alpha=save_alpha |
|
|
) |
|
|
else: |
|
|
|
|
|
remover.remove_background( |
|
|
image_path=img_path, |
|
|
output_path=output_file, |
|
|
crop=False, |
|
|
bg_color=(0, 0, 0), |
|
|
save_alpha=save_alpha, |
|
|
skip_crop=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error processing {img_path}: {str(e)}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nProcessing interrupted by user") |
|
|
finally: |
|
|
remover.cleanup() |
|
|
|
|
|
|
|
|
|
|
|
def process(image_path, output_path=None, save_alpha=True): |
|
|
""" |
|
|
Process a single image and remove background while preserving position and size. |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to input image |
|
|
output_path (str, optional): Path to save output image |
|
|
save_alpha (bool): If True, saves with transparency |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Processed image |
|
|
""" |
|
|
remover = RemoveBackgroundAI() |
|
|
|
|
|
try: |
|
|
result = remover.remove_background_preserve_position( |
|
|
image_path=image_path, |
|
|
output_path=output_path, |
|
|
save_alpha=save_alpha |
|
|
) |
|
|
print(f"β Successfully processed: {os.path.basename(image_path)}") |
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"β Error processing {image_path}: {str(e)}") |
|
|
return None |
|
|
finally: |
|
|
remover.cleanup() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
remove_background_batch( |
|
|
folder="../CaptionCreator/media/puzzle_x_pic/", |
|
|
output_path="../CaptionCreator/media/processed/", |
|
|
save_alpha=True, |
|
|
preserve_position=True |
|
|
) |