import os, io, json, requests, spaces, argparse, traceback, tempfile, zipfile, re, ast, time import gradio as gr import numpy as np import huggingface_hub import onnxruntime as ort import pandas as pd from datetime import datetime, timezone from collections import defaultdict from PIL import Image, ImageOps from apscheduler.schedulers.background import BackgroundScheduler from modules.classifyTags import categorize_tags_output, generate_tags_json, process_tags_for_misc from modules.pixai import create_pixai_interface from modules.booru import create_booru_interface from modules.multi_comfy import create_multi_comfy from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads """ For GPU install all the requirements.txt and the following: pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 or any other Torch version pip install onnxruntime-gpu """ """ It's recommended to create a venv if you want to use it offline: python -m venv venv venv\Scripts\activate pip install ... python app.py """ TITLE = 'Multi-Tagger v1.4' DESCRIPTION = '\nMulti-Tagger is a versatile application for advanced image analysis and captioning. Supports CUDA and CPU.\n' SWINV2_MODEL_DSV3_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' CONV_MODEL_DSV3_REPO = 'SmilingWolf/wd-convnext-tagger-v3' VIT_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-tagger-v3' VIT_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-large-tagger-v3' EVA02_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-eva02-large-tagger-v3' MOAT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-moat-tagger-v2' SWIN_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-swinv2-tagger-v2' CONV_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' CONV2_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2' VIT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-vit-tagger-v2' EVA02_LARGE_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-eva02-large-tagger-v1' SWINV2_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-swinv2-tagger-v1' # Global variables for model components (for memory management) CURRENT_MODEL = None CURRENT_MODEL_NAME = None CURRENT_TAGS_DF = None CURRENT_TAG_NAMES = None CURRENT_RATING_INDEXES = None CURRENT_GENERAL_INDEXES = None CURRENT_CHARACTER_INDEXES = None CURRENT_MODEL_TARGET_SIZE = None # Custom CSS for gallery styling css = """ #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;} #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;} #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);} #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;} #custom-gallery .thumbnail-item img.portrait {max-width: 100%;} #custom-gallery .thumbnail-item img.landscape {max-height: 100%;} .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;} .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;} #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;} """ MODEL_FILENAME = 'model.onnx' LABEL_FILENAME = 'selected_tags.csv' class Timer: """Utility class for measuring execution time of different operations""" def __init__(self): self.start_time = time.perf_counter() self.checkpoints = [('Start', self.start_time)] def checkpoint(self, label='Checkpoint'): """Add a checkpoint with a label""" now = time.perf_counter() self.checkpoints.append((label, now)) def report(self, is_clear_checkpoints=True): """Report time elapsed since last checkpoint""" max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0 prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time for (label, curr_time) in self.checkpoints[1:]: elapsed = curr_time - prev_time print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") prev_time = curr_time if is_clear_checkpoints: self.checkpoints.clear() self.checkpoint() def report_all(self): """Report all checkpoint times including total execution time""" print('\n> Execution Time Report:') max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0 prev_time = self.start_time for (label, curr_time) in self.checkpoints[1:]: elapsed = curr_time - prev_time print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") prev_time = curr_time total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0 print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") self.checkpoints.clear() def restart(self): """Restart the timer""" self.start_time = time.perf_counter() self.checkpoints = [('Start', self.start_time)] def parse_args() -> argparse.Namespace: """Parse command line arguments""" parser = argparse.ArgumentParser() parser.add_argument('--score-slider-step', type=float, default=0.05) parser.add_argument('--score-general-threshold', type=float, default=0.35) parser.add_argument('--score-character-threshold', type=float, default=0.85) parser.add_argument('--share', action='store_true') return parser.parse_args() def load_labels(dataframe) -> tuple: """Load tag names and their category indexes from the dataframe""" name_series = dataframe['name'] tag_names = name_series.tolist() # Find indexes for different tag categories rating_indexes = list(np.where(dataframe['category'] == 9)[0]) general_indexes = list(np.where(dataframe['category'] == 0)[0]) character_indexes = list(np.where(dataframe['category'] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def mcut_threshold(probs): """Calculate threshold using Maximum Change in second derivative (MCut) method""" sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def _download_model_files(model_repo): """Download model files from HuggingFace Hub""" csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) return csv_path, model_path def create_optimized_ort_session(model_path): """Create an optimized ONNX Runtime session with GPU support""" # Configure session options for better performance sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 0 # Use all available cores sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL sess_options.enable_mem_pattern = True sess_options.enable_cpu_mem_arena = True # Check available providers available_providers = ort.get_available_providers() print(f"Available ONNX Runtime providers: {available_providers}") # Configure execution providers (prefer CUDA if available) providers = [] # Use CUDA if available if 'CUDAExecutionProvider' in available_providers: providers.append('CUDAExecutionProvider') print("Using CUDA provider for ONNX inference") else: print("CUDA provider not available, falling back to CPU") # Always include CPU as fallback providers.append('CPUExecutionProvider') try: session = ort.InferenceSession(model_path, sess_options, providers=providers) print(f"Model loaded with providers: {session.get_providers()}") return session except Exception as e: print(f"Failed to create ONNX session: {e}") raise def _load_model_components_optimized(model_repo): """Load and optimize model components""" global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE # Only reload if model changed if model_repo == CURRENT_MODEL_NAME and CURRENT_MODEL is not None: return # Download files csv_path, model_path = _download_model_files(model_repo) # Load optimized ONNX model CURRENT_MODEL = create_optimized_ort_session(model_path) # Load tags tags_df = pd.read_csv(csv_path) tag_names, rating_indexes, general_indexes, character_indexes = load_labels(tags_df) # Store in global variables CURRENT_TAGS_DF = tags_df CURRENT_TAG_NAMES = tag_names CURRENT_RATING_INDEXES = rating_indexes CURRENT_GENERAL_INDEXES = general_indexes CURRENT_CHARACTER_INDEXES = character_indexes # Get model input size _, height, width, _ = CURRENT_MODEL.get_inputs()[0].shape CURRENT_MODEL_TARGET_SIZE = height CURRENT_MODEL_NAME = model_repo def _raw_predict(image_array, model_session): """Run raw prediction using the model session""" input_name = model_session.get_inputs()[0].name label_name = model_session.get_outputs()[0].name preds = model_session.run([label_name], {input_name: image_array})[0] return preds[0].astype(float) def unload_model(): """Explicitly unload the current model from memory""" global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE # Delete the model session if CURRENT_MODEL is not None: del CURRENT_MODEL CURRENT_MODEL = None # Clear other large objects CURRENT_TAGS_DF = None CURRENT_TAG_NAMES = None CURRENT_RATING_INDEXES = None CURRENT_GENERAL_INDEXES = None CURRENT_CHARACTER_INDEXES = None CURRENT_MODEL_TARGET_SIZE = None CURRENT_MODEL_NAME = None # Force garbage collection import gc gc.collect() # Clear CUDA cache if using GPU try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass def cleanup_after_processing(): """Cleanup resources after processing""" unload_model() class Predictor: """Main predictor class for handling image tagging""" def __init__(self): self.model_components = None self.last_loaded_repo = None def load_model(self, model_repo): """Load model if not already loaded""" if model_repo == self.last_loaded_repo and self.model_components is not None: return _load_model_components_optimized(model_repo) self.last_loaded_repo = model_repo def prepare_image(self, path): """Prepare image for model input""" image = Image.open(path) image = image.convert('RGBA') target_size = CURRENT_MODEL_TARGET_SIZE # Create white background and composite canvas = Image.new('RGBA', image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert('RGB') # Pad to square image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) # Resize if needed if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) # Convert to array and preprocess image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] # BGR to RGB return np.expand_dims(image_array, axis=0) def create_file(self, content: str, directory: str, fileName: str) -> str: """Create a file with the given content""" file_path = os.path.join(directory, fileName) if fileName.endswith('.json'): with open(file_path, 'w', encoding='utf-8') as file: file.write(content) else: with open(file_path, 'w+', encoding='utf-8') as file: file.write(content) return file_path def predict(self, gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()): """Main prediction function for processing images""" tag_results.clear() gallery_len = len(gallery) print(f"Predict load model: {model_repo}, gallery length: {gallery_len}") timer = Timer() progressRatio = 1 progressTotal = gallery_len + 1 current_progress = 0 txt_infos = [] output_dir = tempfile.mkdtemp() if not os.path.exists(output_dir): os.makedirs(output_dir) # Load initial model self.load_model(model_repo) current_progress += progressRatio / progressTotal progress(current_progress, desc='Initialize wd model finished') timer.checkpoint("Initialize wd model") timer.report() name_counters = defaultdict(int) for (idx, value) in enumerate(gallery): try: # Handle duplicate filenames image_path = value[0] image_name = os.path.splitext(os.path.basename(image_path))[0] name_counters[image_name] += 1 if name_counters[image_name] > 1: image_name = f"{image_name}_{name_counters[image_name]:02d}" # Prepare image image = self.prepare_image(image_path) print(f"Gallery {idx:02d}: Starting run first model ({model_repo})...") # Load and run first model self.load_model(model_repo) preds = _raw_predict(image, CURRENT_MODEL) labels = list(zip(CURRENT_TAG_NAMES, preds)) # Process ratings ratings_names = [labels[i] for i in CURRENT_RATING_INDEXES] rating = dict(ratings_names) # Process general tags general_names = [labels[i] for i in CURRENT_GENERAL_INDEXES] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_thresh_temp = mcut_threshold(general_probs) else: general_thresh_temp = general_thresh general_res = [x for x in general_names if x[1] > general_thresh_temp] general_res = dict(general_res) # Process character tags character_names = [labels[i] for i in CURRENT_CHARACTER_INDEXES] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_thresh_temp = mcut_threshold(character_probs) character_thresh_temp = max(0.15, character_thresh_temp) else: character_thresh_temp = character_thresh character_res = [x for x in character_names if x[1] > character_thresh_temp] character_res = dict(character_res) character_list_1 = list(character_res.keys()) # Sort general tags by confidence sorted_general_list_1 = sorted(general_res.items(), key=lambda x: x[1], reverse=True) sorted_general_list_1 = [x[0] for x in sorted_general_list_1] # Handle second model if provided if model_repo_2 and model_repo_2 != model_repo: print(f"Gallery {idx:02d}: Starting run second model ({model_repo_2})...") self.load_model(model_repo_2) preds_2 = _raw_predict(image, CURRENT_MODEL) labels_2 = list(zip(CURRENT_TAG_NAMES, preds_2)) # Process general tags from second model general_names_2 = [labels_2[i] for i in CURRENT_GENERAL_INDEXES] if general_mcut_enabled: general_probs_2 = np.array([x[1] for x in general_names_2]) general_thresh_temp_2 = mcut_threshold(general_probs_2) else: general_thresh_temp_2 = general_thresh general_res_2 = [x for x in general_names_2 if x[1] > general_thresh_temp_2] general_res_2 = dict(general_res_2) # Process character tags from second model character_names_2 = [labels_2[i] for i in CURRENT_CHARACTER_INDEXES] if character_mcut_enabled: character_probs_2 = np.array([x[1] for x in character_names_2]) character_thresh_temp_2 = mcut_threshold(character_probs_2) character_thresh_temp_2 = max(0.15, character_thresh_temp_2) else: character_thresh_temp_2 = character_thresh character_res_2 = [x for x in character_names_2 if x[1] > character_thresh_temp_2] character_res_2 = dict(character_res_2) character_list_2 = list(character_res_2.keys()) # Sort general tags from second model sorted_general_list_2 = sorted(general_res_2.items(), key=lambda x: x[1], reverse=True) sorted_general_list_2 = [x[0] for x in sorted_general_list_2] # Combine results from both models combined_character_list = list(set(character_list_1 + character_list_2)) combined_general_list = list(set(sorted_general_list_1 + sorted_general_list_2)) else: combined_character_list = character_list_1 combined_general_list = sorted_general_list_1 # Remove characters from general tags if merging is disabled if not characters_merge_enabled: combined_character_list = [item for item in combined_character_list if item not in combined_general_list] # Handle additional tags prepend_list = [tag.strip() for tag in additional_tags_prepend.split(',') if tag.strip()] append_list = [tag.strip() for tag in additional_tags_append.split(',') if tag.strip()] # Avoid duplicates in prepend/append lists if prepend_list and append_list: append_list = [item for item in append_list if item not in prepend_list] # Remove prepended tags from main list if prepend_list: combined_general_list = [item for item in combined_general_list if item not in prepend_list] # Remove appended tags from main list if append_list: combined_general_list = [item for item in combined_general_list if item not in append_list] # Combine all tags combined_general_list = prepend_list + combined_general_list + append_list # Format output string sorted_general_strings = ', '.join( (combined_character_list if characters_merge_enabled else []) + combined_general_list ).replace('(', '\\(').replace(')', '\\)').replace('_', ' ') # Generate categorized output categorized_strings = categorize_tags_output(sorted_general_strings, character_res).replace('(', '\\(').replace(')', '\\)') categorized_json = generate_tags_json(sorted_general_strings, character_res) # Create output files txt_content = f"Output (string): {sorted_general_strings}\n\nCategorized Output: {categorized_strings}" txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt") txt_infos.append({'path': txt_file, 'name': f"{image_name}_output.txt"}) # Save image copy image_path = value[0] image = Image.open(image_path) image.save(os.path.join(output_dir, f"{image_name}.png"), format='PNG') txt_infos.append({'path': os.path.join(output_dir, f"{image_name}.png"), 'name': f"{image_name}.png"}) # Create tags text file txt_file = self.create_file(sorted_general_strings, output_dir, image_name + '.txt') # Create categorized tags file categorized_file = self.create_file(categorized_strings, output_dir, f"{image_name}_categorized.txt") txt_infos.append({'path': categorized_file, 'name': f"{image_name}_categorized.txt"}) txt_infos.append({'path': txt_file, 'name': image_name + '.txt'}) # Create JSON file json_content = json.dumps(categorized_json, indent=2, ensure_ascii=False) json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized.json") txt_infos.append({'path': json_file, 'name': f"{image_name}_categorized.json"}) # Store results tag_results[image_path] = { 'strings': sorted_general_strings, 'categorized_strings': categorized_strings, 'categorized_json': categorized_json, 'rating': rating, 'character_res': character_res, 'general_res': general_res } # Update progress current_progress += progressRatio / progressTotal progress(current_progress, desc=f"image{idx:02d}, predict finished") timer.checkpoint(f"image{idx:02d}, predict finished") timer.report() except Exception as e: print(traceback.format_exc()) print('Error predict: ' + str(e)) # Create download zip download = [] if txt_infos is not None and len(txt_infos) > 0: downloadZipPath = os.path.join( output_dir, 'Multi-Tagger-' + datetime.now().strftime('%Y%m%d-%H%M%S') + '.zip' ) with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip: for info in txt_infos: taggers_zip.write(info['path'], arcname=info['name']) # If using GPU, model will auto unload after zip file creation cleanup_after_processing() # Comment here to turn off this behavior download.append(downloadZipPath) progress(1, desc=f"Predict completed") timer.report_all() print('Predict is complete.') # Return first image results as default first_image_results = '', {}, {}, {}, '', {} if gallery and len(gallery) > 0: first_image_path = gallery[0][0] if first_image_path in tag_results: first_result = tag_results[first_image_path] character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for name in first_result['character_res'].keys()]) first_image_results = ( first_result['strings'], first_result['rating'], character_tags_formatted, first_result['general_res'], first_result.get('categorized_strings', ''), first_result.get('categorized_json', {}) ) return ( download, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], tag_results ) def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData): # Return first image results if no selection if not selected_state and gallery and len(gallery) > 0: first_image_path = gallery[0][0] if first_image_path in tag_results: first_result = tag_results[first_image_path] character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for name in first_result['character_res'].keys()]) return ( first_result['strings'], first_result['rating'], character_tags_formatted, first_result['general_res'], first_result.get('categorized_strings', ''), first_result.get('categorized_json', {}) ) if not selected_state: return '', {}, '', {}, '', {} # Get selected image path selected_value = selected_state.value image_path = None if isinstance(selected_value, dict) and 'image' in selected_value: image_path = selected_value['image']['path'] elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0: image_path = selected_value[0] else: image_path = str(selected_value) # Return stored results if image_path in tag_results: result = tag_results[image_path] character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for name in result['character_res'].keys()]) return ( result['strings'], result['rating'], character_tags_formatted, result['general_res'], result.get('categorized_strings', ''), result.get('categorized_json', {}) ) return '', {}, '', {}, '', {} def append_gallery(gallery: list, image: str): """Add a single media file (image or video) to the gallery""" return handle_single_media_upload(image, gallery) def extend_gallery(gallery: list, images): """Add multiple media files (images or videos) to the gallery""" return handle_multiple_media_uploads(images, gallery) # Parse arguments and initialize predictor args = parse_args() predictor = Predictor() dropdown_list = [ EVA02_LARGE_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, SWINV2_MODEL_IS_DSV1_REPO ] def _restart_space(): """Restart the HuggingFace Space periodically for stability""" HF_TOKEN = os.getenv('HF_TOKEN') if not HF_TOKEN: raise ValueError('HF_TOKEN environment variable is not set.') huggingface_hub.HfApi().restart_space( repo_id='Werli/Multi-Tagger', token=HF_TOKEN, factory_reboot=False ) # Setup scheduler for periodic restarts scheduler = BackgroundScheduler() restart_space_job = scheduler.add_job(_restart_space, 'interval', seconds=172800) scheduler.start() next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc) NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process." with gr.Blocks(title=TITLE, css=css, theme="Werli/Purple-Crimson-Gradio-Theme", fill_width=True) as demo: gr.Markdown(value=f"
{DESCRIPTION}
") with gr.Tab(label='Waifu Diffusion'): with gr.Row(): with gr.Column(): with gr.Column(variant='panel'): image_input = gr.Image( label='Upload an Image (or paste from clipboard)', type='filepath', sources=['upload', 'clipboard'], height=150 ) with gr.Row(): upload_button = gr.UploadButton( 'Upload multiple images or videos', file_types=['image', 'video'], file_count='multiple', size='md' ) gallery = gr.Gallery( columns=2, show_share_button=False, interactive=True, height='auto', label='Grid of images', preview=False, elem_id='custom-gallery' ) submit = gr.Button(value='Analyze Images', variant='primary', size='lg') clear = gr.ClearButton(components=[gallery], value='Clear Gallery', variant='secondary', size='sm') with gr.Column(variant='panel'): model_repo = gr.Dropdown( dropdown_list, value=EVA02_LARGE_MODEL_DSV3_REPO, label='1st Model' ) PLUS = '+?' gr.Markdown(value=f"{PLUS}
") model_repo_2 = gr.Dropdown( [None] + dropdown_list, value=None, label='2nd Model (Optional)', info='Select another model for diversified results.' ) with gr.Row(): general_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label='General Tags Threshold', scale=3 ) general_mcut_enabled = gr.Checkbox( value=False, label='Use MCut threshold', scale=1 ) with gr.Row(): character_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label='Character Tags Threshold', scale=3 ) character_mcut_enabled = gr.Checkbox( value=False, label='Use MCut threshold', scale=1 ) with gr.Row(): characters_merge_enabled = gr.Checkbox( value=False, label='Merge characters into the string output', scale=1 ) with gr.Row(): additional_tags_prepend = gr.Text( label='Prepend Additional tags (comma split)' ) additional_tags_append = gr.Text( label='Append Additional tags (comma split)' ) with gr.Row(): clear = gr.ClearButton( components=[ gallery, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, additional_tags_prepend, additional_tags_append ], value='Clear Everything', variant='secondary', size='lg' ) with gr.Column(variant='panel'): download_file = gr.File(label='Download') character_res = gr.Textbox( label="Character tags", show_copy_button=True, lines=3 ) sorted_general_strings = gr.Textbox( label='Output', show_label=True, show_copy_button=True, lines=5 ) categorized_strings = gr.Textbox( label='Categorized', show_label=True, show_copy_button=True, lines=5 ) tags_json = gr.JSON( label='Categorized Tags (JSON)', visible=True ) rating = gr.Label(label='Rating') general_res = gr.Textbox( label="General tags", show_copy_button=True, lines=3, visible=False # Temp ) # State to store results tag_results = gr.State({}) # Event handlers image_input.change( append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input] ) upload_button.upload( extend_gallery, inputs=[gallery, upload_button], outputs=gallery ) gallery.select( get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json] ) submit.click( predictor.predict, inputs=[ gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, characters_merge_enabled, additional_tags_prepend, additional_tags_append, tag_results ], outputs=[download_file, sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json, tag_results] ) gr.Markdown('[Based on SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) ') with gr.Tab("PixAI"): pixai_interface = create_pixai_interface() with gr.Tab("Booru Image Fetcher"): booru_interface = create_booru_interface() with gr.Tab("ComfyUI Extractor"): comfy_interface = create_multi_comfy() with gr.Tab(label="Misc"): with gr.Row(): with gr.Column(variant="panel"): tag_string = gr.Textbox( label="Input Tags", placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", lines=4 ) submit_button = gr.Button(value="START", variant="primary", size="lg") with gr.Column(variant="panel"): cleaned_tags_output = gr.Textbox( label="Cleaned Tags", show_label=True, show_copy_button=True, lines=4, info="Tags with ? and numbers removed, formatted with commas. Useful for clearing tags from Booru sites." ) classify_tags_for_display = gr.Textbox( label="Categorized (string)", show_label=True, show_copy_button=True, lines=8, info="Tags organized by categories" ) generate_categorized_json = gr.JSON( label="Categorized JSON (tags)" ) # Fix the event handler to properly call the function submit_button.click( process_tags_for_misc, inputs=[tag_string], outputs=[cleaned_tags_output, classify_tags_for_display, generate_categorized_json] ) gr.Markdown(NEXT_RESTART) demo.queue(max_size=5).launch(show_error=True, show_api=False)