.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
README.md CHANGED
@@ -4,13 +4,11 @@ emoji: 🐨
4
  colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- python_version: "3.13"
12
  short_description: AI-powered image editing tool
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
-
 
4
  colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.48.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  short_description: AI-powered image editing tool
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
__lib__/__init__.py DELETED
File without changes
__lib__/app.py DELETED
@@ -1,1455 +0,0 @@
1
- import gradio as gr
2
- import threading
3
- import os
4
- import shutil
5
- import tempfile
6
- import time
7
- import json
8
- from util import process_image_edit, download_and_check_result_nsfw, GoodWebsiteUrl
9
- from nfsw import NSFWDetector
10
-
11
- # Google Gemini URL for restricted languages
12
- GOOGLE_GEMINI_URL = "https://aistudio.google.com/models/gemini-2-5-flash-image"
13
-
14
- # i18n - Load from encrypted modules
15
- import sys
16
- from pathlib import Path
17
-
18
- # Add i18n module to path
19
- _i18n_module_path = Path(__file__).parent / "i18n"
20
- if str(_i18n_module_path) not in sys.path:
21
- sys.path.insert(0, str(_i18n_module_path))
22
-
23
- # Import encrypted i18n loader
24
- from i18n import translations as _translations
25
- translations = _translations
26
-
27
- def load_translations():
28
- """Compatibility function - translations are already loaded"""
29
- return translations
30
-
31
- def t(key, lang="en"):
32
- return translations.get(lang, {}).get(key, key)
33
-
34
- # Configuration parameters
35
- # a = b
36
- TIP_TRY_N = 1 # Show like button tip after x tries
37
- FREE_TRY_N = 4 # Free phase: first 15 tries without restrictions
38
- SLOW_TRY_N = 6 # Slow phase start: 25 tries
39
- SLOW2_TRY_N = 10 # Slow phase start: 32 tries
40
- RATE_LIMIT_60 = 14 # Full restriction: blocked after 40 tries
41
-
42
- # Time window configuration (minutes)
43
- PHASE_1_WINDOW = 6 # 15-25 tries: 5 minutes
44
- PHASE_2_WINDOW = 13 # 25-32 tries: 10 minutes
45
- PHASE_3_WINDOW = 20 # 32-40 tries: 20 minutes
46
- MAX_IMAGES_PER_WINDOW = 2 # Max images per time window
47
- high_priority_n = 1 # 每个ip只有第一个任务是高优先级的
48
-
49
- IP_Dict = {}
50
- # IP generation statistics and time window tracking
51
- IP_Generation_Count = {} # Record total generation count for each IP
52
- IP_Rate_Limit_Track = {} # Record generation count and timestamp in current time window for each IP
53
- IP_Country_Cache = {} # Cache IP country information to avoid repeated queries
54
-
55
- # Country usage statistics
56
- Country_Usage_Stats = {} # Track usage count by country
57
- Total_Request_Count = 0 # Total request counter for periodic printing
58
- PRINT_STATS_INTERVAL = 10 # Print stats every N requests
59
-
60
- # Async IP query tracking
61
- IP_Query_Results = {} # Track async query results
62
- # Active task tracking (within recent time window)
63
- Active_Tasks = {} # {client_ip: {"start": timestamp}}
64
-
65
- # Restricted countries list (these countries have lower usage limits)
66
- RESTRICTED_COUNTRIES = ["印度", "巴基斯坦", "俄罗斯", "中国", "伊朗"]
67
- RESTRICTED_COUNTRY_LIMIT = 1 # Max usage for restricted countries
68
-
69
- country_dict = {
70
- "zh": ["中国"],
71
- "hi": ["印度"],
72
- "fi": ["芬兰"],
73
- "en": ["美国", "澳大利亚", "英国", "加拿大", "新西兰", "爱尔兰"],
74
- "es": ["西班牙", "墨西哥", "阿根廷", "哥伦比亚", "智利", "秘鲁"],
75
- "pt": ["葡萄牙", "巴西"],
76
- "fr": ["法国", "摩纳哥"],
77
- "de": ["德国", "奥地利", ],
78
- "it": ["意大利", "圣马力诺", "梵蒂冈"],
79
- "ja": ["日本"],
80
- "ru": ["俄罗斯"],
81
- "uk": ["乌克兰"],
82
- "ar": ["沙特阿拉伯", "埃及", "阿拉伯联合酋长国", "摩洛哥"],
83
- "nl":["荷兰"],
84
- "no":["挪威"],
85
- "sv":["瑞典"],
86
- "id":["印度尼西亚"],
87
- "vi": ["越南"],
88
- "he": ["以色列"],
89
- "tr": ["土耳其"],
90
- "da": ["丹麦"],
91
- }
92
-
93
- def query_ip_country(client_ip):
94
- """
95
- Query IP address geo information with robust error handling
96
-
97
- Features:
98
- - 3 second timeout limit
99
- - Comprehensive error handling
100
- - Automatic fallback to default values
101
- - Cache mechanism to avoid repeated queries
102
-
103
- Returns:
104
- dict: {"country": str, "region": str, "city": str}
105
- """
106
- # Check cache first - no API call for subsequent visits
107
- if client_ip in IP_Country_Cache:
108
- print(f"Using cached IP data for {client_ip}")
109
- return IP_Country_Cache[client_ip]
110
-
111
- # Validate IP address
112
- if not client_ip or client_ip in ["127.0.0.1", "localhost", "::1"]:
113
- print(f"Invalid or local IP address: {client_ip}, using default")
114
- default_geo = {"country": "Unknown", "region": "Unknown", "city": "Unknown"}
115
- IP_Country_Cache[client_ip] = default_geo
116
- return default_geo
117
-
118
- # First time visit - query API with robust error handling
119
- print(f"Querying IP geolocation for {client_ip}...")
120
-
121
- try:
122
- import requests
123
- from requests.exceptions import Timeout, ConnectionError, RequestException
124
-
125
- api_url = f"https://api.vore.top/api/IPdata?ip={client_ip}"
126
-
127
- # Make request with 3 second timeout
128
- response = requests.get(api_url, timeout=3)
129
-
130
- if response.status_code == 200:
131
- data = response.json()
132
- if data.get("code") == 200 and "ipdata" in data:
133
- ipdata = data["ipdata"]
134
- geo_info = {
135
- "country": ipdata.get("info1", "Unknown"),
136
- "region": ipdata.get("info2", "Unknown"),
137
- "city": ipdata.get("info3", "Unknown")
138
- }
139
- IP_Country_Cache[client_ip] = geo_info
140
- print(f"Successfully detected location for {client_ip}: {geo_info['country']}")
141
- return geo_info
142
- else:
143
- print(f"API returned invalid data for {client_ip}: {data}")
144
- else:
145
- print(f"API request failed with status {response.status_code} for {client_ip}")
146
-
147
- except Timeout:
148
- print(f"Timeout (>3s) querying IP location for {client_ip}, using default")
149
- except ConnectionError:
150
- print(f"Network connection error for IP {client_ip}, using default")
151
- except RequestException as e:
152
- print(f"Request error for IP {client_ip}: {e}, using default")
153
- except Exception as e:
154
- print(f"Unexpected error querying IP {client_ip}: {e}, using default")
155
-
156
- # All failures lead here - cache default and return
157
- default_geo = {"country": "Unknown", "region": "Unknown", "city": "Unknown"}
158
- IP_Country_Cache[client_ip] = default_geo
159
- print(f"Cached default location for {client_ip}")
160
- return default_geo
161
-
162
- def query_ip_country_async(client_ip):
163
- """
164
- Async version that returns immediately with default, then updates cache in background
165
-
166
- Returns:
167
- tuple: (immediate_lang, geo_info_or_none)
168
- """
169
- # If already cached, return immediately
170
- if client_ip in IP_Country_Cache:
171
- geo_info = IP_Country_Cache[client_ip]
172
- lang = get_lang_from_country(geo_info["country"])
173
- return lang, geo_info
174
-
175
- # Return default immediately, query in background
176
- return "en", None
177
-
178
- def get_lang_from_country(country):
179
- """
180
- Map country name to language code with comprehensive validation
181
-
182
- Features:
183
- - Handles invalid/empty input
184
- - Case-insensitive matching
185
- - Detailed logging
186
- - Always returns valid language code
187
-
188
- Args:
189
- country (str): Country name
190
-
191
- Returns:
192
- str: Language code (always valid, defaults to "en")
193
- """
194
- # Input validation
195
- if not country or not isinstance(country, str) or country.strip() == "":
196
- print(f"Invalid country provided: '{country}', defaulting to English")
197
- return "en"
198
-
199
- # Normalize country name
200
- country = country.strip()
201
- if country.lower() == "unknown":
202
- print(f"Unknown country, defaulting to English")
203
- return "en"
204
-
205
- try:
206
- # Search in country dictionary with case-sensitive match first
207
- for lang, countries in country_dict.items():
208
- if country in countries:
209
- print(f"Matched country '{country}' to language '{lang}'")
210
- return lang
211
-
212
- # If no exact match, try case-insensitive match
213
- country_lower = country.lower()
214
- for lang, countries in country_dict.items():
215
- for country_variant in countries:
216
- if country_variant.lower() == country_lower:
217
- print(f"Case-insensitive match: country '{country}' to language '{lang}'")
218
- return lang
219
-
220
- # No match found
221
- print(f"Country '{country}' not found in country_dict, defaulting to English")
222
- return "en"
223
-
224
- except Exception as e:
225
- print(f"Error matching country '{country}': {e}, defaulting to English")
226
- return "en"
227
-
228
- def get_lang_from_ip(client_ip):
229
- """
230
- Get language based on IP geolocation with comprehensive error handling
231
-
232
- Features:
233
- - Validates input IP address
234
- - Handles all possible exceptions
235
- - Always returns a valid language code
236
- - Defaults to English on any failure
237
- - Includes detailed logging
238
-
239
- Args:
240
- client_ip (str): Client IP address
241
-
242
- Returns:
243
- str: Language code (always valid, defaults to "en")
244
- """
245
- # Input validation
246
- if not client_ip or not isinstance(client_ip, str):
247
- print(f"Invalid IP address provided: {client_ip}, defaulting to English")
248
- return "en"
249
-
250
- try:
251
- # Query geolocation info (has its own error handling and 3s timeout)
252
- geo_info = query_ip_country(client_ip)
253
-
254
- if not geo_info or not isinstance(geo_info, dict):
255
- print(f"No geolocation data for {client_ip}, defaulting to English")
256
- return "en"
257
-
258
- # Extract country with fallback
259
- country = geo_info.get("country", "Unknown")
260
- if not country or country == "Unknown":
261
- print(f"Unknown country for IP {client_ip}, defaulting to English")
262
- return "en"
263
-
264
- # Map country to language
265
- detected_lang = get_lang_from_country(country)
266
-
267
- # Validate language code
268
- if not detected_lang or not isinstance(detected_lang, str) or len(detected_lang) != 2:
269
- print(f"Invalid language code '{detected_lang}' for {client_ip}, defaulting to English")
270
- return "en"
271
-
272
- print(f"IP {client_ip} -> Country: {country} -> Language: {detected_lang}")
273
- return detected_lang
274
-
275
- except Exception as e:
276
- print(f"Unexpected error getting language from IP {client_ip}: {e}, defaulting to English")
277
- return "en" # Always return a valid language code
278
-
279
- def is_restricted_country_ip(client_ip):
280
- """
281
- Check if IP is from a restricted country
282
-
283
- Returns:
284
- bool: True if from restricted country
285
- """
286
- geo_info = query_ip_country(client_ip)
287
- country = geo_info["country"]
288
- return country in RESTRICTED_COUNTRIES
289
-
290
- def get_ip_max_limit(client_ip):
291
- """
292
- Get max usage limit for IP based on country
293
-
294
- Returns:
295
- int: Max usage limit
296
- """
297
- if is_restricted_country_ip(client_ip):
298
- return RESTRICTED_COUNTRY_LIMIT
299
- else:
300
- return RATE_LIMIT_60
301
-
302
- def get_ip_generation_count(client_ip):
303
- """
304
- Get IP generation count
305
- """
306
- if client_ip not in IP_Generation_Count:
307
- IP_Generation_Count[client_ip] = 0
308
- return IP_Generation_Count[client_ip]
309
-
310
- def increment_ip_generation_count(client_ip):
311
- """
312
- Increment IP generation count
313
- """
314
- if client_ip not in IP_Generation_Count:
315
- IP_Generation_Count[client_ip] = 0
316
- IP_Generation_Count[client_ip] += 1
317
- return IP_Generation_Count[client_ip]
318
-
319
- def get_ip_phase(client_ip):
320
- """
321
- Get current phase for IP
322
-
323
- Returns:
324
- str: 'free', 'rate_limit_1', 'rate_limit_2', 'rate_limit_3', 'blocked'
325
- """
326
- count = get_ip_generation_count(client_ip)
327
- max_limit = get_ip_max_limit(client_ip)
328
-
329
- # For restricted countries, check if they've reached their limit
330
- if is_restricted_country_ip(client_ip):
331
- if count >= max_limit:
332
- return 'blocked'
333
- elif count >= max_limit - 2: # Last 2 attempts
334
- return 'rate_limit_3'
335
- elif count >= max_limit - 3: # 3rd attempt from end
336
- return 'rate_limit_2'
337
- elif count >= max_limit - 4: # 4th attempt from end
338
- return 'rate_limit_1'
339
- else:
340
- return 'free'
341
-
342
- # For normal countries, use standard limits
343
- if count < FREE_TRY_N:
344
- return 'free'
345
- elif count < SLOW_TRY_N:
346
- return 'rate_limit_1' # NSFW blur + 5 minutes 2 images
347
- elif count < SLOW2_TRY_N:
348
- return 'rate_limit_2' # NSFW blur + 10 minutes 2 images
349
- elif count < max_limit:
350
- return 'rate_limit_3' # NSFW blur + 20 minutes 2 images
351
- else:
352
- return 'blocked' # Generation blocked
353
-
354
- def check_rate_limit_for_phase(client_ip, phase):
355
- """
356
- Check rate limit for specific phase
357
-
358
- Returns:
359
- tuple: (is_limited, wait_time_minutes, current_count)
360
- """
361
- if phase not in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
362
- return False, 0, 0
363
-
364
- # Determine time window
365
- if phase == 'rate_limit_1':
366
- window_minutes = PHASE_1_WINDOW
367
- elif phase == 'rate_limit_2':
368
- window_minutes = PHASE_2_WINDOW
369
- else: # rate_limit_3
370
- window_minutes = PHASE_3_WINDOW
371
-
372
- current_time = time.time()
373
- window_key = f"{client_ip}_{phase}"
374
-
375
- # Clean expired records
376
- if window_key in IP_Rate_Limit_Track:
377
- track_data = IP_Rate_Limit_Track[window_key]
378
- # Check if within current time window
379
- if current_time - track_data['start_time'] > window_minutes * 60:
380
- # Time window expired, reset
381
- IP_Rate_Limit_Track[window_key] = {
382
- 'count': 0,
383
- 'start_time': current_time,
384
- 'last_generation': current_time
385
- }
386
- else:
387
- # Initialize
388
- IP_Rate_Limit_Track[window_key] = {
389
- 'count': 0,
390
- 'start_time': current_time,
391
- 'last_generation': current_time
392
- }
393
-
394
- track_data = IP_Rate_Limit_Track[window_key]
395
-
396
- # Check if exceeded limit
397
- if track_data['count'] >= MAX_IMAGES_PER_WINDOW:
398
- # Calculate remaining wait time
399
- elapsed = current_time - track_data['start_time']
400
- wait_time = (window_minutes * 60) - elapsed
401
- wait_minutes = max(0, wait_time / 60)
402
- return True, wait_minutes, track_data['count']
403
-
404
- return False, 0, track_data['count']
405
-
406
- def update_country_stats(client_ip):
407
- """
408
- Update country usage statistics and print periodically
409
- """
410
- global Total_Request_Count, Country_Usage_Stats
411
-
412
- # Get country info
413
- geo_info = IP_Country_Cache.get(client_ip, {"country": "Unknown", "region": "Unknown", "city": "Unknown"})
414
- country = geo_info["country"]
415
-
416
- # Update country stats
417
- if country not in Country_Usage_Stats:
418
- Country_Usage_Stats[country] = 0
419
- Country_Usage_Stats[country] += 1
420
-
421
- # Increment total request counter
422
- Total_Request_Count += 1
423
-
424
- # Print stats every N requests
425
- if Total_Request_Count % PRINT_STATS_INTERVAL == 0:
426
- print("\n" + "="*60)
427
- print(f"📊 国家使用统计 (总请求数: {Total_Request_Count})")
428
- print("="*60)
429
-
430
- # Sort by usage count (descending)
431
- sorted_stats = sorted(Country_Usage_Stats.items(), key=lambda x: x[1], reverse=True)
432
-
433
- for country_name, count in sorted_stats:
434
- percentage = (count / Total_Request_Count) * 100
435
- print(f" {country_name}: {count} 次 ({percentage:.1f}%)")
436
-
437
- print("="*60 + "\n")
438
-
439
- def record_generation_attempt(client_ip, phase):
440
- """
441
- Record generation attempt
442
- """
443
- # Increment total count
444
- increment_ip_generation_count(client_ip)
445
-
446
- # Update country statistics
447
- update_country_stats(client_ip)
448
-
449
- # Record time window count
450
- if phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
451
- window_key = f"{client_ip}_{phase}"
452
- current_time = time.time()
453
-
454
- if window_key in IP_Rate_Limit_Track:
455
- IP_Rate_Limit_Track[window_key]['count'] += 1
456
- IP_Rate_Limit_Track[window_key]['last_generation'] = current_time
457
- else:
458
- IP_Rate_Limit_Track[window_key] = {
459
- 'count': 1,
460
- 'start_time': current_time,
461
- 'last_generation': current_time
462
- }
463
-
464
- def apply_gaussian_blur_to_image_url(image_url, blur_strength=50):
465
- """
466
- Apply Gaussian blur to image URL
467
-
468
- Args:
469
- image_url (str): Original image URL
470
- blur_strength (int): Blur strength, default 50 (heavy blur)
471
-
472
- Returns:
473
- PIL.Image: Blurred PIL Image object
474
- """
475
- try:
476
- import requests
477
- from PIL import Image, ImageFilter
478
- import io
479
-
480
- # Download image
481
- response = requests.get(image_url, timeout=30)
482
- if response.status_code != 200:
483
- return None
484
-
485
- # Convert to PIL Image
486
- image_data = io.BytesIO(response.content)
487
- image = Image.open(image_data)
488
-
489
- # Apply heavy Gaussian blur
490
- blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_strength))
491
-
492
- return blurred_image
493
-
494
- except Exception as e:
495
- print(f"⚠️ Failed to apply Gaussian blur: {e}")
496
- return None
497
-
498
- # Initialize NSFW detector (download from Hugging Face)
499
- try:
500
- nsfw_detector = NSFWDetector() # Auto download falconsai_yolov9_nsfw_model_quantized.pt from Hugging Face
501
- print("✅ NSFW detector initialized successfully")
502
- except Exception as e:
503
- print(f"❌ NSFW detector initialization failed: {e}")
504
- nsfw_detector = None
505
-
506
- def edit_image_interface(input_image, prompt, lang, request: gr.Request, progress=gr.Progress()):
507
- """
508
- Interface function for processing image editing with phase-based limitations
509
- """
510
- # 默认禁用“Use as Input”按钮,待成功生成后再开启
511
- use_as_input_state = gr.update(interactive=False)
512
- try:
513
- # Extract user IP
514
- client_ip = request.client.host
515
- x_forwarded_for = dict(request.headers).get('x-forwarded-for')
516
- if x_forwarded_for:
517
- client_ip = x_forwarded_for
518
- if client_ip not in IP_Dict:
519
- IP_Dict[client_ip] = 0
520
- IP_Dict[client_ip] += 1
521
-
522
- if input_image is None:
523
- return None, t("error_upload_first", lang), gr.update(visible=False), use_as_input_state
524
-
525
- if not prompt or prompt.strip() == "":
526
- return None, t("error_enter_prompt", lang), gr.update(visible=False), use_as_input_state
527
-
528
- # Check if prompt length is greater than 3 characters
529
- if len(prompt.strip()) <= 3:
530
- return None, t("error_prompt_too_short", lang), gr.update(visible=False), use_as_input_state
531
- except Exception as e:
532
- print(f"⚠️ Unexpected error: {e}", flush=True)
533
- return None, t("error_processing_failed", lang), gr.update(visible=False), use_as_input_state
534
-
535
- # Concurrency guard: block if there is an active task within last 3 minutes
536
- try:
537
- now_ts = time.time()
538
- active_info = Active_Tasks.get(client_ip)
539
- if active_info:
540
- start_ts = active_info.get("start", 0)
541
- if now_ts - start_ts <= 180:
542
- return None, "You already have a task in progress. Please wait for it to finish before submitting a new one.", gr.update(visible=False, value=None), use_as_input_state
543
- else:
544
- # Cleanup stale record
545
- Active_Tasks.pop(client_ip, None)
546
- except Exception as e:
547
- print(f"⚠️ Concurrency guard check failed: {e}")
548
-
549
- # Get user current phase
550
- current_phase = get_ip_phase(client_ip)
551
- current_count = get_ip_generation_count(client_ip)
552
- geo_info = IP_Country_Cache.get(client_ip, {"country": "Unknown", "region": "Unknown", "city": "Unknown"})
553
- is_restricted = is_restricted_country_ip(client_ip)
554
-
555
- print(f"📊 User phase info - IP: {client_ip}, Location: {geo_info['country']}/{geo_info['region']}/{geo_info['city']}, Phase: {current_phase}, Count: {current_count}, Restricted: {is_restricted}")
556
-
557
- # Check if user reached the like button tip threshold
558
- # For restricted countries, show like tip from the first attempt
559
- show_like_tip = (current_count >= 1) if is_restricted else (current_count >= TIP_TRY_N)
560
-
561
- # Check if completely blocked
562
- if current_phase == 'blocked':
563
- # Generate blocked limit button with different URL for restricted countries
564
- if is_restricted or lang in ["hi", "ru", "zh"]:
565
- blocked_url = GOOGLE_GEMINI_URL
566
- else:
567
- blocked_url = 'https://omnicreator.net/#generator'
568
-
569
- blocked_button_html = f"""
570
- <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
571
- <a href='{blocked_url}' target='_blank' style='
572
- display: inline-flex;
573
- align-items: center;
574
- justify-content: center;
575
- padding: 16px 32px;
576
- background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
577
- color: white;
578
- text-decoration: none;
579
- border-radius: 12px;
580
- font-weight: 600;
581
- font-size: 16px;
582
- text-align: center;
583
- min-width: 200px;
584
- box-shadow: 0 4px 15px rgba(231, 76, 60, 0.4);
585
- transition: all 0.3s ease;
586
- border: none;
587
- '>&#128640; Unlimited Generation</a>
588
- </div>
589
- """
590
-
591
- # Use same message for all users to avoid discrimination perception
592
- blocked_message = t("error_free_limit_reached", lang)
593
-
594
- return None, blocked_message, gr.update(value=blocked_button_html, visible=True), use_as_input_state
595
-
596
- # Check rate limit (applies to rate_limit phases)
597
- if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
598
- is_limited, wait_minutes, window_count = check_rate_limit_for_phase(client_ip, current_phase)
599
- if is_limited:
600
- wait_minutes_int = int(wait_minutes) + 1
601
- # Generate rate limit button with different URL for restricted countries
602
- if is_restricted or lang in ["hi", "ru", "zh"]:
603
- rate_limit_url = GOOGLE_GEMINI_URL
604
- else:
605
- rate_limit_url = 'https://omnicreator.net/#generator'
606
-
607
- rate_limit_button_html = f"""
608
- <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
609
- <a href='{rate_limit_url}' target='_blank' style='
610
- display: inline-flex;
611
- align-items: center;
612
- justify-content: center;
613
- padding: 16px 32px;
614
- background: linear-gradient(135deg, #f39c12 0%, #e67e22 100%);
615
- color: white;
616
- text-decoration: none;
617
- border-radius: 12px;
618
- font-weight: 600;
619
- font-size: 16px;
620
- text-align: center;
621
- min-width: 200px;
622
- box-shadow: 0 4px 15px rgba(243, 156, 18, 0.4);
623
- transition: all 0.3s ease;
624
- border: none;
625
- '>⏰ Skip Wait - Unlimited Generation</a>
626
- </div>
627
- """
628
- return None, t("error_free_limit_wait", lang).format(wait_minutes_int=wait_minutes_int), gr.update(value=rate_limit_button_html, visible=True), use_as_input_state
629
-
630
- # Handle NSFW detection based on phase
631
- is_nsfw_task = False # Track if this task involves NSFW content
632
-
633
- # Skip NSFW detection in free phase
634
- if current_phase != 'free' and nsfw_detector is not None and input_image is not None:
635
- try:
636
- nsfw_result = nsfw_detector.predict_pil_label_only(input_image)
637
-
638
- if nsfw_result.lower() == "nsfw":
639
- is_nsfw_task = True
640
- use_as_input_state = gr.update(interactive=False)
641
- print(f"🔍 Input NSFW detected in {current_phase} phase: ❌❌❌ {nsfw_result} - IP: {client_ip} (will blur result)")
642
- else:
643
- print(f"🔍 Input NSFW check passed: ✅✅✅ {nsfw_result} - IP: {client_ip}")
644
-
645
- except Exception as e:
646
- print(f"⚠️ Input NSFW detection failed: {e}")
647
- # Allow continuation when detection fails
648
-
649
- result_url = None
650
- status_message = ""
651
- use_as_input_state = gr.update(interactive=True)
652
-
653
- def progress_callback(message):
654
- try:
655
- nonlocal status_message
656
- status_message = message
657
- # Add error handling to prevent progress update failure
658
- if progress is not None:
659
- # Enhanced progress display with better formatting
660
- if "Queue:" in message or "tasks ahead" in message:
661
- # Queue status - show with different progress value to indicate waiting
662
- progress(0.1, desc=message)
663
- elif "Processing" in message or "AI is processing" in message:
664
- # Processing status
665
- progress(0.7, desc=message)
666
- elif "Generating" in message or "Almost done" in message:
667
- # Generation status
668
- progress(0.9, desc=message)
669
- else:
670
- # Default status
671
- progress(0.5, desc=message)
672
- except Exception as e:
673
- print(f"⚠️ Progress update failed: {e}")
674
-
675
- try:
676
- # Determine priority before recording generation attempt
677
- # First high_priority_n tasks for each IP get priority=1
678
- task_priority = 1 if current_count < high_priority_n else 0
679
-
680
- # Record active task start (for concurrency guard)
681
- Active_Tasks[client_ip] = {"start": time.time()}
682
-
683
- # Record generation attempt (before actual generation to ensure correct count)
684
- record_generation_attempt(client_ip, current_phase)
685
- updated_count = get_ip_generation_count(client_ip)
686
-
687
- print(f"✅ Processing started - IP: {client_ip}, phase: {current_phase}, total count: {updated_count}, priority: {task_priority}, prompt: {prompt.strip()}", flush=True)
688
-
689
- # Call image editing processing function with priority
690
- input_image_url, result_url, message, task_uuid = process_image_edit(input_image, prompt.strip(), None, progress_callback, priority=task_priority, client_ip=client_ip)
691
-
692
- # Check if HF user limit exceeded
693
- if message and message.startswith("HF_LIMIT_EXCEEDED:"):
694
- error_message = message.replace("HF_LIMIT_EXCEEDED:", "")
695
- # Generate HF limit exceeded button (similar to blocked status)
696
- hf_limit_url = 'https://omnicreator.net/#generator'
697
-
698
- hf_limit_button_html = f"""
699
- <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
700
- <a href='{hf_limit_url}' target='_blank' style='
701
- display: inline-flex;
702
- align-items: center;
703
- justify-content: center;
704
- padding: 16px 32px;
705
- background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
706
- color: white;
707
- text-decoration: none;
708
- border-radius: 12px;
709
- font-weight: 600;
710
- font-size: 16px;
711
- text-align: center;
712
- min-width: 200px;
713
- box-shadow: 0 4px 15px rgba(231, 76, 60, 0.4);
714
- transition: all 0.3s ease;
715
- border: none;
716
- '>&#128640; Unlimited Generation</a>
717
- </div>
718
- """
719
-
720
- # Use translated message or default
721
- limit_message = error_message if error_message else t("error_free_limit_reached", lang)
722
-
723
- return None, limit_message, gr.update(value=hf_limit_button_html, visible=True), use_as_input_state
724
-
725
- if result_url:
726
- print(f"✅ Processing completed successfully - IP: {client_ip}, result_url: {result_url}, task_uuid: {task_uuid}", flush=True)
727
-
728
- # Detect result image NSFW content (only in rate limit phases)
729
- if nsfw_detector is not None and current_phase != 'free':
730
- try:
731
- if progress is not None:
732
- progress(0.9, desc=t("status_checking_result", lang))
733
-
734
- is_nsfw, nsfw_error = download_and_check_result_nsfw(result_url, nsfw_detector)
735
-
736
- if nsfw_error:
737
- print(f"⚠️ Result image NSFW detection error - IP: {client_ip}, error: {nsfw_error}")
738
- elif is_nsfw:
739
- is_nsfw_task = True # Mark task as NSFW
740
- print(f"🔍 Result image NSFW detected in {current_phase} phase: ❌❌❌ - IP: {client_ip} (will blur result)")
741
- else:
742
- print(f"🔍 Result image NSFW check passed: ✅✅✅ - IP: {client_ip}")
743
-
744
- except Exception as e:
745
- print(f"⚠️ Result image NSFW detection exception - IP: {client_ip}, error: {str(e)}")
746
-
747
- # Apply blur if this is an NSFW task in rate limit phases
748
- should_blur = False
749
-
750
- if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3'] and is_nsfw_task:
751
- should_blur = True
752
-
753
- # Apply blur processing
754
- if should_blur:
755
- if progress is not None:
756
- progress(0.95, desc=t("status_applying_filter", lang))
757
-
758
- blurred_image = apply_gaussian_blur_to_image_url(result_url)
759
- if blurred_image is not None:
760
- final_result = blurred_image # Return PIL Image object
761
- final_message = t("warning_content_filter", lang)
762
- print(f"🔒 Applied Gaussian blur for NSFW content - IP: {client_ip}")
763
- else:
764
- # Blur failed, return original URL with warning
765
- final_result = result_url
766
- final_message = t("warning_content_review", lang)
767
-
768
- # Disable use-as-input when NSFW content is detected
769
- use_as_input_state = gr.update(interactive=False)
770
-
771
- # Generate NSFW button for blurred content with different URL for restricted countries
772
- if is_restricted or lang in ["hi", "ru", "zh"]:
773
- nsfw_url = GOOGLE_GEMINI_URL
774
- else:
775
- nsfw_url = 'https://omnicreator.net/#generator'
776
-
777
- banner_html = """
778
- <div style='margin: 14px auto 0; max-width: 640px; background: linear-gradient(120deg, #f0f4ff 0%, #e5edff 50%, #f7fbff 100%); border: 1px solid #cbd5ff; border-radius: 14px; padding: 14px 18px; box-shadow: 0 10px 25px rgba(88, 101, 242, 0.18); text-align: center;'>
779
- <div style='font-size: 15px; font-weight: 800; color: #1f2a44; display: flex; align-items: center; justify-content: center; gap: 8px;'>
780
- 🚀 Omni Image Editor 2.0 is live!
781
- </div>
782
- <a href='https://huggingface.co/spaces/selfit-camera/Omni-Image-Editor' target='_blank' style='display: inline-flex; align-items: center; justify-content: center; margin-top: 6px; padding: 10px 18px; background: #5865f2; color: white; border-radius: 10px; font-weight: 800; text-decoration: none; box-shadow: 0 6px 18px rgba(88, 101, 242, 0.35);'>
783
- Try the Hugging Face Space demo (free)
784
- </a>
785
- <div style='font-size: 13px; color: #4a5568; margin-top: 6px; font-weight: 600;'>This is a free HF Space demo for Omni Image Editor 2.0.</div>
786
- </div>
787
- """
788
-
789
- nsfw_action_buttons_html = f"""
790
- <div style='text-align: center; margin: 18px 0 10px 0;'>
791
- <a href='{nsfw_url}' target='_blank' style='
792
- display: inline-flex;
793
- align-items: center;
794
- justify-content: center;
795
- padding: 16px 32px;
796
- background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
797
- color: white;
798
- text-decoration: none;
799
- border-radius: 12px;
800
- font-weight: 700;
801
- font-size: 16px;
802
- min-width: 220px;
803
- box-shadow: 0 8px 25px rgba(255, 107, 107, 0.35);
804
- transition: all 0.3s ease;
805
- border: none;
806
- '>🔥 Unlimited Creative Generation</a>
807
- </div>
808
- {banner_html}
809
- """
810
- return final_result, final_message, gr.update(value=nsfw_action_buttons_html, visible=True), use_as_input_state
811
- else:
812
- final_result = result_url
813
- final_message = t("status_completed_message", lang).format(message=message)
814
-
815
- try:
816
- if progress is not None:
817
- progress(1.0, desc=t("status_processing_completed", lang))
818
- except Exception as e:
819
- print(f"⚠️ Final progress update failed: {e}")
820
-
821
- # Generate action buttons HTML
822
- banner_html = """
823
- <div style='margin: 14px auto 0; max-width: 640px; background: linear-gradient(120deg, #f0f4ff 0%, #e5edff 50%, #f7fbff 100%); border: 1px solid #cbd5ff; border-radius: 14px; padding: 14px 18px; box-shadow: 0 10px 25px rgba(88, 101, 242, 0.18); text-align: center;'>
824
- <div style='font-size: 15px; font-weight: 800; color: #1f2a44; display: flex; align-items: center; justify-content: center; gap: 8px;'>
825
- 🚀 Omni Image Editor 2.0 is live!
826
- </div>
827
- <a href='https://huggingface.co/spaces/selfit-camera/Omni-Image-Editor' target='_blank' style='display: inline-flex; align-items: center; justify-content: center; margin-top: 6px; padding: 10px 18px; background: #5865f2; color: white; border-radius: 10px; font-weight: 800; text-decoration: none; box-shadow: 0 6px 18px rgba(88, 101, 242, 0.35);'>
828
- Try the Hugging Face Space demo (free)
829
- </a>
830
- <div style='font-size: 13px; color: #4a5568; margin-top: 6px; font-weight: 600;'>This is a free HF Space demo for Omni Image Editor 2.0.</div>
831
- </div>
832
- """
833
-
834
- action_buttons_html = ""
835
-
836
- # 根据 TIP_TRY_N(受限地区从第一次起就触发)展示点赞提示
837
- if show_like_tip:
838
- action_buttons_html = """
839
- <div style='display: flex; justify-content: center; margin: 15px 0 5px 0; padding: 0px;'>
840
- <div style='
841
- display: inline-flex;
842
- align-items: center;
843
- justify-content: center;
844
- padding: 12px 24px;
845
- background: linear-gradient(135deg, #7c3aed 0%, #6366f1 100%);
846
- color: white;
847
- text-decoration: none;
848
- border-radius: 10px;
849
- font-weight: 600;
850
- font-size: 14px;
851
- text-align: center;
852
- max-width: 400px;
853
- box-shadow: 0 3px 12px rgba(255, 107, 107, 0.3);
854
- border: none;
855
- '>👉 Click the ❤️ Like button to unlock more free trial attempts!</div>
856
- </div>
857
- """
858
-
859
- # Always show the Omni Image Editor 2.0 banner under the result image
860
- action_buttons_html = f"{action_buttons_html}{banner_html}"
861
-
862
- return final_result, final_message, gr.update(value=action_buttons_html, visible=True), use_as_input_state
863
- else:
864
- print(f"❌ Processing failed - IP: {client_ip}, error: {message}", flush=True)
865
- return None, t("error_processing_failed", lang).format(message=message), gr.update(visible=False), use_as_input_state
866
-
867
- except Exception as e:
868
- print(f"❌ Processing exception - IP: {client_ip}, error: {str(e)}")
869
- return None, t("error_processing_exception", lang).format(error=str(e)), gr.update(visible=False), use_as_input_state
870
- finally:
871
- # Task finished (success or failure) — clear active marker to allow next submission immediately
872
- Active_Tasks.pop(client_ip, None)
873
-
874
- # Create Gradio interface
875
- def create_app():
876
- with gr.Blocks(
877
- title="Image Editor 1.0",
878
- theme=gr.themes.Soft(),
879
- css="""
880
- .main-container {
881
- max-width: 1200px;
882
- margin: 0 auto;
883
- }
884
- .news-banner-row {
885
- margin: 10px auto 15px auto;
886
- padding: 0 10px;
887
- max-width: 1200px;
888
- width: 100% !important;
889
- }
890
- .news-banner-row .gr-row {
891
- display: flex !important;
892
- align-items: center !important;
893
- width: 100% !important;
894
- }
895
- .news-banner-row .gr-column:first-child {
896
- flex: 1 !important; /* 占据所有剩余空间 */
897
- display: flex !important;
898
- justify-content: center !important; /* 在其空间内居中 */
899
- }
900
- .banner-lang-selector {
901
- margin-left: auto !important;
902
- display: flex !important;
903
- justify-content: flex-end !important;
904
- align-items: center !important;
905
- position: relative !important;
906
- z-index: 10 !important;
907
- }
908
- .banner-lang-selector .gr-dropdown {
909
- background: white !important;
910
- border: 1px solid #ddd !important;
911
- border-radius: 8px !important;
912
- padding: 8px 16px !important;
913
- font-size: 14px !important;
914
- font-weight: 500 !important;
915
- color: #333 !important;
916
- cursor: pointer !important;
917
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
918
- min-width: 140px !important;
919
- max-width: 160px !important;
920
- transition: all 0.2s ease !important;
921
- }
922
- .banner-lang-selector .gr-dropdown:hover {
923
- border-color: #999 !important;
924
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15) !important;
925
- }
926
- @media (max-width: 768px) {
927
- .news-banner-row {
928
- padding: 0 15px !important;
929
- }
930
- .news-banner-row .gr-row {
931
- display: flex !important;
932
- flex-direction: column !important;
933
- gap: 10px !important;
934
- position: static !important;
935
- }
936
- .news-banner-row .gr-column:first-child {
937
- position: static !important;
938
- pointer-events: auto !important;
939
- }
940
- .banner-lang-selector {
941
- margin-left: 0 !important;
942
- justify-content: center !important;
943
- }
944
- }
945
- .upload-area {
946
- border: 2px dashed #ccc;
947
- border-radius: 10px;
948
- padding: 20px;
949
- text-align: center;
950
- }
951
- .result-area {
952
- margin-top: 20px;
953
- padding: 20px;
954
- border-radius: 10px;
955
- background-color: #f8f9fa;
956
- }
957
- .use-as-input-btn {
958
- margin-top: 10px;
959
- width: 100%;
960
- }
961
- """,
962
- # Improve concurrency performance configuration
963
- head="""
964
- <script>
965
- // Reduce client-side state update frequency, avoid excessive SSE connections
966
- if (window.gradio) {
967
- window.gradio.update_frequency = 2000; // Update every 2 seconds
968
- }
969
- </script>
970
- """
971
- ) as app:
972
-
973
- lang_state = gr.State("en")
974
-
975
- # Main title - centered
976
- header_title = gr.HTML(f"""
977
- <div style="text-align: center; margin: 20px auto 10px auto; max-width: 800px;">
978
- <h1 style="color: #2c3e50; margin: 0; font-size: 3.5em; font-weight: 800; letter-spacing: 3px; text-shadow: 2px 2px 4px rgba(0,0,0,0.1);">
979
- {t('header_title', 'en')}
980
- </h1>
981
- </div>
982
- """)
983
-
984
- with gr.Row(elem_classes=["news-banner-row"]):
985
- with gr.Column(scale=1, min_width=400):
986
- # Banner is initially visible (will be hidden for zh/hi/ru languages on load)
987
- news_banner = gr.HTML(f"""
988
- <style>
989
- @keyframes breathe {{
990
- 0%, 100% {{ transform: scale(1); }}
991
- 50% {{ transform: scale(1.02); }}
992
- }}
993
- .breathing-banner {{
994
- animation: breathe 3s ease-in-out infinite;
995
- }}
996
- </style>
997
- <div class="breathing-banner" style="
998
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
999
- margin: 0 auto;
1000
- padding: 8px 40px;
1001
- border-radius: 20px;
1002
- max-width: 800px;
1003
- box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
1004
- text-align: center;
1005
- width: fit-content;
1006
- ">
1007
- <span style="color: white; font-weight: 600; font-size: 1.0em;">
1008
- 🎉 NEW:
1009
- <a href="https://huggingface.co/spaces/selfit-camera/Omni-Image-Editor" target="_blank" style="
1010
- color: white;
1011
- text-decoration: none;
1012
- border-bottom: 1px solid rgba(255,255,255,0.5);
1013
- transition: all 0.3s ease;
1014
- margin: 0 8px;
1015
- " onmouseover="this.style.borderBottom='1px solid white'"
1016
- onmouseout="this.style.borderBottom='1px solid rgba(255,255,255,0.5)'">
1017
- Image Editor 2.0
1018
- </a>
1019
- is Online Now ! More free trials, better quality!
1020
- </span>
1021
- </div>
1022
- """, visible=True)
1023
-
1024
- with gr.Column(scale=0, min_width=160, elem_classes=["banner-lang-selector"]):
1025
- # Lock UI to English only; allow_custom_value avoids Gradio errors if any non-en value is set programmatically
1026
- lang_dropdown = gr.Dropdown(
1027
- choices=[
1028
- ("English", "en"),
1029
- ],
1030
- value="en",
1031
- label="🌐",
1032
- show_label=True,
1033
- interactive=True,
1034
- container=False,
1035
- allow_custom_value=True,
1036
- )
1037
-
1038
- with gr.Tabs() as tabs:
1039
- with gr.Tab(t("global_editor_tab", "en")) as global_tab:
1040
- with gr.Row():
1041
- with gr.Column(scale=1):
1042
- upload_image_header = gr.Markdown(t("upload_image_header", "en"))
1043
- input_image = gr.Image(
1044
- label=t("upload_image_label", "en"),
1045
- type="pil",
1046
- height=512,
1047
- elem_classes=["upload-area"]
1048
- )
1049
-
1050
- editing_instructions_header = gr.Markdown(t("editing_instructions_header", "en"))
1051
- prompt_input = gr.Textbox(
1052
- label=t("prompt_input_label", "en"),
1053
- placeholder=t("prompt_input_placeholder", "en"),
1054
- lines=3,
1055
- max_lines=5
1056
- )
1057
-
1058
- edit_button = gr.Button(
1059
- t("start_editing_button", "en"),
1060
- variant="primary",
1061
- size="lg"
1062
- )
1063
-
1064
- with gr.Column(scale=1):
1065
- editing_result_header = gr.Markdown(t("editing_result_header", "en"))
1066
- output_image = gr.Image(
1067
- label=t("output_image_label", "en"),
1068
- height=320,
1069
- elem_classes=["result-area"]
1070
- )
1071
-
1072
- use_as_input_btn = gr.Button(
1073
- t("use_as_input_button", "en"),
1074
- variant="secondary",
1075
- size="sm",
1076
- elem_classes=["use-as-input-btn"]
1077
- )
1078
-
1079
- status_output = gr.Textbox(
1080
- label=t("status_output_label", "en"),
1081
- lines=2,
1082
- max_lines=3,
1083
- interactive=False
1084
- )
1085
-
1086
- action_buttons = gr.HTML(visible=False)
1087
-
1088
- prompt_examples_header = gr.Markdown(t("prompt_examples_header", "en"))
1089
- with gr.Row():
1090
- example_prompts = [
1091
- "Set the background to a grand opera stage with red curtains",
1092
- "Change the outfit into a traditional Chinese hanfu with flowing sleeves",
1093
- "Give the character blue dragon-like eyes with glowing pupils",
1094
- "Change lighting to soft dreamy pastel glow",
1095
- "Change pose to sitting cross-legged on the ground"
1096
- ]
1097
-
1098
- for prompt in example_prompts:
1099
- gr.Button(
1100
- prompt,
1101
- size="sm"
1102
- ).click(
1103
- lambda p=prompt: p,
1104
- outputs=prompt_input
1105
- )
1106
-
1107
- edit_button.click(
1108
- fn=edit_image_interface,
1109
- inputs=[input_image, prompt_input, lang_state],
1110
- outputs=[output_image, status_output, action_buttons, use_as_input_btn],
1111
- show_progress=True,
1112
- concurrency_limit=20
1113
- )
1114
-
1115
- def simple_use_as_input(output_img):
1116
- if output_img is not None:
1117
- return output_img
1118
- return None
1119
-
1120
- use_as_input_btn.click(
1121
- fn=simple_use_as_input,
1122
- inputs=[output_image],
1123
- outputs=[input_image]
1124
- )
1125
-
1126
- # SEO Content Section
1127
- seo_html = gr.HTML()
1128
-
1129
- def get_seo_html(lang):
1130
- # 中文、印度语、俄语不显示SEO部分
1131
- if lang in ["zh", "hi", "ru"]:
1132
- return ""
1133
-
1134
- return f"""
1135
- <div style="width: 100%; margin: 50px 0; padding: 0 20px;">
1136
-
1137
- <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 40px; border-radius: 20px; margin: 40px 0;">
1138
- <h2 style="margin: 0 0 20px 0; font-size: 2.2em; font-weight: 700;">
1139
- &#127912; {t('seo_unlimited_title', lang)}
1140
- </h2>
1141
- <p style="margin: 0 0 25px 0; font-size: 1.2em; opacity: 0.95; line-height: 1.6;">
1142
- {t('seo_unlimited_desc', lang)}
1143
- </p>
1144
-
1145
- <div style="display: flex; justify-content: center; gap: 25px; flex-wrap: wrap; margin: 30px 0;">
1146
- <a href="https://omnicreator.net/#generator" target="_blank" style="
1147
- display: inline-flex;
1148
- align-items: center;
1149
- justify-content: center;
1150
- padding: 20px 40px;
1151
- background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
1152
- color: white;
1153
- text-decoration: none;
1154
- border-radius: 15px;
1155
- font-weight: 700;
1156
- font-size: 18px;
1157
- text-align: center;
1158
- min-width: 250px;
1159
- box-shadow: 0 8px 25px rgba(255, 107, 107, 0.4);
1160
- transition: all 0.3s ease;
1161
- border: none;
1162
- transform: scale(1);
1163
- " onmouseover="this.style.transform='scale(1.05)'" onmouseout="this.style.transform='scale(1)'">
1164
- &#128640; {t('seo_unlimited_button', lang)}
1165
- </a>
1166
-
1167
- </div>
1168
-
1169
- <p style="color: rgba(255,255,255,0.9); font-size: 1em; margin: 20px 0 0 0;">
1170
- {t('seo_unlimited_footer', lang)}
1171
- </p>
1172
- </div>
1173
-
1174
- <div style="text-align: center; margin: 25px auto; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); padding: 35px; border-radius: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.1);">
1175
- <h2 style="color: #2c3e50; margin: 0 0 20px 0; font-size: 1.9em; font-weight: 700;">
1176
- &#11088; {t('seo_professional_title', lang)}
1177
- </h2>
1178
- <p style="color: #555; font-size: 1.1em; line-height: 1.6; margin: 0 0 20px 0; padding: 0 20px;">
1179
- {t('seo_professional_desc', lang)}
1180
- </p>
1181
- </div>
1182
-
1183
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 25px; margin: 40px 0;">
1184
-
1185
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #e74c3c;">
1186
- <h3 style="color: #e74c3c; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1187
- &#127919; {t('seo_feature1_title', lang)}
1188
- </h3>
1189
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1190
- {t('seo_feature1_desc', lang)}
1191
- </p>
1192
- </div>
1193
-
1194
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #3498db;">
1195
- <h3 style="color: #3498db; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1196
- 🔓 {t('seo_feature2_title', lang)}
1197
- </h3>
1198
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1199
- {t('seo_feature2_desc', lang)}
1200
- </p>
1201
- </div>
1202
-
1203
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #27ae60;">
1204
- <h3 style="color: #27ae60; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1205
- &#9889; {t('seo_feature3_title', lang)}
1206
- </h3>
1207
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1208
- {t('seo_feature3_desc', lang)}
1209
- </p>
1210
- </div>
1211
-
1212
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #9b59b6;">
1213
- <h3 style="color: #9b59b6; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1214
- &#127912; {t('seo_feature4_title', lang)}
1215
- </h3>
1216
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1217
- {t('seo_feature4_desc', lang)}
1218
- </p>
1219
- </div>
1220
-
1221
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #f39c12;">
1222
- <h3 style="color: #f39c12; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1223
- &#128142; {t('seo_feature5_title', lang)}
1224
- </h3>
1225
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1226
- {t('seo_feature5_desc', lang)}
1227
- </p>
1228
- </div>
1229
-
1230
- <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #34495e;">
1231
- <h3 style="color: #34495e; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1232
- 🌍 {t('seo_feature6_title', lang)}
1233
- </h3>
1234
- <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1235
- {t('seo_feature6_desc', lang)}
1236
- </p>
1237
- </div>
1238
-
1239
- </div>
1240
-
1241
- <div style="background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 50%, #fecfef 100%); padding: 30px; border-radius: 15px; margin: 40px 0;">
1242
- <h3 style="color: #8b5cf6; text-align: center; margin: 0 0 25px 0; font-size: 1.5em; font-weight: 700;">
1243
- &#128161; {t('seo_protips_title', lang)}
1244
- </h3>
1245
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 18px;">
1246
-
1247
- <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1248
- <strong style="color: #8b5cf6; font-size: 1.1em;">📝 {t('seo_protip1_title', lang)}</strong>
1249
- <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">{t('seo_protip1_desc', lang)}</p>
1250
- </div>
1251
-
1252
- <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1253
- <strong style="color: #8b5cf6; font-size: 1.1em;">&#127919; {t('seo_protip2_title', lang)}</strong>
1254
- <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">{t('seo_protip2_desc', lang)}</p>
1255
- </div>
1256
-
1257
- <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1258
- <strong style="color: #8b5cf6; font-size: 1.1em;">&#9889; {t('seo_protip3_title', lang)}</strong>
1259
- <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">{t('seo_protip3_desc', lang)}</p>
1260
- </div>
1261
-
1262
- <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1263
- <strong style="color: #8b5cf6; font-size: 1.1em;">&#128444; {t('seo_protip4_title', lang)}</strong>
1264
- <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">{t('seo_protip4_desc', lang)}</p>
1265
- </div>
1266
-
1267
- </div>
1268
- </div>
1269
-
1270
- <div style="text-align: center; margin: 25px auto; background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%); padding: 35px; border-radius: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.1);">
1271
- <h2 style="color: #2c3e50; margin: 0 0 20px 0; font-size: 1.8em; font-weight: 700;">
1272
- &#128640; {t('seo_needs_title', lang)}
1273
- </h2>
1274
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px; margin: 25px 0; text-align: left;">
1275
-
1276
- <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1277
- <h4 style="color: #e74c3c; margin: 0 0 10px 0;">🎨 {t('seo_needs_art_title', lang)}</h4>
1278
- <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1279
- <li>{t('seo_needs_art_item1', lang)}</li>
1280
- <li>{t('seo_needs_art_item2', lang)}</li>
1281
- <li>{t('seo_needs_art_item3', lang)}</li>
1282
- <li>{t('seo_needs_art_item4', lang)}</li>
1283
- </ul>
1284
- </div>
1285
-
1286
- <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1287
- <h4 style="color: #3498db; margin: 0 0 10px 0;">📸 {t('seo_needs_photo_title', lang)}</h4>
1288
- <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1289
- <li>{t('seo_needs_photo_item1', lang)}</li>
1290
- <li>{t('seo_needs_photo_item2', lang)}</li>
1291
- <li>{t('seo_needs_photo_item3', lang)}</li>
1292
- <li>{t('seo_needs_photo_item4', lang)}</li>
1293
- </ul>
1294
- </div>
1295
-
1296
- <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1297
- <h4 style="color: #27ae60; margin: 0 0 10px 0;">🛍️ {t('seo_needs_ecom_title', lang)}</h4>
1298
- <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1299
- <li>{t('seo_needs_ecom_item1', lang)}</li>
1300
- <li>{t('seo_needs_ecom_item2', lang)}</li>
1301
- <li>{t('seo_needs_ecom_item3', lang)}</li>
1302
- <li>{t('seo_needs_ecom_item4', lang)}</li>
1303
- </ul>
1304
- </div>
1305
-
1306
- <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1307
- <h4 style="color: #9b59b6; margin: 0 0 10px 0;">📱 {t('seo_needs_social_title', lang)}</h4>
1308
- <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1309
- <li>{t('seo_needs_social_item1', lang)}</li>
1310
- <li>{t('seo_needs_social_item2', lang)}</li>
1311
- <li>{t('seo_needs_social_item3', lang)}</li>
1312
- <li>{t('seo_needs_social_item4', lang)}</li>
1313
- </ul>
1314
- </div>
1315
-
1316
- </div>
1317
- </div>
1318
-
1319
- </div>
1320
- """
1321
-
1322
- all_ui_components = [
1323
- header_title, news_banner,
1324
- global_tab, upload_image_header, input_image, editing_instructions_header, prompt_input, edit_button,
1325
- editing_result_header, output_image, use_as_input_btn, status_output, prompt_examples_header,
1326
- seo_html,
1327
- ]
1328
-
1329
- def update_ui_lang(lang):
1330
- # Hide banner for zh, hi, ru languages
1331
- show_banner = lang not in ["zh", "hi", "ru"]
1332
-
1333
- return {
1334
- header_title: gr.update(value=f"""
1335
- <div style="text-align: center; margin: 20px auto 10px auto; max-width: 800px;">
1336
- <h1 style="color: #2c3e50; margin: 0; font-size: 3.5em; font-weight: 800; letter-spacing: 3px; text-shadow: 2px 2px 4px rgba(0,0,0,0.1);">
1337
- {t('header_title', lang)}
1338
- </h1>
1339
- </div>"""),
1340
- news_banner: gr.update(visible=show_banner),
1341
- global_tab: gr.update(label=t("global_editor_tab", lang)),
1342
- upload_image_header: gr.update(value=t("upload_image_header", lang)),
1343
- input_image: gr.update(label=t("upload_image_label", lang)),
1344
- editing_instructions_header: gr.update(value=t("editing_instructions_header", lang)),
1345
- prompt_input: gr.update(label=t("prompt_input_label", lang), placeholder=t("prompt_input_placeholder", lang)),
1346
- edit_button: gr.update(value=t("start_editing_button", lang)),
1347
- editing_result_header: gr.update(value=t("editing_result_header", lang)),
1348
- output_image: gr.update(label=t("output_image_label", lang)),
1349
- use_as_input_btn: gr.update(value=t("use_as_input_button", lang)),
1350
- status_output: gr.update(label=t("status_output_label", lang)),
1351
- prompt_examples_header: gr.update(value=t("prompt_examples_header", lang)),
1352
- seo_html: gr.update(value=get_seo_html(lang)),
1353
- }
1354
-
1355
- def on_lang_change(lang):
1356
- # Force UI to stay in English regardless of dropdown value
1357
- return "en", *update_ui_lang("en").values()
1358
-
1359
- lang_dropdown.change(
1360
- on_lang_change,
1361
- inputs=[lang_dropdown],
1362
- outputs=[lang_state] + all_ui_components
1363
- )
1364
-
1365
- # IP query state for async loading
1366
- ip_query_state = gr.State({"status": "pending", "ip": None, "lang": "en"})
1367
-
1368
- def on_load_immediate(request: gr.Request):
1369
- """
1370
- Load page with language based on robust IP detection
1371
-
1372
- Features:
1373
- - Multiple fallback layers for IP extraction
1374
- - Comprehensive error handling
1375
- - Always returns valid language (defaults to English)
1376
- - Detailed logging for debugging
1377
- """
1378
- # Extract client IP with multiple fallback methods
1379
- client_ip = None
1380
- try:
1381
- # Primary method: direct client host
1382
- client_ip = request.client.host
1383
-
1384
- # Secondary method: check forwarded headers
1385
- headers = dict(request.headers) if hasattr(request, 'headers') else {}
1386
- x_forwarded_for = headers.get('x-forwarded-for') or headers.get('X-Forwarded-For')
1387
- if x_forwarded_for:
1388
- # Take first IP from comma-separated list
1389
- client_ip = x_forwarded_for.split(',')[0].strip()
1390
-
1391
- # Alternative headers
1392
- if not client_ip or client_ip in ["127.0.0.1", "localhost"]:
1393
- client_ip = headers.get('x-real-ip') or headers.get('X-Real-IP') or client_ip
1394
-
1395
- except Exception as e:
1396
- print(f"Error extracting client IP: {e}, using default")
1397
- client_ip = "unknown"
1398
-
1399
- # Validate extracted IP
1400
- if not client_ip:
1401
- client_ip = "unknown"
1402
-
1403
- print(f"Loading page for IP: {client_ip}")
1404
-
1405
- # Determine language with robust error handling
1406
- try:
1407
- # Check if IP is already cached (second+ visit)
1408
- if client_ip in IP_Country_Cache:
1409
- # Use cached data - but force English UI
1410
- cached_lang = "en"
1411
- print(f"Using cached language (forced to en) for IP: {client_ip}")
1412
- query_state = {"ip": client_ip, "cached": True}
1413
- return cached_lang, cached_lang, query_state, *update_ui_lang(cached_lang).values()
1414
-
1415
- # First visit: Query IP and determine language (max 3s timeout built-in)
1416
- print(f"First visit - detecting language for IP: {client_ip}")
1417
- # Always force English UI even if detection yields another language
1418
- detected_lang = "en"
1419
-
1420
- print(f"First visit - Final language forced to: {detected_lang} for IP: {client_ip}")
1421
- query_state = {"ip": client_ip, "cached": False}
1422
- return detected_lang, detected_lang, query_state, *update_ui_lang(detected_lang).values()
1423
-
1424
- except Exception as e:
1425
- # Ultimate fallback - always works
1426
- print(f"Critical error in language detection for {client_ip}: {e}")
1427
- print("Using English as ultimate fallback")
1428
- query_state = {"ip": client_ip or "unknown", "cached": False, "error": str(e)}
1429
- return "en", "en", query_state, *update_ui_lang("en").values()
1430
-
1431
-
1432
- app.load(
1433
- on_load_immediate,
1434
- inputs=None,
1435
- outputs=[lang_state, lang_dropdown, ip_query_state] + all_ui_components,
1436
- )
1437
-
1438
- return app
1439
-
1440
- if __name__ == "__main__":
1441
- app = create_app()
1442
- # Improve queue configuration to handle high concurrency and prevent SSE connection issues
1443
- app.queue(
1444
- default_concurrency_limit=20, # Default concurrency limit
1445
- max_size=50, # Maximum queue size
1446
- api_open=False # Close API access to reduce resource consumption
1447
- )
1448
- app.launch(
1449
- server_name="0.0.0.0",
1450
- show_error=True, # Show detailed error information
1451
- quiet=False, # Keep log output
1452
- max_threads=40, # Increase thread pool size
1453
- height=800,
1454
- favicon_path=None # Reduce resource loading
1455
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__lib__/i18n/__init__.py DELETED
@@ -1,36 +0,0 @@
1
- """
2
- i18n loader for encrypted translation files
3
- """
4
- import sys
5
- import importlib.util
6
- from pathlib import Path
7
-
8
- def load_pyc_module(module_name, pyc_path):
9
- """Load a .pyc module using importlib"""
10
- spec = importlib.util.spec_from_file_location(module_name, pyc_path)
11
- if spec is None or spec.loader is None:
12
- raise ImportError(f"Cannot load module {module_name} from {pyc_path}")
13
- module = importlib.util.module_from_spec(spec)
14
- sys.modules[module_name] = module
15
- spec.loader.exec_module(module)
16
- return module
17
-
18
- def load_translations():
19
- """Load all encrypted translation files"""
20
- translations = {}
21
- i18n_dir = Path(__file__).parent
22
-
23
- # List all .pyc files in i18n directory
24
- for pyc_file in i18n_dir.glob("*.pyc"):
25
- lang = pyc_file.stem # Get language code from filename
26
- try:
27
- module = load_pyc_module(f"i18n_{lang}", pyc_file)
28
- if hasattr(module, 'data'):
29
- translations[lang] = module.data
30
- except Exception as e:
31
- print(f"Failed to load {pyc_file.name}: {e}")
32
-
33
- return translations
34
-
35
- # Auto-load translations when module is imported
36
- translations = load_translations()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__lib__/i18n/ar.pyc DELETED
Binary file (12.3 kB)
 
__lib__/i18n/da.pyc DELETED
Binary file (9.79 kB)
 
__lib__/i18n/de.pyc DELETED
Binary file (10.3 kB)
 
__lib__/i18n/en.pyc DELETED
Binary file (9.08 kB)
 
__lib__/i18n/es.pyc DELETED
Binary file (10.3 kB)
 
__lib__/i18n/fi.pyc DELETED
Binary file (9.79 kB)
 
__lib__/i18n/fr.pyc DELETED
Binary file (10.8 kB)
 
__lib__/i18n/he.pyc DELETED
Binary file (11.5 kB)
 
__lib__/i18n/hi.pyc DELETED
Binary file (16.9 kB)
 
__lib__/i18n/id.pyc DELETED
Binary file (9.73 kB)
 
__lib__/i18n/it.pyc DELETED
Binary file (10.1 kB)
 
__lib__/i18n/ja.pyc DELETED
Binary file (11 kB)
 
__lib__/i18n/nl.pyc DELETED
Binary file (9.85 kB)
 
__lib__/i18n/no.pyc DELETED
Binary file (9.69 kB)
 
__lib__/i18n/pt.pyc DELETED
Binary file (10.2 kB)
 
__lib__/i18n/ru.pyc DELETED
Binary file (15 kB)
 
__lib__/i18n/sv.pyc DELETED
Binary file (9.77 kB)
 
__lib__/i18n/tr.pyc DELETED
Binary file (10.3 kB)
 
__lib__/i18n/uk.pyc DELETED
Binary file (14.5 kB)
 
__lib__/i18n/vi.pyc DELETED
Binary file (11.5 kB)
 
__lib__/i18n/zh.pyc DELETED
Binary file (8.95 kB)
 
__lib__/nfsw.pyc DELETED
Binary file (10 kB)
 
__lib__/pipeline.pyc DELETED
Binary file (83.1 kB)
 
__lib__/util.pyc DELETED
Binary file (18.6 kB)
 
app.py CHANGED
@@ -1,60 +1,1448 @@
1
- """
2
- Minimal app loader for ImageEditSpace
3
- This app loads the compiled, obfuscated modules from __lib__
4
- """
5
- import sys
6
- from pathlib import Path
7
- import importlib.util
8
-
9
- # Add __lib__ to path to import compiled modules
10
- lib_dir = Path(__file__).parent / "__lib__"
11
- if not lib_dir.exists():
12
- raise RuntimeError(f"Compiled library directory not found: {lib_dir}")
13
-
14
- sys.path.insert(0, str(lib_dir))
15
-
16
- def load_pyc_module(module_name, pyc_path):
17
- """Load a .pyc module using importlib"""
18
- spec = importlib.util.spec_from_file_location(module_name, pyc_path)
19
- if spec is None or spec.loader is None:
20
- raise ImportError(f"Cannot load module {module_name} from {pyc_path}")
21
- module = importlib.util.module_from_spec(spec)
22
- sys.modules[module_name] = module
23
- spec.loader.exec_module(module)
24
- return module
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- # Load compiled modules
28
- util_module = load_pyc_module("util", lib_dir / "util.pyc")
29
- nfsw_module = load_pyc_module("nfsw", lib_dir / "nfsw.pyc")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Import app module (source file)
32
- import app as app_module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Create and launch app
35
- app = app_module.create_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  app.queue(
37
- default_concurrency_limit=20,
38
- max_size=50,
39
- api_open=False
40
  )
41
  app.launch(
42
  server_name="0.0.0.0",
43
- show_error=True,
44
- quiet=False,
45
- max_threads=40,
46
  height=800,
47
- favicon_path=None
48
- )
49
-
50
- except ImportError as e:
51
- print(f"Failed to import compiled modules: {e}")
52
- print("Make sure to run build_encrypted.py first to compile the modules")
53
- import traceback
54
- traceback.print_exc()
55
- sys.exit(1)
56
- except Exception as e:
57
- print(f"Error running app: {e}")
58
- import traceback
59
- traceback.print_exc()
60
- sys.exit(1)
 
1
+ import gradio as gr
2
+ import threading
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+ import time
7
+ from util import process_image_edit, process_local_image_edit, download_and_check_result_nsfw
8
+ from nfsw import NSFWDetector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Configuration parameters
11
+
12
+ TIP_TRY_N = 8 # Show like button tip after 12 tries
13
+ FREE_TRY_N = 20 # Free phase: first 15 tries without restrictions
14
+ SLOW_TRY_N = 25 # Slow phase start: 25 tries
15
+ SLOW2_TRY_N = 32 # Slow phase start: 32 tries
16
+ RATE_LIMIT_60 = 40 # Full restriction: blocked after 40 tries
17
+
18
+ # Time window configuration (minutes)
19
+ PHASE_1_WINDOW = 5 # 15-25 tries: 5 minutes
20
+ PHASE_2_WINDOW = 10 # 25-32 tries: 10 minutes
21
+ PHASE_3_WINDOW = 20 # 32-40 tries: 20 minutes
22
+ MAX_IMAGES_PER_WINDOW = 2 # Max images per time window
23
+
24
+ IP_Dict = {}
25
+ # IP generation statistics and time window tracking
26
+ IP_Generation_Count = {} # Record total generation count for each IP
27
+ IP_Rate_Limit_Track = {} # Record generation count and timestamp in current time window for each IP
28
+
29
+ def get_ip_generation_count(client_ip):
30
+ """
31
+ Get IP generation count
32
+ """
33
+ if client_ip not in IP_Generation_Count:
34
+ IP_Generation_Count[client_ip] = 0
35
+ return IP_Generation_Count[client_ip]
36
+
37
+ def increment_ip_generation_count(client_ip):
38
+ """
39
+ Increment IP generation count
40
+ """
41
+ if client_ip not in IP_Generation_Count:
42
+ IP_Generation_Count[client_ip] = 0
43
+ IP_Generation_Count[client_ip] += 1
44
+ return IP_Generation_Count[client_ip]
45
+
46
+ def get_ip_phase(client_ip):
47
+ """
48
+ Get current phase for IP
49
+
50
+ Returns:
51
+ str: 'free', 'rate_limit_1', 'rate_limit_2', 'rate_limit_3', 'blocked'
52
+ """
53
+ count = get_ip_generation_count(client_ip)
54
+
55
+ if count < FREE_TRY_N:
56
+ return 'free'
57
+ elif count < SLOW_TRY_N:
58
+ return 'rate_limit_1' # NSFW blur + 5 minutes 2 images
59
+ elif count < SLOW2_TRY_N:
60
+ return 'rate_limit_2' # NSFW blur + 10 minutes 2 images
61
+ elif count < RATE_LIMIT_60:
62
+ return 'rate_limit_3' # NSFW blur + 20 minutes 2 images
63
+ else:
64
+ return 'blocked' # Generation blocked
65
+
66
+ def check_rate_limit_for_phase(client_ip, phase):
67
+ """
68
+ Check rate limit for specific phase
69
+
70
+ Returns:
71
+ tuple: (is_limited, wait_time_minutes, current_count)
72
+ """
73
+ if phase not in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
74
+ return False, 0, 0
75
+
76
+ # Determine time window
77
+ if phase == 'rate_limit_1':
78
+ window_minutes = PHASE_1_WINDOW
79
+ elif phase == 'rate_limit_2':
80
+ window_minutes = PHASE_2_WINDOW
81
+ else: # rate_limit_3
82
+ window_minutes = PHASE_3_WINDOW
83
+
84
+ current_time = time.time()
85
+ window_key = f"{client_ip}_{phase}"
86
+
87
+ # Clean expired records
88
+ if window_key in IP_Rate_Limit_Track:
89
+ track_data = IP_Rate_Limit_Track[window_key]
90
+ # Check if within current time window
91
+ if current_time - track_data['start_time'] > window_minutes * 60:
92
+ # Time window expired, reset
93
+ IP_Rate_Limit_Track[window_key] = {
94
+ 'count': 0,
95
+ 'start_time': current_time,
96
+ 'last_generation': current_time
97
+ }
98
+ else:
99
+ # Initialize
100
+ IP_Rate_Limit_Track[window_key] = {
101
+ 'count': 0,
102
+ 'start_time': current_time,
103
+ 'last_generation': current_time
104
+ }
105
+
106
+ track_data = IP_Rate_Limit_Track[window_key]
107
+
108
+ # Check if exceeded limit
109
+ if track_data['count'] >= MAX_IMAGES_PER_WINDOW:
110
+ # Calculate remaining wait time
111
+ elapsed = current_time - track_data['start_time']
112
+ wait_time = (window_minutes * 60) - elapsed
113
+ wait_minutes = max(0, wait_time / 60)
114
+ return True, wait_minutes, track_data['count']
115
+
116
+ return False, 0, track_data['count']
117
+
118
+ def record_generation_attempt(client_ip, phase):
119
+ """
120
+ Record generation attempt
121
+ """
122
+ # Increment total count
123
+ increment_ip_generation_count(client_ip)
124
+
125
+ # Record time window count
126
+ if phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
127
+ window_key = f"{client_ip}_{phase}"
128
+ current_time = time.time()
129
+
130
+ if window_key in IP_Rate_Limit_Track:
131
+ IP_Rate_Limit_Track[window_key]['count'] += 1
132
+ IP_Rate_Limit_Track[window_key]['last_generation'] = current_time
133
+ else:
134
+ IP_Rate_Limit_Track[window_key] = {
135
+ 'count': 1,
136
+ 'start_time': current_time,
137
+ 'last_generation': current_time
138
+ }
139
+
140
+ def apply_gaussian_blur_to_image_url(image_url, blur_strength=50):
141
+ """
142
+ Apply Gaussian blur to image URL
143
+
144
+ Args:
145
+ image_url (str): Original image URL
146
+ blur_strength (int): Blur strength, default 50 (heavy blur)
147
+
148
+ Returns:
149
+ PIL.Image: Blurred PIL Image object
150
+ """
151
+ try:
152
+ import requests
153
+ from PIL import Image, ImageFilter
154
+ import io
155
+
156
+ # Download image
157
+ response = requests.get(image_url, timeout=30)
158
+ if response.status_code != 200:
159
+ return None
160
+
161
+ # Convert to PIL Image
162
+ image_data = io.BytesIO(response.content)
163
+ image = Image.open(image_data)
164
+
165
+ # Apply heavy Gaussian blur
166
+ blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_strength))
167
+
168
+ return blurred_image
169
+
170
+ except Exception as e:
171
+ print(f"⚠️ Failed to apply Gaussian blur: {e}")
172
+ return None
173
+
174
+ # Initialize NSFW detector (download from Hugging Face)
175
  try:
176
+ nsfw_detector = NSFWDetector() # Auto download falconsai_yolov9_nsfw_model_quantized.pt from Hugging Face
177
+ print(" NSFW detector initialized successfully")
178
+ except Exception as e:
179
+ print(f"❌ NSFW detector initialization failed: {e}")
180
+ nsfw_detector = None
181
+
182
+ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.Progress()):
183
+ """
184
+ Interface function for processing image editing with phase-based limitations
185
+ """
186
+ try:
187
+ # Extract user IP
188
+ client_ip = request.client.host
189
+ x_forwarded_for = dict(request.headers).get('x-forwarded-for')
190
+ if x_forwarded_for:
191
+ client_ip = x_forwarded_for
192
+ if client_ip not in IP_Dict:
193
+ IP_Dict[client_ip] = 0
194
+ IP_Dict[client_ip] += 1
195
+
196
+ if input_image is None:
197
+ return None, "Please upload an image first", gr.update(visible=False)
198
+
199
+ if not prompt or prompt.strip() == "":
200
+ return None, "Please enter editing prompt", gr.update(visible=False)
201
+
202
+ # Check if prompt length is greater than 3 characters
203
+ if len(prompt.strip()) <= 3:
204
+ return None, "❌ Editing prompt must be more than 3 characters", gr.update(visible=False)
205
+ except Exception as e:
206
+ print(f"⚠️ Request preprocessing error: {e}")
207
+ return None, "❌ Request processing error", gr.update(visible=False)
208
+
209
+ # Get user current phase
210
+ current_phase = get_ip_phase(client_ip)
211
+ current_count = get_ip_generation_count(client_ip)
212
+
213
+ print(f"📊 User phase info - IP: {client_ip}, current phase: {current_phase}, generation count: {current_count}")
214
+
215
+ # Check if user reached the like button tip threshold
216
+ show_like_tip = (current_count >= TIP_TRY_N)
217
+
218
+ # Check if completely blocked
219
+ if current_phase == 'blocked':
220
+ # Generate blocked limit button
221
+ blocked_button_html = f"""
222
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
223
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
224
+ display: inline-flex;
225
+ align-items: center;
226
+ justify-content: center;
227
+ padding: 16px 32px;
228
+ background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
229
+ color: white;
230
+ text-decoration: none;
231
+ border-radius: 12px;
232
+ font-weight: 600;
233
+ font-size: 16px;
234
+ text-align: center;
235
+ min-width: 200px;
236
+ box-shadow: 0 4px 15px rgba(231, 76, 60, 0.4);
237
+ transition: all 0.3s ease;
238
+ border: none;
239
+ '>&#128640; Unlimited Generation</a>
240
+ </div>
241
+ """
242
+ return None, f"❌ You have reached Hugging Face's free generation limit. Please visit https://omnicreator.net/#generator for unlimited generation", gr.update(value=blocked_button_html, visible=True)
243
+
244
+ # Check rate limit (applies to rate_limit phases)
245
+ if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
246
+ is_limited, wait_minutes, window_count = check_rate_limit_for_phase(client_ip, current_phase)
247
+ if is_limited:
248
+ wait_minutes_int = int(wait_minutes) + 1
249
+ # Generate rate limit button
250
+ rate_limit_button_html = f"""
251
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
252
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
253
+ display: inline-flex;
254
+ align-items: center;
255
+ justify-content: center;
256
+ padding: 16px 32px;
257
+ background: linear-gradient(135deg, #f39c12 0%, #e67e22 100%);
258
+ color: white;
259
+ text-decoration: none;
260
+ border-radius: 12px;
261
+ font-weight: 600;
262
+ font-size: 16px;
263
+ text-align: center;
264
+ min-width: 200px;
265
+ box-shadow: 0 4px 15px rgba(243, 156, 18, 0.4);
266
+ transition: all 0.3s ease;
267
+ border: none;
268
+ '>⏰ Skip Wait - Unlimited Generation</a>
269
+ </div>
270
+ """
271
+ return None, f"❌ You have reached Hugging Face's free generation limit. Please visit https://omnicreator.net/#generator for unlimited generation, or wait {wait_minutes_int} minutes before generating again", gr.update(value=rate_limit_button_html, visible=True)
272
+
273
+ # Handle NSFW detection based on phase
274
+ is_nsfw_task = False # Track if this task involves NSFW content
275
+
276
+ # Skip NSFW detection in free phase
277
+ if current_phase != 'free' and nsfw_detector is not None and input_image is not None:
278
+ try:
279
+ nsfw_result = nsfw_detector.predict_pil_label_only(input_image)
280
+
281
+ if nsfw_result.lower() == "nsfw":
282
+ is_nsfw_task = True
283
+ print(f"🔍 Input NSFW detected in {current_phase} phase: ❌❌❌ {nsfw_result} - IP: {client_ip} (will blur result)")
284
+ else:
285
+ print(f"🔍 Input NSFW check passed: ✅✅✅ {nsfw_result} - IP: {client_ip}")
286
+
287
+ except Exception as e:
288
+ print(f"⚠️ Input NSFW detection failed: {e}")
289
+ # Allow continuation when detection fails
290
+
291
+ result_url = None
292
+ status_message = ""
293
+
294
+ def progress_callback(message):
295
+ try:
296
+ nonlocal status_message
297
+ status_message = message
298
+ # Add error handling to prevent progress update failure
299
+ if progress is not None:
300
+ # Enhanced progress display with better formatting
301
+ if "Queue:" in message or "tasks ahead" in message:
302
+ # Queue status - show with different progress value to indicate waiting
303
+ progress(0.1, desc=message)
304
+ elif "Processing" in message or "AI is processing" in message:
305
+ # Processing status
306
+ progress(0.7, desc=message)
307
+ elif "Generating" in message or "Almost done" in message:
308
+ # Generation status
309
+ progress(0.9, desc=message)
310
+ else:
311
+ # Default status
312
+ progress(0.5, desc=message)
313
+ except Exception as e:
314
+ print(f"⚠️ Progress update failed: {e}")
315
+
316
+ try:
317
+ # Record generation attempt (before actual generation to ensure correct count)
318
+ record_generation_attempt(client_ip, current_phase)
319
+ updated_count = get_ip_generation_count(client_ip)
320
+
321
+ print(f"✅ Processing started - IP: {client_ip}, phase: {current_phase}, total count: {updated_count}, prompt: {prompt.strip()}", flush=True)
322
+
323
+ # Call image editing processing function
324
+ result_url, message, task_uuid = process_image_edit(input_image, prompt.strip(), None, progress_callback)
325
+
326
+ if result_url:
327
+ print(f"✅ Processing completed successfully - IP: {client_ip}, result_url: {result_url}, task_uuid: {task_uuid}", flush=True)
328
+
329
+ # Detect result image NSFW content (only in rate limit phases)
330
+ if nsfw_detector is not None and current_phase != 'free':
331
+ try:
332
+ if progress is not None:
333
+ progress(0.9, desc="Checking result image...")
334
+
335
+ is_nsfw, nsfw_error = download_and_check_result_nsfw(result_url, nsfw_detector)
336
+
337
+ if nsfw_error:
338
+ print(f"⚠️ Result image NSFW detection error - IP: {client_ip}, error: {nsfw_error}")
339
+ elif is_nsfw:
340
+ is_nsfw_task = True # Mark task as NSFW
341
+ print(f"🔍 Result image NSFW detected in {current_phase} phase: ❌❌❌ - IP: {client_ip} (will blur result)")
342
+ else:
343
+ print(f"🔍 Result image NSFW check passed: ✅✅✅ - IP: {client_ip}")
344
+
345
+ except Exception as e:
346
+ print(f"��️ Result image NSFW detection exception - IP: {client_ip}, error: {str(e)}")
347
+
348
+ # Apply blur if this is an NSFW task in rate limit phases
349
+ should_blur = False
350
+
351
+ if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3'] and is_nsfw_task:
352
+ should_blur = True
353
+
354
+ # Apply blur processing
355
+ if should_blur:
356
+ if progress is not None:
357
+ progress(0.95, desc="Applying content filter...")
358
+
359
+ blurred_image = apply_gaussian_blur_to_image_url(result_url)
360
+ if blurred_image is not None:
361
+ final_result = blurred_image # Return PIL Image object
362
+ final_message = f"⚠️ NSFW content detected, content filter applied. NSFW content is prohibited by Hugging Face, but you can generate unlimited content at our official website https://omnicreator.net/#generator"
363
+ print(f"🔒 Applied Gaussian blur for NSFW content - IP: {client_ip}")
364
+ else:
365
+ # Blur failed, return original URL with warning
366
+ final_result = result_url
367
+ final_message = f"⚠️ NSFW content detected, but content filter failed. Please visit https://omnicreator.net/#generator for better experience"
368
+
369
+ # Generate NSFW button for blurred content
370
+ nsfw_action_buttons_html = f"""
371
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
372
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
373
+ display: inline-flex;
374
+ align-items: center;
375
+ justify-content: center;
376
+ padding: 16px 32px;
377
+ background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
378
+ color: white;
379
+ text-decoration: none;
380
+ border-radius: 12px;
381
+ font-weight: 600;
382
+ font-size: 16px;
383
+ text-align: center;
384
+ min-width: 200px;
385
+ box-shadow: 0 4px 15px rgba(255, 107, 107, 0.4);
386
+ transition: all 0.3s ease;
387
+ border: none;
388
+ '>🔥 Unlimited NSFW Generation</a>
389
+ </div>
390
+ """
391
+ return final_result, final_message, gr.update(value=nsfw_action_buttons_html, visible=True)
392
+ else:
393
+ final_result = result_url
394
+ final_message = "✅ " + message
395
+
396
+ try:
397
+ if progress is not None:
398
+ progress(1.0, desc="Processing completed")
399
+ except Exception as e:
400
+ print(f"⚠️ Final progress update failed: {e}")
401
+
402
+ # Generate action buttons HTML like Trump AI Voice
403
+ action_buttons_html = ""
404
+ if task_uuid:
405
+ task_detail_url = f"https://omnicreator.net/my-creations/task/{task_uuid}"
406
+ action_buttons_html = f"""
407
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
408
+ <a href='{task_detail_url}' target='_blank' style='
409
+ display: inline-flex;
410
+ align-items: center;
411
+ justify-content: center;
412
+ padding: 16px 32px;
413
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
414
+ color: white;
415
+ text-decoration: none;
416
+ border-radius: 12px;
417
+ font-weight: 600;
418
+ font-size: 16px;
419
+ text-align: center;
420
+ min-width: 160px;
421
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
422
+ transition: all 0.3s ease;
423
+ border: none;
424
+ '>&#128444; Download HD Image</a>
425
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
426
+ display: inline-flex;
427
+ align-items: center;
428
+ justify-content: center;
429
+ padding: 16px 32px;
430
+ background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
431
+ color: white;
432
+ text-decoration: none;
433
+ border-radius: 12px;
434
+ font-weight: 600;
435
+ font-size: 16px;
436
+ text-align: center;
437
+ min-width: 160px;
438
+ box-shadow: 0 4px 15px rgba(17, 153, 142, 0.4);
439
+ transition: all 0.3s ease;
440
+ border: none;
441
+ '>&#128640; Unlimited Generation</a>
442
+ </div>
443
+ """
444
+
445
+ # Add popup script if needed (using different approach)
446
+ if show_like_tip:
447
+ action_buttons_html += """
448
+ <div style='display: flex; justify-content: center; margin: 15px 0 5px 0; padding: 0px;'>
449
+ <div style='
450
+ display: inline-flex;
451
+ align-items: center;
452
+ justify-content: center;
453
+ padding: 12px 24px;
454
+ background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
455
+ color: white;
456
+ border-radius: 10px;
457
+ font-weight: 600;
458
+ font-size: 14px;
459
+ text-align: center;
460
+ max-width: 400px;
461
+ box-shadow: 0 3px 12px rgba(255, 107, 107, 0.3);
462
+ border: none;
463
+ '>👉 Click the ❤️ Like button to unlock more free trial attempts!</div>
464
+ </div>
465
+ """
466
+
467
+ return final_result, final_message, gr.update(value=action_buttons_html, visible=True)
468
+ else:
469
+ print(f"❌ Processing failed - IP: {client_ip}, error: {message}", flush=True)
470
+ return None, "❌ " + message, gr.update(visible=False)
471
+
472
+ except Exception as e:
473
+ print(f"❌ Processing exception - IP: {client_ip}, error: {str(e)}")
474
+ return None, f"❌ Error occurred during processing: {str(e)}", gr.update(visible=False)
475
+
476
+ def local_edit_interface(image_dict, prompt, reference_image, request: gr.Request, progress=gr.Progress()):
477
+ """
478
+ Handle local editing requests (with phase-based limitations)
479
+ """
480
+ try:
481
+ # Extract user IP
482
+ client_ip = request.client.host
483
+ x_forwarded_for = dict(request.headers).get('x-forwarded-for')
484
+ if x_forwarded_for:
485
+ client_ip = x_forwarded_for
486
+ if client_ip not in IP_Dict:
487
+ IP_Dict[client_ip] = 0
488
+ IP_Dict[client_ip] += 1
489
+
490
+ if image_dict is None:
491
+ return None, "Please upload an image and draw the area to edit", gr.update(visible=False)
492
+
493
+ # Handle different input formats for ImageEditor
494
+ if isinstance(image_dict, dict):
495
+ # ImageEditor dict format
496
+ if "background" not in image_dict or "layers" not in image_dict:
497
+ return None, "Please draw the area to edit on the image", gr.update(visible=False)
498
+
499
+ base_image = image_dict["background"]
500
+ layers = image_dict["layers"]
501
+
502
+ # Special handling: if background is None but composite exists, use composite
503
+ if base_image is None and "composite" in image_dict and image_dict["composite"] is not None:
504
+ print("🔧 Background is None, using composite instead")
505
+ base_image = image_dict["composite"]
506
+ else:
507
+ # Simple case: Direct PIL Image (from example)
508
+ base_image = image_dict
509
+ layers = []
510
+
511
+ # Check for special example case - bypass mask requirement
512
+ is_example_case = prompt and prompt.startswith("EXAMPLE_PANDA_CAT_")
513
+
514
+ # Debug: check current state
515
+ if is_example_case:
516
+ print(f"🔍 Example case detected - base_image is None: {base_image is None}")
517
+
518
+ # Special handling for example case: load image directly from file
519
+ if is_example_case and base_image is None:
520
+ try:
521
+ from PIL import Image
522
+ import os
523
+
524
+ main_path = "datas/panda01.jpeg"
525
+ print(f"🔍 Trying to load: {main_path}, exists: {os.path.exists(main_path)}")
526
+
527
+ if os.path.exists(main_path):
528
+ base_image = Image.open(main_path)
529
+ print(f"✅ Successfully loaded example image: {base_image.size}")
530
+ else:
531
+ return None, f"❌ Example image not found: {main_path}", gr.update(visible=False)
532
+ except Exception as e:
533
+ return None, f"❌ Failed to load example image: {str(e)}", gr.update(visible=False)
534
+
535
+ # Additional check for base_image
536
+ if base_image is None:
537
+ if is_example_case:
538
+ print(f"❌ Example case but base_image still None!")
539
+ return None, "❌ No image found. Please upload an image first.", gr.update(visible=False)
540
+
541
+ if not layers and not is_example_case:
542
+ return None, "Please draw the area to edit on the image", gr.update(visible=False)
543
+
544
+ if not prompt or prompt.strip() == "":
545
+ return None, "Please enter editing prompt", gr.update(visible=False)
546
+
547
+ # Check prompt length
548
+ if len(prompt.strip()) <= 3:
549
+ return None, "❌ Editing prompt must be more than 3 characters", gr.update(visible=False)
550
+ except Exception as e:
551
+ print(f"⚠️ Local edit request preprocessing error: {e}")
552
+ return None, "❌ Request processing error", gr.update(visible=False)
553
+
554
+ # Get user current phase
555
+ current_phase = get_ip_phase(client_ip)
556
+ current_count = get_ip_generation_count(client_ip)
557
+
558
+ print(f"📊 Local edit user phase info - IP: {client_ip}, current phase: {current_phase}, generation count: {current_count}")
559
+
560
+ # Check if user reached the like button tip threshold
561
+ show_like_tip = (current_count >= TIP_TRY_N)
562
 
563
+ # Check if completely blocked
564
+ if current_phase == 'blocked':
565
+ # Generate blocked limit button
566
+ blocked_button_html = f"""
567
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
568
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
569
+ display: inline-flex;
570
+ align-items: center;
571
+ justify-content: center;
572
+ padding: 16px 32px;
573
+ background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
574
+ color: white;
575
+ text-decoration: none;
576
+ border-radius: 12px;
577
+ font-weight: 600;
578
+ font-size: 16px;
579
+ text-align: center;
580
+ min-width: 200px;
581
+ box-shadow: 0 4px 15px rgba(231, 76, 60, 0.4);
582
+ transition: all 0.3s ease;
583
+ border: none;
584
+ '>&#128640; Unlimited Generation</a>
585
+ </div>
586
+ """
587
+ return None, f"❌ You have reached Hugging Face's free generation limit. Please visit https://omnicreator.net/#generator for unlimited generation", gr.update(value=blocked_button_html, visible=True)
588
+
589
+ # Check rate limit (applies to rate_limit phases)
590
+ if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3']:
591
+ is_limited, wait_minutes, window_count = check_rate_limit_for_phase(client_ip, current_phase)
592
+ if is_limited:
593
+ wait_minutes_int = int(wait_minutes) + 1
594
+ # Generate rate limit button
595
+ rate_limit_button_html = f"""
596
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
597
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
598
+ display: inline-flex;
599
+ align-items: center;
600
+ justify-content: center;
601
+ padding: 16px 32px;
602
+ background: linear-gradient(135deg, #f39c12 0%, #e67e22 100%);
603
+ color: white;
604
+ text-decoration: none;
605
+ border-radius: 12px;
606
+ font-weight: 600;
607
+ font-size: 16px;
608
+ text-align: center;
609
+ min-width: 200px;
610
+ box-shadow: 0 4px 15px rgba(243, 156, 18, 0.4);
611
+ transition: all 0.3s ease;
612
+ border: none;
613
+ '>⏰ Skip Wait - Unlimited Generation</a>
614
+ </div>
615
+ """
616
+ return None, f"❌ You have reached Hugging Face's free generation limit. Please visit https://omnicreator.net/#generator for unlimited generation, or wait {wait_minutes_int} minutes before generating again", gr.update(value=rate_limit_button_html, visible=True)
617
+
618
+ # Handle NSFW detection based on phase
619
+ is_nsfw_task = False # Track if this task involves NSFW content
620
+
621
+ # Skip NSFW detection in free phase
622
+ if current_phase != 'free' and nsfw_detector is not None and base_image is not None:
623
+ try:
624
+ nsfw_result = nsfw_detector.predict_pil_label_only(base_image)
625
+
626
+ if nsfw_result.lower() == "nsfw":
627
+ is_nsfw_task = True
628
+ print(f"🔍 Local edit input NSFW detected in {current_phase} phase: ❌❌❌ {nsfw_result} - IP: {client_ip} (will blur result)")
629
+ else:
630
+ print(f"🔍 Local edit input NSFW check passed: ✅✅✅ {nsfw_result} - IP: {client_ip}")
631
+
632
+ except Exception as e:
633
+ print(f"⚠️ Local edit input NSFW detection failed: {e}")
634
+ # Allow continuation when detection fails
635
+
636
+ result_url = None
637
+ status_message = ""
638
+
639
+ def progress_callback(message):
640
+ try:
641
+ nonlocal status_message
642
+ status_message = message
643
+ # Add error handling to prevent progress update failure
644
+ if progress is not None:
645
+ # Enhanced progress display with better formatting for local editing
646
+ if "Queue:" in message or "tasks ahead" in message:
647
+ # Queue status - show with different progress value to indicate waiting
648
+ progress(0.1, desc=message)
649
+ elif "Processing" in message or "AI is processing" in message:
650
+ # Processing status
651
+ progress(0.7, desc=message)
652
+ elif "Generating" in message or "Almost done" in message:
653
+ # Generation status
654
+ progress(0.9, desc=message)
655
+ else:
656
+ # Default status
657
+ progress(0.5, desc=message)
658
+ except Exception as e:
659
+ print(f"⚠️ Local edit progress update failed: {e}")
660
+
661
+ try:
662
+ # Record generation attempt (before actual generation to ensure correct count)
663
+ record_generation_attempt(client_ip, current_phase)
664
+ updated_count = get_ip_generation_count(client_ip)
665
+
666
+ print(f"✅ Local editing started - IP: {client_ip}, phase: {current_phase}, total count: {updated_count}, prompt: {prompt.strip()}", flush=True)
667
+
668
+ # Clean prompt for API call
669
+ clean_prompt = prompt.strip()
670
+ if clean_prompt.startswith("EXAMPLE_PANDA_CAT_"):
671
+ clean_prompt = clean_prompt[18:] # Remove the prefix
672
+
673
+ # Call local image editing processing function
674
+ if is_example_case:
675
+ # For example case, pass special flag to use local mask file
676
+ result_url, message, task_uuid = process_local_image_edit(base_image, layers, clean_prompt, reference_image, progress_callback, use_example_mask="datas/panda01m.jpeg")
677
+ else:
678
+ # Normal case
679
+ result_url, message, task_uuid = process_local_image_edit(base_image, layers, clean_prompt, reference_image, progress_callback)
680
+
681
+ if result_url:
682
+ print(f"✅ Local editing completed successfully - IP: {client_ip}, result_url: {result_url}, task_uuid: {task_uuid}", flush=True)
683
+
684
+ # Detect result image NSFW content (only in rate limit phases)
685
+ if nsfw_detector is not None and current_phase != 'free':
686
+ try:
687
+ if progress is not None:
688
+ progress(0.9, desc="Checking result image...")
689
+
690
+ is_nsfw, nsfw_error = download_and_check_result_nsfw(result_url, nsfw_detector)
691
+
692
+ if nsfw_error:
693
+ print(f"⚠️ Local edit result image NSFW detection error - IP: {client_ip}, error: {nsfw_error}")
694
+ elif is_nsfw:
695
+ is_nsfw_task = True # Mark task as NSFW
696
+ print(f"🔍 Local edit result image NSFW detected in {current_phase} phase: ❌❌❌ - IP: {client_ip} (will blur result)")
697
+ else:
698
+ print(f"🔍 Local edit result image NSFW check passed: ✅✅✅ - IP: {client_ip}")
699
+
700
+ except Exception as e:
701
+ print(f"⚠️ Local edit result image NSFW detection exception - IP: {client_ip}, error: {str(e)}")
702
+
703
+ # Apply blur if this is an NSFW task in rate limit phases
704
+ should_blur = False
705
+
706
+ if current_phase in ['rate_limit_1', 'rate_limit_2', 'rate_limit_3'] and is_nsfw_task:
707
+ should_blur = True
708
+
709
+ # Apply blur processing
710
+ if should_blur:
711
+ if progress is not None:
712
+ progress(0.95, desc="Applying content filter...")
713
+
714
+ blurred_image = apply_gaussian_blur_to_image_url(result_url)
715
+ if blurred_image is not None:
716
+ final_result = blurred_image # Return PIL Image object
717
+ final_message = f"⚠️ NSFW content detected, content filter applied. NSFW content is prohibited by Hugging Face, but you can generate unlimited content at our official website https://omnicreator.net/#generator"
718
+ print(f"🔒 Local edit applied Gaussian blur for NSFW content - IP: {client_ip}")
719
+ else:
720
+ # Blur failed, return original URL with warning
721
+ final_result = result_url
722
+ final_message = f"⚠️ NSFW content detected, but content filter failed. Please visit https://omnicreator.net/#generator for better experience"
723
+
724
+ # Generate NSFW button for blurred content
725
+ nsfw_action_buttons_html = f"""
726
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
727
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
728
+ display: inline-flex;
729
+ align-items: center;
730
+ justify-content: center;
731
+ padding: 16px 32px;
732
+ background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
733
+ color: white;
734
+ text-decoration: none;
735
+ border-radius: 12px;
736
+ font-weight: 600;
737
+ font-size: 16px;
738
+ text-align: center;
739
+ min-width: 200px;
740
+ box-shadow: 0 4px 15px rgba(255, 107, 107, 0.4);
741
+ transition: all 0.3s ease;
742
+ border: none;
743
+ '>🔥 Unlimited NSFW Generation</a>
744
+ </div>
745
+ """
746
+ return final_result, final_message, gr.update(value=nsfw_action_buttons_html, visible=True)
747
+ else:
748
+ final_result = result_url
749
+ final_message = "✅ " + message
750
+
751
+ try:
752
+ if progress is not None:
753
+ progress(1.0, desc="Processing completed")
754
+ except Exception as e:
755
+ print(f"⚠️ Local edit final progress update failed: {e}")
756
+
757
+ # Generate action buttons HTML like Trump AI Voice
758
+ action_buttons_html = ""
759
+ if task_uuid:
760
+ task_detail_url = f"https://omnicreator.net/my-creations/task/{task_uuid}"
761
+ action_buttons_html = f"""
762
+ <div style='display: flex; justify-content: center; gap: 15px; margin: 10px 0 5px 0; padding: 0px;'>
763
+ <a href='{task_detail_url}' target='_blank' style='
764
+ display: inline-flex;
765
+ align-items: center;
766
+ justify-content: center;
767
+ padding: 16px 32px;
768
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
769
+ color: white;
770
+ text-decoration: none;
771
+ border-radius: 12px;
772
+ font-weight: 600;
773
+ font-size: 16px;
774
+ text-align: center;
775
+ min-width: 160px;
776
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
777
+ transition: all 0.3s ease;
778
+ border: none;
779
+ '>&#128444; Download HD Image</a>
780
+ <a href='https://omnicreator.net/#generator' target='_blank' style='
781
+ display: inline-flex;
782
+ align-items: center;
783
+ justify-content: center;
784
+ padding: 16px 32px;
785
+ background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
786
+ color: white;
787
+ text-decoration: none;
788
+ border-radius: 12px;
789
+ font-weight: 600;
790
+ font-size: 16px;
791
+ text-align: center;
792
+ min-width: 160px;
793
+ box-shadow: 0 4px 15px rgba(17, 153, 142, 0.4);
794
+ transition: all 0.3s ease;
795
+ border: none;
796
+ '>&#128640; Unlimited Generation</a>
797
+ </div>
798
+ """
799
+
800
+ # Add popup script if needed (using different approach)
801
+ if show_like_tip:
802
+ action_buttons_html += """
803
+ <div style='display: flex; justify-content: center; margin: 15px 0 5px 0; padding: 0px;'>
804
+ <div style='
805
+ display: inline-flex;
806
+ align-items: center;
807
+ justify-content: center;
808
+ padding: 12px 24px;
809
+ background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
810
+ color: white;
811
+ border-radius: 10px;
812
+ font-weight: 600;
813
+ font-size: 14px;
814
+ text-align: center;
815
+ max-width: 400px;
816
+ box-shadow: 0 3px 12px rgba(255, 107, 107, 0.3);
817
+ border: none;
818
+ '>👉 Please consider clicking the ❤️ Like button to support this space!</div>
819
+ </div>
820
+ """
821
+
822
+ return final_result, final_message, gr.update(value=action_buttons_html, visible=True)
823
+ else:
824
+ print(f"❌ Local editing processing failed - IP: {client_ip}, error: {message}", flush=True)
825
+ return None, "❌ " + message, gr.update(visible=False)
826
+
827
+ except Exception as e:
828
+ print(f"❌ Local editing exception - IP: {client_ip}, error: {str(e)}")
829
+ return None, f"❌ Error occurred during processing: {str(e)}", gr.update(visible=False)
830
+
831
+ # Create Gradio interface
832
+ def create_app():
833
+ with gr.Blocks(
834
+ title="AI Image Editor",
835
+ theme=gr.themes.Soft(),
836
+ css="""
837
+ .main-container {
838
+ max-width: 1200px;
839
+ margin: 0 auto;
840
+ }
841
+ .upload-area {
842
+ border: 2px dashed #ccc;
843
+ border-radius: 10px;
844
+ padding: 20px;
845
+ text-align: center;
846
+ }
847
+ .result-area {
848
+ margin-top: 20px;
849
+ padding: 20px;
850
+ border-radius: 10px;
851
+ background-color: #f8f9fa;
852
+ }
853
+ .use-as-input-btn {
854
+ margin-top: 10px;
855
+ width: 100%;
856
+ }
857
+ """,
858
+ # Improve concurrency performance configuration
859
+ head="""
860
+ <script>
861
+ // Reduce client-side state update frequency, avoid excessive SSE connections
862
+ if (window.gradio) {
863
+ window.gradio.update_frequency = 2000; // Update every 2 seconds
864
+ }
865
+ </script>
866
+ """
867
+ ) as app:
868
+
869
+ # Main title - styled like Trump AI Voice
870
+ gr.HTML("""
871
+ <div style="text-align: center; margin: 5px auto 0px auto; max-width: 800px;">
872
+ <h1 style="color: #2c3e50; margin: 0; font-size: 3.5em; font-weight: 800; letter-spacing: 3px; text-shadow: 2px 2px 4px rgba(0,0,0,0.1);">
873
+ 🎨 AI Image Editor
874
+ </h1>
875
+ </div>
876
+ """, padding=False)
877
+
878
+ # 🌟 NEW: Multi-Image Editing Announcement Banner with breathing effect
879
+ gr.HTML("""
880
+ <style>
881
+ @keyframes breathe {
882
+ 0%, 100% { transform: scale(1); }
883
+ 50% { transform: scale(1.02); }
884
+ }
885
+ .breathing-banner {
886
+ animation: breathe 3s ease-in-out infinite;
887
+ }
888
+ </style>
889
+ <div class="breathing-banner" style="
890
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
891
+ margin: 5px auto 5px auto;
892
+ padding: 6px 40px;
893
+ border-radius: 20px;
894
+ max-width: 700px;
895
+ box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
896
+ text-align: center;
897
+ ">
898
+ <span style="color: white; font-weight: 600; font-size: 1.0em;">
899
+ 🚀 NEWS:
900
+ <a href="https://huggingface.co/spaces/Selfit/Multi-Image-Edit" target="_blank" style="
901
+ color: white;
902
+ text-decoration: none;
903
+ border-bottom: 1px solid rgba(255,255,255,0.5);
904
+ transition: all 0.3s ease;
905
+ " onmouseover="this.style.borderBottom='1px solid white'"
906
+ onmouseout="this.style.borderBottom='1px solid rgba(255,255,255,0.5)'">
907
+ World's First Multi-Image Editing Tool →
908
+ </a>
909
+ </span>
910
+ </div>
911
+ """, padding=False)
912
+
913
+ with gr.Tabs():
914
+ with gr.Tab("🌍 Global Editor"):
915
+ with gr.Row():
916
+ with gr.Column(scale=1):
917
+ gr.Markdown("### 📸 Upload Image")
918
+ input_image = gr.Image(
919
+ label="Select image to edit",
920
+ type="pil",
921
+ height=512,
922
+ elem_classes=["upload-area"]
923
+ )
924
+
925
+ gr.Markdown("### ✍️ Editing Instructions")
926
+ prompt_input = gr.Textbox(
927
+ label="Enter editing prompt",
928
+ placeholder="For example: change background to beach, add rainbow, remove background, etc...",
929
+ lines=3,
930
+ max_lines=5
931
+ )
932
+
933
+ edit_button = gr.Button(
934
+ "🚀 Start Editing",
935
+ variant="primary",
936
+ size="lg"
937
+ )
938
+
939
+ with gr.Column(scale=1):
940
+ gr.Markdown("### 🎯 Editing Result")
941
+ output_image = gr.Image(
942
+ label="Edited image",
943
+ height=320,
944
+ elem_classes=["result-area"]
945
+ )
946
+
947
+ use_as_input_btn = gr.Button(
948
+ "🔄 Use as Input",
949
+ variant="secondary",
950
+ size="sm",
951
+ elem_classes=["use-as-input-btn"]
952
+ )
953
+
954
+ status_output = gr.Textbox(
955
+ label="Processing status",
956
+ lines=2,
957
+ max_lines=3,
958
+ interactive=False
959
+ )
960
+
961
+ action_buttons = gr.HTML(visible=False)
962
+
963
+ gr.Markdown("### 💡 Prompt Examples")
964
+ with gr.Row():
965
+ example_prompts = [
966
+ "Set the background to a grand opera stage with red curtains",
967
+ "Change the outfit into a traditional Chinese hanfu with flowing sleeves",
968
+ "Give the character blue dragon-like eyes with glowing pupils",
969
+ "Change lighting to soft dreamy pastel glow",
970
+ "Change pose to sitting cross-legged on the ground"
971
+ ]
972
+
973
+ for prompt in example_prompts:
974
+ gr.Button(
975
+ prompt,
976
+ size="sm"
977
+ ).click(
978
+ lambda p=prompt: p,
979
+ outputs=prompt_input
980
+ )
981
+
982
+ edit_button.click(
983
+ fn=edit_image_interface,
984
+ inputs=[input_image, prompt_input],
985
+ outputs=[output_image, status_output, action_buttons],
986
+ show_progress=True,
987
+ concurrency_limit=10,
988
+ api_name="global_edit"
989
+ )
990
+
991
+ def simple_use_as_input(output_img):
992
+ if output_img is not None:
993
+ return output_img
994
+ return None
995
+
996
+ use_as_input_btn.click(
997
+ fn=simple_use_as_input,
998
+ inputs=[output_image],
999
+ outputs=[input_image]
1000
+ )
1001
+
1002
+ with gr.Tab("🖌️ Local Inpaint"):
1003
+ with gr.Row():
1004
+ with gr.Column(scale=1):
1005
+ gr.Markdown("### 📸 Upload Image and Draw Mask")
1006
+ local_input_image = gr.ImageEditor(
1007
+ label="Upload image and draw mask",
1008
+ type="pil",
1009
+ height=512,
1010
+ brush=gr.Brush(colors=["#ff0000"], default_size=180),
1011
+ elem_classes=["upload-area"]
1012
+ )
1013
+
1014
+ gr.Markdown("### 🖼️ Reference Image(Optional)")
1015
+ local_reference_image = gr.Image(
1016
+ label="Upload reference image (optional)",
1017
+ type="pil",
1018
+ height=256
1019
+ )
1020
+
1021
+ gr.Markdown("### ✍️ Editing Instructions")
1022
+ local_prompt_input = gr.Textbox(
1023
+ label="Enter local editing prompt",
1024
+ placeholder="For example: change selected area hair to golden, add patterns to selected object, change selected area color, etc...",
1025
+ lines=3,
1026
+ max_lines=5
1027
+ )
1028
+
1029
+ local_edit_button = gr.Button(
1030
+ "🎯 Start Local Editing",
1031
+ variant="primary",
1032
+ size="lg"
1033
+ )
1034
+
1035
+ with gr.Column(scale=1):
1036
+ gr.Markdown("### 🎯 Editing Result")
1037
+ local_output_image = gr.Image(
1038
+ label="Local edited image",
1039
+ height=320,
1040
+ elem_classes=["result-area"]
1041
+ )
1042
+
1043
+ local_use_as_input_btn = gr.Button(
1044
+ "🔄 Use as Input",
1045
+ variant="secondary",
1046
+ size="sm",
1047
+ elem_classes=["use-as-input-btn"]
1048
+ )
1049
+
1050
+ local_status_output = gr.Textbox(
1051
+ label="Processing status",
1052
+ lines=2,
1053
+ max_lines=3,
1054
+ interactive=False
1055
+ )
1056
+
1057
+ local_action_buttons = gr.HTML(visible=False)
1058
+
1059
+ local_edit_button.click(
1060
+ fn=local_edit_interface,
1061
+ inputs=[local_input_image, local_prompt_input, local_reference_image],
1062
+ outputs=[local_output_image, local_status_output, local_action_buttons],
1063
+ show_progress=True,
1064
+ concurrency_limit=8,
1065
+ api_name="local_edit"
1066
+ )
1067
+
1068
+ def simple_local_use_as_input(output_img):
1069
+ if output_img is not None:
1070
+ return {
1071
+ "background": output_img,
1072
+ "layers": [],
1073
+ "composite": output_img
1074
+ }
1075
+ return None
1076
+
1077
+ local_use_as_input_btn.click(
1078
+ fn=simple_local_use_as_input,
1079
+ inputs=[local_output_image],
1080
+ outputs=[local_input_image]
1081
+ )
1082
+
1083
+ # Local inpaint example
1084
+ gr.Markdown("### 💡 Local Inpaint Example")
1085
+
1086
+ def load_local_example():
1087
+ """Load panda to cat transformation example - simplified, mask handled in backend"""
1088
+ try:
1089
+ from PIL import Image
1090
+ import os
1091
+
1092
+ # Check file paths
1093
+ main_path = "datas/panda01.jpeg"
1094
+ ref_path = "datas/cat01.webp"
1095
+
1096
+ # Load main image
1097
+ if not os.path.exists(main_path):
1098
+ return None, None, "EXAMPLE_PANDA_CAT_let the cat ride on the panda"
1099
+
1100
+ main_img = Image.open(main_path)
1101
+
1102
+ # Load reference image
1103
+ if not os.path.exists(ref_path):
1104
+ ref_img = None
1105
+ else:
1106
+ ref_img = Image.open(ref_path)
1107
+
1108
+ # ImageEditor format
1109
+ editor_data = {
1110
+ "background": main_img,
1111
+ "layers": [],
1112
+ "composite": main_img
1113
+ }
1114
+
1115
+ # Special prompt to indicate this is the example case
1116
+ prompt = "EXAMPLE_PANDA_CAT_let the cat ride on the panda"
1117
+
1118
+ # Return just the PIL image instead of dict format to avoid UI state issues
1119
+ return main_img, ref_img, prompt
1120
+
1121
+ except Exception as e:
1122
+ return None, None, "EXAMPLE_PANDA_CAT_Transform the panda head into a cute cat head, keeping the body"
1123
+
1124
+ # Example display
1125
+ gr.Markdown("#### 🐼➡️🐱 Example: Panda to Cat Transformation")
1126
+ with gr.Row():
1127
+ with gr.Column(scale=2):
1128
+ # Preview images for local example
1129
+ with gr.Row():
1130
+ try:
1131
+ gr.Image("datas/panda01.jpeg", label="Main Image", height=120, width=120, show_label=True, interactive=False)
1132
+ gr.Image("datas/panda01m.jpeg", label="Mask", height=120, width=120, show_label=True, interactive=False)
1133
+ gr.Image("datas/cat01.webp", label="Reference", height=120, width=120, show_label=True, interactive=False)
1134
+ except:
1135
+ gr.Markdown("*Preview images not available*")
1136
+ gr.Markdown("**Prompt**: let the cat ride on the panda \n**Note**: Mask will be automatically applied when you submit this example")
1137
+ with gr.Column(scale=1):
1138
+ gr.Button(
1139
+ "🎨 Load Panda Example",
1140
+ size="lg",
1141
+ variant="secondary"
1142
+ ).click(
1143
+ fn=load_local_example,
1144
+ outputs=[local_input_image, local_reference_image, local_prompt_input]
1145
+ )
1146
+
1147
+ # Add a refresh button to fix UI state issues
1148
+ gr.Button(
1149
+ "🔄 Refresh Image Editor",
1150
+ size="sm",
1151
+ variant="secondary"
1152
+ ).click(
1153
+ fn=lambda: gr.update(),
1154
+ outputs=[local_input_image]
1155
+ )
1156
+
1157
+ # SEO Content Section
1158
+ gr.HTML("""
1159
+ <div style="width: 100%; margin: 50px 0; padding: 0 20px;">
1160
+
1161
+ <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 40px; border-radius: 20px; margin: 40px 0;">
1162
+ <h2 style="margin: 0 0 20px 0; font-size: 2.2em; font-weight: 700;">
1163
+ &#127912; Unlimited AI Image Generation & Editing
1164
+ </h2>
1165
+ <p style="margin: 0 0 25px 0; font-size: 1.2em; opacity: 0.95; line-height: 1.6;">
1166
+ Experience the ultimate freedom in AI image creation! Generate and edit unlimited images without restrictions,
1167
+ including NSFW content, with our premium AI image editing platform.
1168
+ </p>
1169
+
1170
+ <div style="display: flex; justify-content: center; gap: 25px; flex-wrap: wrap; margin: 30px 0;">
1171
+ <a href="https://omnicreator.net/#generator" target="_blank" style="
1172
+ display: inline-flex;
1173
+ align-items: center;
1174
+ justify-content: center;
1175
+ padding: 20px 40px;
1176
+ background: linear-gradient(135deg, #ff6b6b 0%, #feca57 100%);
1177
+ color: white;
1178
+ text-decoration: none;
1179
+ border-radius: 15px;
1180
+ font-weight: 700;
1181
+ font-size: 18px;
1182
+ text-align: center;
1183
+ min-width: 250px;
1184
+ box-shadow: 0 8px 25px rgba(255, 107, 107, 0.4);
1185
+ transition: all 0.3s ease;
1186
+ border: none;
1187
+ transform: scale(1);
1188
+ " onmouseover="this.style.transform='scale(1.05)'" onmouseout="this.style.transform='scale(1)'">
1189
+ &#128640; Get Unlimited Access Now
1190
+ </a>
1191
+
1192
+ </div>
1193
+
1194
+ <p style="color: rgba(255,255,255,0.9); font-size: 1em; margin: 20px 0 0 0;">
1195
+ Join thousands of creators who trust Omni Creator for unrestricted AI image generation!
1196
+ </p>
1197
+ </div>
1198
+
1199
+ <div style="text-align: center; margin: 25px auto; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); padding: 35px; border-radius: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.1);">
1200
+ <h2 style="color: #2c3e50; margin: 0 0 20px 0; font-size: 1.9em; font-weight: 700;">
1201
+ &#11088; Professional AI Image Editor - No Restrictions
1202
+ </h2>
1203
+ <p style="color: #555; font-size: 1.1em; line-height: 1.6; margin: 0 0 20px 0; padding: 0 20px;">
1204
+ Transform your creative vision into reality with our advanced AI image editing platform. Whether you're creating
1205
+ art, editing photos, designing content, or working with any type of imagery - our powerful AI removes all limitations
1206
+ and gives you complete creative freedom.
1207
+ </p>
1208
+ </div>
1209
+
1210
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 25px; margin: 40px 0;">
1211
+
1212
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #e74c3c;">
1213
+ <h3 style="color: #e74c3c; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1214
+ &#127919; Unlimited Generation
1215
+ </h3>
1216
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1217
+ Premium users enjoy unlimited image generation without daily limits, rate restrictions, or content barriers.
1218
+ Create as many images as you need, whenever you need them.
1219
+ </p>
1220
+ </div>
1221
+
1222
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #3498db;">
1223
+ <h3 style="color: #3498db; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1224
+ 🔓 No Content Restrictions
1225
+ </h3>
1226
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1227
+ Generate and edit any type of content without NSFW filters or content limitations. Complete creative
1228
+ freedom for artists, designers, and content creators.
1229
+ </p>
1230
+ </div>
1231
+
1232
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #27ae60;">
1233
+ <h3 style="color: #27ae60; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1234
+ &#9889; Lightning Fast Processing
1235
+ </h3>
1236
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1237
+ Advanced AI infrastructure delivers high-quality results in seconds. No waiting in queues,
1238
+ no processing delays - just instant, professional-grade image editing.
1239
+ </p>
1240
+ </div>
1241
+
1242
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #9b59b6;">
1243
+ <h3 style="color: #9b59b6; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1244
+ &#127912; Advanced Editing Tools
1245
+ </h3>
1246
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1247
+ Global transformations, precision local editing, style transfer, object removal, background replacement,
1248
+ and dozens of other professional editing capabilities.
1249
+ </p>
1250
+ </div>
1251
+
1252
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #f39c12;">
1253
+ <h3 style="color: #f39c12; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1254
+ &#128142; Premium Quality
1255
+ </h3>
1256
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1257
+ State-of-the-art AI models trained on millions of images deliver exceptional quality and realism.
1258
+ Professional results suitable for commercial use and high-end projects.
1259
+ </p>
1260
+ </div>
1261
+
1262
+ <div style="background: white; padding: 30px; border-radius: 15px; box-shadow: 0 5px 20px rgba(0,0,0,0.08); border-left: 5px solid #34495e;">
1263
+ <h3 style="color: #34495e; margin: 0 0 15px 0; font-size: 1.4em; font-weight: 600;">
1264
+ 🌍 Multi-Modal Support
1265
+ </h3>
1266
+ <p style="color: #666; margin: 0; line-height: 1.6; font-size: 1em;">
1267
+ Support for all image formats, styles, and use cases. From photorealistic portraits to artistic creations,
1268
+ product photography to digital art - we handle everything.
1269
+ </p>
1270
+ </div>
1271
+
1272
+ </div>
1273
+
1274
+ <div style="background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%); color: white; padding: 40px; border-radius: 20px; margin: 40px 0; text-align: center;">
1275
+ <h2 style="margin: 0 0 25px 0; font-size: 1.8em; font-weight: 700;">
1276
+ &#128142; Why Choose Omni Creator Premium?
1277
+ </h2>
1278
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 20px; margin: 30px 0;">
1279
+
1280
+ <div style="background: rgba(255,255,255,0.15); padding: 20px; border-radius: 12px;">
1281
+ <h4 style="margin: 0 0 10px 0; font-size: 1.2em;">🚫 No Rate Limits</h4>
1282
+ <p style="margin: 0; opacity: 0.9; font-size: 0.95em;">Generate unlimited images without waiting periods or daily restrictions</p>
1283
+ </div>
1284
+
1285
+ <div style="background: rgba(255,255,255,0.15); padding: 20px; border-radius: 12px;">
1286
+ <h4 style="margin: 0 0 10px 0; font-size: 1.2em;">🎭 Unrestricted Content</h4>
1287
+ <p style="margin: 0; opacity: 0.9; font-size: 0.95em;">Create any type of content without NSFW filters or censorship</p>
1288
+ </div>
1289
 
1290
+ <div style="background: rgba(255,255,255,0.15); padding: 20px; border-radius: 12px;">
1291
+ <h4 style="margin: 0 0 10px 0; font-size: 1.2em;">&#9889; Priority Processing</h4>
1292
+ <p style="margin: 0; opacity: 0.9; font-size: 0.95em;">Skip queues and get instant results with dedicated processing power</p>
1293
+ </div>
1294
+
1295
+ <div style="background: rgba(255,255,255,0.15); padding: 20px; border-radius: 12px;">
1296
+ <h4 style="margin: 0 0 10px 0; font-size: 1.2em;">&#127912; Advanced Features</h4>
1297
+ <p style="margin: 0; opacity: 0.9; font-size: 0.95em;">Access to latest AI models and cutting-edge editing capabilities</p>
1298
+ </div>
1299
+
1300
+ </div>
1301
+ <div style="display: flex; justify-content: center; margin: 25px 0 0 0;">
1302
+ <a href="https://omnicreator.net/#generator" target="_blank" style="
1303
+ display: inline-flex;
1304
+ align-items: center;
1305
+ justify-content: center;
1306
+ padding: 18px 35px;
1307
+ background: rgba(255,255,255,0.9);
1308
+ color: #333;
1309
+ text-decoration: none;
1310
+ border-radius: 15px;
1311
+ font-weight: 700;
1312
+ font-size: 16px;
1313
+ text-align: center;
1314
+ min-width: 200px;
1315
+ box-shadow: 0 6px 20px rgba(0,0,0,0.3);
1316
+ transition: all 0.3s ease;
1317
+ border: none;
1318
+ ">&#11088; Start Creating Now</a>
1319
+ </div>
1320
+ </div>
1321
+
1322
+ <div style="background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 50%, #fecfef 100%); padding: 30px; border-radius: 15px; margin: 40px 0;">
1323
+ <h3 style="color: #8b5cf6; text-align: center; margin: 0 0 25px 0; font-size: 1.5em; font-weight: 700;">
1324
+ &#128161; Pro Tips for Best Results
1325
+ </h3>
1326
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 18px;">
1327
+
1328
+ <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1329
+ <strong style="color: #8b5cf6; font-size: 1.1em;">📝 Clear Descriptions:</strong>
1330
+ <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">Use detailed, specific prompts for better results. Describe colors, styles, lighting, and composition clearly.</p>
1331
+ </div>
1332
+
1333
+ <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1334
+ <strong style="color: #8b5cf6; font-size: 1.1em;">&#127919; Local Editing:</strong>
1335
+ <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">Use precise brush strokes to select areas for local editing. Smaller, focused edits often yield better results.</p>
1336
+ </div>
1337
+
1338
+ <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1339
+ <strong style="color: #8b5cf6; font-size: 1.1em;">&#9889; Iterative Process:</strong>
1340
+ <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">Use "Use as Input" feature to refine results. Multiple iterations can achieve complex transformations.</p>
1341
+ </div>
1342
+
1343
+ <div style="background: rgba(255,255,255,0.85); padding: 18px; border-radius: 12px;">
1344
+ <strong style="color: #8b5cf6; font-size: 1.1em;">&#128444; Image Quality:</strong>
1345
+ <p style="color: #555; margin: 5px 0 0 0; line-height: 1.5;">Higher resolution input images (up to 10MB) generally produce better editing results and finer details.</p>
1346
+ </div>
1347
+
1348
+ </div>
1349
+ </div>
1350
+
1351
+ <div style="text-align: center; margin: 25px auto; background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%); padding: 35px; border-radius: 20px; box-shadow: 0 10px 30px rgba(0,0,0,0.1);">
1352
+ <h2 style="color: #2c3e50; margin: 0 0 20px 0; font-size: 1.8em; font-weight: 700;">
1353
+ &#128640; Perfect For Every Creative Need
1354
+ </h2>
1355
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px; margin: 25px 0; text-align: left;">
1356
+
1357
+ <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1358
+ <h4 style="color: #e74c3c; margin: 0 0 10px 0;">🎨 Digital Art</h4>
1359
+ <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1360
+ <li>Character design</li>
1361
+ <li>Concept art</li>
1362
+ <li>Style transfer</li>
1363
+ <li>Artistic effects</li>
1364
+ </ul>
1365
+ </div>
1366
+
1367
+ <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1368
+ <h4 style="color: #3498db; margin: 0 0 10px 0;">📸 Photography</h4>
1369
+ <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1370
+ <li>Background replacement</li>
1371
+ <li>Object removal</li>
1372
+ <li>Lighting adjustment</li>
1373
+ <li>Portrait enhancement</li>
1374
+ </ul>
1375
+ </div>
1376
+
1377
+ <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1378
+ <h4 style="color: #27ae60; margin: 0 0 10px 0;">🛍️ E-commerce</h4>
1379
+ <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1380
+ <li>Product photography</li>
1381
+ <li>Lifestyle shots</li>
1382
+ <li>Color variations</li>
1383
+ <li>Context placement</li>
1384
+ </ul>
1385
+ </div>
1386
+
1387
+ <div style="background: rgba(255,255,255,0.8); padding: 20px; border-radius: 12px;">
1388
+ <h4 style="color: #9b59b6; margin: 0 0 10px 0;">📱 Social Media</h4>
1389
+ <ul style="color: #555; margin: 0; padding-left: 18px; line-height: 1.6;">
1390
+ <li>Content creation</li>
1391
+ <li>Meme generation</li>
1392
+ <li>Brand visuals</li>
1393
+ <li>Viral content</li>
1394
+ </ul>
1395
+ </div>
1396
+
1397
+ </div>
1398
+ <div style="text-align: center; margin: 25px 0 0 0;">
1399
+ <a href="https://omnicreator.net/#generator" target="_blank" style="
1400
+ display: inline-flex;
1401
+ align-items: center;
1402
+ justify-content: center;
1403
+ padding: 18px 35px;
1404
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1405
+ color: white;
1406
+ text-decoration: none;
1407
+ border-radius: 15px;
1408
+ font-weight: 700;
1409
+ font-size: 16px;
1410
+ text-align: center;
1411
+ min-width: 220px;
1412
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4);
1413
+ transition: all 0.3s ease;
1414
+ border: none;
1415
+ ">🎯 Start Your Project Now</a>
1416
+ </div>
1417
+ </div>
1418
+
1419
+ </div>
1420
+
1421
+ <div style="text-align: center; margin: 30px auto 20px auto; padding: 20px;">
1422
+ <p style="margin: 0 0 10px 0; font-size: 18px; color: #333; font-weight: 500;">
1423
+ Powered by <a href="https://omnicreator.net/#generator" target="_blank" style="color: #667eea; text-decoration: none; font-weight: bold;">Omni Creator</a>
1424
+ </p>
1425
+ <p style="margin: 0; font-size: 14px; color: #999; font-weight: 400;">
1426
+ The ultimate AI image generation and editing platform • Unlimited creativity, zero restrictions
1427
+ </p>
1428
+ </div>
1429
+ """, padding=False)
1430
+
1431
+ return app
1432
+
1433
+ if __name__ == "__main__":
1434
+ app = create_app()
1435
+ # Improve queue configuration to handle high concurrency and prevent SSE connection issues
1436
  app.queue(
1437
+ default_concurrency_limit=20, # Default concurrency limit
1438
+ max_size=50, # Maximum queue size
1439
+ api_open=False # Close API access to reduce resource consumption
1440
  )
1441
  app.launch(
1442
  server_name="0.0.0.0",
1443
+ show_error=True, # Show detailed error information
1444
+ quiet=False, # Keep log output
1445
+ max_threads=40, # Increase thread pool size
1446
  height=800,
1447
+ favicon_path=None # Reduce resource loading
1448
+ )
 
 
 
 
 
 
 
 
 
 
 
 
nfsw.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import json
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ class NSFWDetector:
10
+ """
11
+ NSFW检测器类,使用YOLOv9模型进行图像分类
12
+ """
13
+
14
+ def __init__(self, repo_id="Falconsai/nsfw_image_detection",
15
+ model_filename="falconsai_yolov9_nsfw_model_quantized.pt",
16
+ labels_filename="labels.json",
17
+ input_size=(224, 224)):
18
+ """
19
+ 初始化NSFW检测器
20
+
21
+ Args:
22
+ repo_id (str): Hugging Face仓库ID
23
+ model_filename (str): 模型文件名
24
+ labels_filename (str): 标签文件名
25
+ input_size (tuple): 模型输入尺寸 (height, width)
26
+ """
27
+ self.repo_id = repo_id
28
+ self.model_filename = model_filename
29
+ self.labels_filename = labels_filename
30
+ self.input_size = input_size
31
+
32
+ # 从Hugging Face下载文件
33
+ self.model_path = self._download_model()
34
+ self.labels_path = self._download_labels()
35
+
36
+ # 加载标签
37
+ self.labels = self._load_labels()
38
+
39
+ # 加载模型
40
+ self.session = self._load_model()
41
+ self.input_name = self.session.get_inputs()[0].name
42
+ self.output_name = self.session.get_outputs()[0].name
43
+
44
+ def _download_model(self):
45
+ """
46
+ 从Hugging Face下载模型文件
47
+
48
+ Returns:
49
+ str: 下载的模型文件路径
50
+ """
51
+ try:
52
+ print(f"正在从 {self.repo_id} 下载模型文件: {self.model_filename}")
53
+ model_path = hf_hub_download(
54
+ repo_id=self.repo_id,
55
+ filename=self.model_filename,
56
+ cache_dir="./hf_cache"
57
+ )
58
+ print(f"✅ 模型下载成功: {model_path}")
59
+ return model_path
60
+ except Exception as e:
61
+ raise RuntimeError(f"模型下载失败: {e}")
62
+
63
+ def _download_labels(self):
64
+ """
65
+ 从Hugging Face下载标签文件
66
+
67
+ Returns:
68
+ str: 下载的标签文件路径
69
+ """
70
+ try:
71
+ print(f"正在从 {self.repo_id} 下载标签文件: {self.labels_filename}")
72
+ labels_path = hf_hub_download(
73
+ repo_id=self.repo_id,
74
+ filename=self.labels_filename,
75
+ cache_dir="./hf_cache"
76
+ )
77
+ print(f"✅ 标签文件下载成功: {labels_path}")
78
+ return labels_path
79
+ except Exception as e:
80
+ raise RuntimeError(f"标签文件下载失败: {e}")
81
+
82
+ def _load_labels(self):
83
+ """
84
+ 加载类别标签
85
+
86
+ Returns:
87
+ dict: 标签字典
88
+ """
89
+ try:
90
+ with open(self.labels_path, "r") as f:
91
+ return json.load(f)
92
+ except FileNotFoundError:
93
+ raise FileNotFoundError(f"标签文件未找到: {self.labels_path}")
94
+ except json.JSONDecodeError:
95
+ raise ValueError(f"标签文件格式错误: {self.labels_path}")
96
+
97
+ def _load_model(self):
98
+ """
99
+ 加载ONNX模型
100
+
101
+ Returns:
102
+ onnxruntime.InferenceSession: 模型会话
103
+ """
104
+ try:
105
+ return ort.InferenceSession(self.model_path)
106
+ except Exception as e:
107
+ raise RuntimeError(f"模型加载失败: {self.model_path}, 错误: {e}")
108
+
109
+ def _preprocess_image(self, image_path):
110
+ """
111
+ 图像预处理
112
+
113
+ Args:
114
+ image_path (str): 图像文件路径
115
+
116
+ Returns:
117
+ tuple: (预处理后的张量, 原始图像)
118
+ """
119
+ try:
120
+ # 加载并转换图像
121
+ original_image = Image.open(image_path).convert("RGB")
122
+
123
+ # 调整尺寸
124
+ image_resized = original_image.resize(self.input_size, Image.Resampling.BILINEAR)
125
+
126
+ # 转换为numpy数组并归一化
127
+ image_np = np.array(image_resized, dtype=np.float32) / 255.0
128
+
129
+ # 调整维度顺序 [H, W, C] -> [C, H, W]
130
+ image_np = np.transpose(image_np, (2, 0, 1))
131
+
132
+ # 添加批次维度 [C, H, W] -> [1, C, H, W]
133
+ input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
134
+
135
+ return input_tensor, original_image
136
+
137
+ except FileNotFoundError:
138
+ raise FileNotFoundError(f"图像文件未找到: {image_path}")
139
+ except Exception as e:
140
+ raise RuntimeError(f"图像预处理失败: {e}")
141
+
142
+ def _postprocess_predictions(self, predictions):
143
+ """
144
+ 后处理预测结果
145
+
146
+ Args:
147
+ predictions: 模型预测输出
148
+
149
+ Returns:
150
+ str: 预测的类别标签
151
+ """
152
+ predicted_index = np.argmax(predictions)
153
+ predicted_label = self.labels[str(predicted_index)]
154
+ return predicted_label
155
+
156
+ def predict(self, image_path):
157
+ """
158
+ 对单张图像进行NSFW检测
159
+
160
+ Args:
161
+ image_path (str): 图像文件路径
162
+
163
+ Returns:
164
+ tuple: (预测标签, 原始图像)
165
+ """
166
+ # 预处理图像
167
+ input_tensor, original_image = self._preprocess_image(image_path)
168
+
169
+ # 运行推理
170
+ outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
171
+ predictions = outputs[0]
172
+
173
+ # 后处理结果
174
+ predicted_label = self._postprocess_predictions(predictions)
175
+
176
+ return predicted_label, original_image
177
+
178
+ def predict_label_only(self, image_path):
179
+ """
180
+ 只返回预测标签(不返回图像)
181
+
182
+ Args:
183
+ image_path (str): 图像文件路径
184
+
185
+ Returns:
186
+ str: 预测的类别标签
187
+ """
188
+ predicted_label, _ = self.predict(image_path)
189
+ return predicted_label
190
+
191
+ def predict_from_pil(self, pil_image):
192
+ """
193
+ 直接从PIL Image对象进行NSFW检测
194
+
195
+ Args:
196
+ pil_image (PIL.Image): PIL图像对象
197
+
198
+ Returns:
199
+ tuple: (预测标签, 原始图像)
200
+ """
201
+ try:
202
+ # 确保是RGB格式
203
+ if pil_image.mode != "RGB":
204
+ pil_image = pil_image.convert("RGB")
205
+
206
+ # 调整尺寸
207
+ image_resized = pil_image.resize(self.input_size, Image.Resampling.BILINEAR)
208
+
209
+ # 转换为numpy数组并归一化
210
+ image_np = np.array(image_resized, dtype=np.float32) / 255.0
211
+
212
+ # 调整维度顺序 [H, W, C] -> [C, H, W]
213
+ image_np = np.transpose(image_np, (2, 0, 1))
214
+
215
+ # 添加批次维度 [C, H, W] -> [1, C, H, W]
216
+ input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
217
+
218
+ # 运行推理
219
+ outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
220
+ predictions = outputs[0]
221
+
222
+ # 后处理结果
223
+ predicted_label = self._postprocess_predictions(predictions)
224
+
225
+ return predicted_label, pil_image
226
+
227
+ except Exception as e:
228
+ raise RuntimeError(f"PIL图像预测失败: {e}")
229
+
230
+ def predict_pil_label_only(self, pil_image):
231
+ """
232
+ 从PIL Image对象只返回预测标签
233
+
234
+ Args:
235
+ pil_image (PIL.Image): PIL图像对象
236
+
237
+ Returns:
238
+ str: 预测的类别标签
239
+ """
240
+ predicted_label, _ = self.predict_from_pil(pil_image)
241
+ return predicted_label
242
+
243
+ # --- 使用示例 ---
244
+ if __name__ == "__main__":
245
+ # 配置参数
246
+ single_image_path = "datas/bad01.jpg"
247
+
248
+ try:
249
+ # 创建检测器实例(自动从Hugging Face下载)
250
+ detector = NSFWDetector()
251
+
252
+ # 检查图像文件是否存在
253
+ if os.path.exists(single_image_path):
254
+ # 进行预测
255
+ predicted_label = detector.predict_label_only(single_image_path)
256
+ print(f"图像文件: {single_image_path}")
257
+ print(f"预测结果: {predicted_label}")
258
+ else:
259
+ print(f"错误: 指定的图像文件不存在: {single_image_path}")
260
+
261
+ except Exception as e:
262
+ print(f"初始化检测器时发生错误: {e}")
pipeline.py DELETED
@@ -1,1934 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Optional, Tuple, Union, List, Dict, Any, Callable
5
- from dataclasses import dataclass
6
- import numpy as np
7
- from PIL import Image
8
- import torchvision.transforms as T
9
- from torchvision.transforms.functional import to_tensor, normalize
10
- import warnings
11
- from contextlib import contextmanager
12
- from functools import wraps
13
-
14
- from transformers import PretrainedConfig, PreTrainedModel, CLIPTextModel, CLIPTokenizer
15
- from transformers.modeling_outputs import BaseModelOutputWithPooling
16
- from diffusers import DiffusionPipeline, DDIMScheduler
17
- from diffusers.configuration_utils import ConfigMixin, register_to_config
18
- from diffusers.models.modeling_utils import ModelMixin
19
- from diffusers.utils import BaseOutput
20
-
21
- # Optimization imports
22
- try:
23
- import transformer_engine.pytorch as te
24
- from transformer_engine.common import recipe
25
- HAS_TRANSFORMER_ENGINE = True
26
- except ImportError:
27
- HAS_TRANSFORMER_ENGINE = False
28
-
29
- try:
30
- from torch._dynamo import config as dynamo_config
31
- HAS_TORCH_COMPILE = hasattr(torch, 'compile')
32
- except ImportError:
33
- HAS_TORCH_COMPILE = False
34
-
35
- # -----------------------------------------------------------------------------
36
- # 1. Advanced Configuration (8B Scale)
37
- # -----------------------------------------------------------------------------
38
-
39
- class OmniMMDitV2Config(PretrainedConfig):
40
- model_type = "omnimm_dit_v2"
41
-
42
- def __init__(
43
- self,
44
- vocab_size: int = 49408,
45
- hidden_size: int = 4096, # 4096 dim for ~7B-8B scale
46
- intermediate_size: int = 11008, # Llama-style MLP expansion
47
- num_hidden_layers: int = 32, # Deep network
48
- num_attention_heads: int = 32,
49
- num_key_value_heads: Optional[int] = 8, # GQA (Grouped Query Attention)
50
- hidden_act: str = "silu",
51
- max_position_embeddings: int = 4096,
52
- initializer_range: float = 0.02,
53
- rms_norm_eps: float = 1e-5,
54
- use_cache: bool = True,
55
- pad_token_id: int = 0,
56
- bos_token_id: int = 1,
57
- eos_token_id: int = 2,
58
- tie_word_embeddings: bool = False,
59
- rope_theta: float = 10000.0,
60
- # DiT Specifics
61
- patch_size: int = 2,
62
- in_channels: int = 4, # VAE Latent channels
63
- out_channels: int = 4, # x2 for variance if learned
64
- frequency_embedding_size: int = 256,
65
- # Multi-Modal Specifics
66
- max_condition_images: int = 3, # Support 1-3 input images
67
- visual_embed_dim: int = 1024, # e.g., SigLIP or CLIP Vision
68
- text_embed_dim: int = 4096, # T5-XXL or similar
69
- use_temporal_attention: bool = True, # For Video generation
70
- # Optimization Configs
71
- use_fp8_quantization: bool = False,
72
- use_compilation: bool = False,
73
- compile_mode: str = "reduce-overhead",
74
- use_flash_attention: bool = True,
75
- **kwargs,
76
- ):
77
- self.vocab_size = vocab_size
78
- self.hidden_size = hidden_size
79
- self.intermediate_size = intermediate_size
80
- self.num_hidden_layers = num_hidden_layers
81
- self.num_attention_heads = num_attention_heads
82
- self.num_key_value_heads = num_key_value_heads
83
- self.hidden_act = hidden_act
84
- self.max_position_embeddings = max_position_embeddings
85
- self.initializer_range = initializer_range
86
- self.rms_norm_eps = rms_norm_eps
87
- self.use_cache = use_cache
88
- self.rope_theta = rope_theta
89
- self.patch_size = patch_size
90
- self.in_channels = in_channels
91
- self.out_channels = out_channels
92
- self.frequency_embedding_size = frequency_embedding_size
93
- self.max_condition_images = max_condition_images
94
- self.visual_embed_dim = visual_embed_dim
95
- self.text_embed_dim = text_embed_dim
96
- self.use_temporal_attention = use_temporal_attention
97
- self.use_fp8_quantization = use_fp8_quantization
98
- self.use_compilation = use_compilation
99
- self.compile_mode = compile_mode
100
- self.use_flash_attention = use_flash_attention
101
- super().__init__(
102
- pad_token_id=pad_token_id,
103
- bos_token_id=bos_token_id,
104
- eos_token_id=eos_token_id,
105
- tie_word_embeddings=tie_word_embeddings,
106
- **kwargs,
107
- )
108
-
109
- # -----------------------------------------------------------------------------
110
- # 2. Professional Building Blocks (RoPE, SwiGLU, AdaLN)
111
- # -----------------------------------------------------------------------------
112
-
113
- class OmniRMSNorm(nn.Module):
114
- def __init__(self, hidden_size, eps=1e-6):
115
- super().__init__()
116
- self.weight = nn.Parameter(torch.ones(hidden_size))
117
- self.variance_epsilon = eps
118
-
119
- def forward(self, hidden_states):
120
- input_dtype = hidden_states.dtype
121
- hidden_states = hidden_states.to(torch.float32)
122
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
123
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
124
- return self.weight * hidden_states.to(input_dtype)
125
-
126
- class OmniRotaryEmbedding(nn.Module):
127
- """Complex implementation of Rotary Positional Embeddings for DiT"""
128
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
129
- super().__init__()
130
- self.dim = dim
131
- self.max_position_embeddings = max_position_embeddings
132
- self.base = base
133
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
134
- self.register_buffer("inv_freq", inv_freq, persistent=False)
135
-
136
- def forward(self, x, seq_len=None):
137
- t = torch.arange(seq_len or x.shape[1], device=x.device).type_as(self.inv_freq)
138
- freqs = torch.outer(t, self.inv_freq)
139
- emb = torch.cat((freqs, freqs), dim=-1)
140
- return emb.cos(), emb.sin()
141
-
142
- class OmniSwiGLU(nn.Module):
143
- """Swish-Gated Linear Unit for High-Performance FFN"""
144
- def __init__(self, config: OmniMMDitV2Config):
145
- super().__init__()
146
- self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
147
- self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
148
- self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
149
-
150
- def forward(self, x):
151
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
152
-
153
- class TimestepEmbedder(nn.Module):
154
- """Fourier feature embedding for timesteps"""
155
- def __init__(self, hidden_size, frequency_embedding_size=256):
156
- super().__init__()
157
- self.mlp = nn.Sequential(
158
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
159
- nn.SiLU(),
160
- nn.Linear(hidden_size, hidden_size, bias=True),
161
- )
162
- self.frequency_embedding_size = frequency_embedding_size
163
-
164
- @staticmethod
165
- def timestep_embedding(t, dim, max_period=10000):
166
- half = dim // 2
167
- freqs = torch.exp(
168
- -torch.log(torch.tensor(max_period)) * torch.arange(start=0, end=half, dtype=torch.float32) / half
169
- ).to(device=t.device)
170
- args = t[:, None].float() * freqs[None]
171
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
172
- if dim % 2:
173
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
174
- return embedding
175
-
176
- def forward(self, t, dtype):
177
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
178
- return self.mlp(t_freq)
179
-
180
- # -----------------------------------------------------------------------------
181
- # 2.5. Data Processing Utilities
182
- # -----------------------------------------------------------------------------
183
-
184
- class OmniImageProcessor:
185
- """Advanced image preprocessing for multi-modal diffusion models"""
186
-
187
- def __init__(
188
- self,
189
- image_mean: List[float] = [0.485, 0.456, 0.406],
190
- image_std: List[float] = [0.229, 0.224, 0.225],
191
- size: Tuple[int, int] = (512, 512),
192
- interpolation: str = "bicubic",
193
- do_normalize: bool = True,
194
- do_center_crop: bool = False,
195
- ):
196
- self.image_mean = image_mean
197
- self.image_std = image_std
198
- self.size = size
199
- self.do_normalize = do_normalize
200
- self.do_center_crop = do_center_crop
201
-
202
- # Build transform pipeline
203
- transforms_list = []
204
- if do_center_crop:
205
- transforms_list.append(T.CenterCrop(min(size)))
206
-
207
- interp_mode = {
208
- "bilinear": T.InterpolationMode.BILINEAR,
209
- "bicubic": T.InterpolationMode.BICUBIC,
210
- "lanczos": T.InterpolationMode.LANCZOS,
211
- }.get(interpolation, T.InterpolationMode.BICUBIC)
212
-
213
- transforms_list.append(T.Resize(size, interpolation=interp_mode, antialias=True))
214
- self.transform = T.Compose(transforms_list)
215
-
216
- def preprocess(
217
- self,
218
- images: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]],
219
- return_tensors: str = "pt",
220
- ) -> torch.Tensor:
221
- """
222
- Preprocess images for model input.
223
-
224
- Args:
225
- images: Single image or list of images (PIL, numpy, or torch)
226
- return_tensors: Return type ("pt" for PyTorch)
227
-
228
- Returns:
229
- Preprocessed image tensor [B, C, H, W]
230
- """
231
- if not isinstance(images, list):
232
- images = [images]
233
-
234
- processed = []
235
- for img in images:
236
- # Convert to PIL if needed
237
- if isinstance(img, np.ndarray):
238
- if img.dtype == np.uint8:
239
- img = Image.fromarray(img)
240
- else:
241
- img = Image.fromarray((img * 255).astype(np.uint8))
242
- elif isinstance(img, torch.Tensor):
243
- img = T.ToPILImage()(img)
244
-
245
- # Apply transforms
246
- img = self.transform(img)
247
-
248
- # Convert to tensor
249
- if not isinstance(img, torch.Tensor):
250
- img = to_tensor(img)
251
-
252
- # Normalize
253
- if self.do_normalize:
254
- img = normalize(img, self.image_mean, self.image_std)
255
-
256
- processed.append(img)
257
-
258
- # Stack into batch
259
- if return_tensors == "pt":
260
- return torch.stack(processed, dim=0)
261
-
262
- return processed
263
-
264
- def postprocess(
265
- self,
266
- images: torch.Tensor,
267
- output_type: str = "pil",
268
- ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]:
269
- """
270
- Postprocess model output to desired format.
271
-
272
- Args:
273
- images: Model output tensor [B, C, H, W]
274
- output_type: "pil", "np", or "pt"
275
-
276
- Returns:
277
- Processed images in requested format
278
- """
279
- # Denormalize if needed
280
- if self.do_normalize:
281
- mean = torch.tensor(self.image_mean).view(1, 3, 1, 1).to(images.device)
282
- std = torch.tensor(self.image_std).view(1, 3, 1, 1).to(images.device)
283
- images = images * std + mean
284
-
285
- # Clamp to valid range
286
- images = torch.clamp(images, 0, 1)
287
-
288
- if output_type == "pil":
289
- images = images.cpu().permute(0, 2, 3, 1).numpy()
290
- images = (images * 255).round().astype(np.uint8)
291
- return [Image.fromarray(img) for img in images]
292
- elif output_type == "np":
293
- return images.cpu().numpy()
294
- else:
295
- return images
296
-
297
-
298
- class OmniVideoProcessor:
299
- """Video frame processing for temporal diffusion models"""
300
-
301
- def __init__(
302
- self,
303
- image_processor: OmniImageProcessor,
304
- num_frames: int = 16,
305
- frame_stride: int = 1,
306
- ):
307
- self.image_processor = image_processor
308
- self.num_frames = num_frames
309
- self.frame_stride = frame_stride
310
-
311
- def preprocess_video(
312
- self,
313
- video_frames: Union[List[Image.Image], np.ndarray, torch.Tensor],
314
- temporal_interpolation: bool = True,
315
- ) -> torch.Tensor:
316
- """
317
- Preprocess video frames for temporal model.
318
-
319
- Args:
320
- video_frames: List of PIL images, numpy array [T, H, W, C], or tensor [T, C, H, W]
321
- temporal_interpolation: Whether to interpolate to target frame count
322
-
323
- Returns:
324
- Preprocessed video tensor [B, C, T, H, W]
325
- """
326
- # Convert to list of PIL images
327
- if isinstance(video_frames, np.ndarray):
328
- if video_frames.ndim == 4: # [T, H, W, C]
329
- video_frames = [Image.fromarray(frame) for frame in video_frames]
330
- else:
331
- raise ValueError(f"Expected 4D numpy array, got shape {video_frames.shape}")
332
- elif isinstance(video_frames, torch.Tensor):
333
- if video_frames.ndim == 4: # [T, C, H, W]
334
- video_frames = [T.ToPILImage()(frame) for frame in video_frames]
335
- else:
336
- raise ValueError(f"Expected 4D tensor, got shape {video_frames.shape}")
337
-
338
- # Sample frames if needed
339
- total_frames = len(video_frames)
340
- if temporal_interpolation and total_frames != self.num_frames:
341
- indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
342
- video_frames = [video_frames[i] for i in indices]
343
-
344
- # Process each frame
345
- processed_frames = []
346
- for frame in video_frames[:self.num_frames]:
347
- frame_tensor = self.image_processor.preprocess(frame, return_tensors="pt")[0]
348
- processed_frames.append(frame_tensor)
349
-
350
- # Stack: [T, C, H, W] -> [1, C, T, H, W]
351
- video_tensor = torch.stack(processed_frames, dim=1).unsqueeze(0)
352
- return video_tensor
353
-
354
- def postprocess_video(
355
- self,
356
- video_tensor: torch.Tensor,
357
- output_type: str = "pil",
358
- ) -> Union[List[Image.Image], np.ndarray, torch.Tensor]:
359
- """
360
- Postprocess video output.
361
-
362
- Args:
363
- video_tensor: Model output [B, C, T, H, W] or [B, T, C, H, W]
364
- output_type: "pil", "np", or "pt"
365
-
366
- Returns:
367
- Processed video frames
368
- """
369
- # Normalize dimensions to [B, T, C, H, W]
370
- if video_tensor.ndim == 5:
371
- if video_tensor.shape[1] in [3, 4]: # [B, C, T, H, W]
372
- video_tensor = video_tensor.permute(0, 2, 1, 3, 4)
373
-
374
- batch_size, num_frames = video_tensor.shape[:2]
375
-
376
- # Process each frame
377
- all_frames = []
378
- for b in range(batch_size):
379
- frames = []
380
- for t in range(num_frames):
381
- frame = video_tensor[b, t] # [C, H, W]
382
- frame = frame.unsqueeze(0) # [1, C, H, W]
383
- processed = self.image_processor.postprocess(frame, output_type=output_type)
384
- frames.extend(processed)
385
- all_frames.append(frames)
386
-
387
- return all_frames[0] if batch_size == 1 else all_frames
388
-
389
-
390
- class OmniLatentProcessor:
391
- """VAE latent space encoding/decoding with scaling and normalization"""
392
-
393
- def __init__(
394
- self,
395
- vae: Any,
396
- scaling_factor: float = 0.18215,
397
- do_normalize_latents: bool = True,
398
- ):
399
- self.vae = vae
400
- self.scaling_factor = scaling_factor
401
- self.do_normalize_latents = do_normalize_latents
402
-
403
- @torch.no_grad()
404
- def encode(
405
- self,
406
- images: torch.Tensor,
407
- generator: Optional[torch.Generator] = None,
408
- return_dict: bool = False,
409
- ) -> torch.Tensor:
410
- """
411
- Encode images to latent space.
412
-
413
- Args:
414
- images: Input images [B, C, H, W] in range [-1, 1]
415
- generator: Random generator for sampling
416
- return_dict: Whether to return dict or tensor
417
-
418
- Returns:
419
- Latent codes [B, 4, H//8, W//8]
420
- """
421
- # VAE expects input in [-1, 1]
422
- if images.min() >= 0:
423
- images = images * 2.0 - 1.0
424
-
425
- # Encode
426
- latent_dist = self.vae.encode(images).latent_dist
427
- latents = latent_dist.sample(generator=generator)
428
-
429
- # Scale latents
430
- latents = latents * self.scaling_factor
431
-
432
- # Additional normalization for stability
433
- if self.do_normalize_latents:
434
- latents = (latents - latents.mean()) / (latents.std() + 1e-6)
435
-
436
- return latents if not return_dict else {"latents": latents}
437
-
438
- @torch.no_grad()
439
- def decode(
440
- self,
441
- latents: torch.Tensor,
442
- return_dict: bool = False,
443
- ) -> torch.Tensor:
444
- """
445
- Decode latents to image space.
446
-
447
- Args:
448
- latents: Latent codes [B, 4, H//8, W//8]
449
- return_dict: Whether to return dict or tensor
450
-
451
- Returns:
452
- Decoded images [B, 3, H, W] in range [-1, 1]
453
- """
454
- # Denormalize if needed
455
- if self.do_normalize_latents:
456
- # Assume identity transform for simplicity in decoding
457
- pass
458
-
459
- # Unscale
460
- latents = latents / self.scaling_factor
461
-
462
- # Decode
463
- images = self.vae.decode(latents).sample
464
-
465
- return images if not return_dict else {"images": images}
466
-
467
- @torch.no_grad()
468
- def encode_video(
469
- self,
470
- video_frames: torch.Tensor,
471
- generator: Optional[torch.Generator] = None,
472
- ) -> torch.Tensor:
473
- """
474
- Encode video frames to latent space.
475
-
476
- Args:
477
- video_frames: Input video [B, C, T, H, W] or [B, T, C, H, W]
478
- generator: Random generator
479
-
480
- Returns:
481
- Video latents [B, 4, T, H//8, W//8]
482
- """
483
- # Reshape to process frames independently
484
- if video_frames.shape[2] not in [3, 4]: # [B, T, C, H, W]
485
- B, T, C, H, W = video_frames.shape
486
- video_frames = video_frames.reshape(B * T, C, H, W)
487
-
488
- # Encode
489
- latents = self.encode(video_frames, generator=generator)
490
-
491
- # Reshape back
492
- latents = latents.reshape(B, T, *latents.shape[1:])
493
- latents = latents.permute(0, 2, 1, 3, 4) # [B, 4, T, H//8, W//8]
494
- else: # [B, C, T, H, W]
495
- B, C, T, H, W = video_frames.shape
496
- video_frames = video_frames.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
497
-
498
- latents = self.encode(video_frames, generator=generator)
499
- latents = latents.reshape(B, T, *latents.shape[1:])
500
- latents = latents.permute(0, 2, 1, 3, 4)
501
-
502
- return latents
503
-
504
- # -----------------------------------------------------------------------------
505
- # 3. Core Architecture: OmniMMDitBlock (3D-Attention + Modulation)
506
- # -----------------------------------------------------------------------------
507
-
508
- class OmniMMDitBlock(nn.Module):
509
- def __init__(self, config: OmniMMDitV2Config, layer_idx: int):
510
- super().__init__()
511
- self.layer_idx = layer_idx
512
- self.hidden_size = config.hidden_size
513
- self.num_heads = config.num_attention_heads
514
- self.head_dim = config.hidden_size // config.num_attention_heads
515
-
516
- # Self-Attention with QK-Norm
517
- self.norm1 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
518
- self.attn = nn.MultiheadAttention(
519
- config.hidden_size, config.num_attention_heads, batch_first=True
520
- )
521
-
522
- self.q_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
523
- self.k_norm = OmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
524
-
525
- # Cross-Attention for multimodal fusion
526
- self.norm2 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
527
- self.cross_attn = nn.MultiheadAttention(
528
- config.hidden_size, config.num_attention_heads, batch_first=True
529
- )
530
-
531
- # Feed-Forward Network with SwiGLU activation
532
- self.norm3 = OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
533
- self.ffn = OmniSwiGLU(config)
534
-
535
- # Adaptive Layer Normalization with zero initialization
536
- self.adaLN_modulation = nn.Sequential(
537
- nn.SiLU(),
538
- nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True)
539
- )
540
-
541
- def forward(
542
- self,
543
- hidden_states: torch.Tensor,
544
- encoder_hidden_states: torch.Tensor, # Text embeddings
545
- visual_context: Optional[torch.Tensor], # Reference image embeddings
546
- timestep_emb: torch.Tensor,
547
- rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
548
- ) -> torch.Tensor:
549
-
550
- # AdaLN Modulation
551
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
552
- self.adaLN_modulation(timestep_emb)[:, None].chunk(6, dim=-1)
553
- )
554
-
555
- # Self-Attention block
556
- normed_hidden = self.norm1(hidden_states)
557
- normed_hidden = normed_hidden * (1 + scale_msa) + shift_msa
558
-
559
- attn_output, _ = self.attn(normed_hidden, normed_hidden, normed_hidden)
560
- hidden_states = hidden_states + gate_msa * attn_output
561
-
562
- # Cross-Attention with multimodal conditioning
563
- if visual_context is not None:
564
- context = torch.cat([encoder_hidden_states, visual_context], dim=1)
565
- else:
566
- context = encoder_hidden_states
567
-
568
- normed_hidden_cross = self.norm2(hidden_states)
569
- cross_output, _ = self.cross_attn(normed_hidden_cross, context, context)
570
- hidden_states = hidden_states + cross_output
571
-
572
- # Feed-Forward block
573
- normed_ffn = self.norm3(hidden_states)
574
- normed_ffn = normed_ffn * (1 + scale_mlp) + shift_mlp
575
- ffn_output = self.ffn(normed_ffn)
576
- hidden_states = hidden_states + gate_mlp * ffn_output
577
-
578
- return hidden_states
579
-
580
- # -----------------------------------------------------------------------------
581
- # 4. The Model: OmniMMDitV2
582
- # -----------------------------------------------------------------------------
583
-
584
- class OmniMMDitV2(ModelMixin, PreTrainedModel):
585
- """
586
- Omni-Modal Multi-Dimensional Diffusion Transformer V2.
587
- Supports: Text-to-Image, Image-to-Image (Edit), Image-to-Video.
588
- """
589
- config_class = OmniMMDitV2Config
590
- _supports_gradient_checkpointing = True
591
-
592
- def __init__(self, config: OmniMMDitV2Config):
593
- super().__init__(config)
594
- self.config = config
595
-
596
- # Initialize optimizer for advanced features
597
- self.optimizer = ModelOptimizer(
598
- fp8_config=FP8Config(enabled=config.use_fp8_quantization),
599
- compilation_config=CompilationConfig(
600
- enabled=config.use_compilation,
601
- mode=config.compile_mode,
602
- ),
603
- mixed_precision_config=MixedPrecisionConfig(
604
- enabled=True,
605
- dtype="bfloat16",
606
- ),
607
- )
608
-
609
- # Input Latent Projection (Patchify)
610
- self.x_embedder = nn.Linear(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, bias=True)
611
-
612
- # Time & Vector Embeddings
613
- self.t_embedder = TimestepEmbedder(config.hidden_size, config.frequency_embedding_size)
614
-
615
- # Visual Condition Projector (Handles 1-3 images)
616
- self.visual_projector = nn.Sequential(
617
- nn.Linear(config.visual_embed_dim, config.hidden_size),
618
- nn.LayerNorm(config.hidden_size),
619
- nn.Linear(config.hidden_size, config.hidden_size)
620
- )
621
-
622
- # Positional Embeddings (Absolute + RoPE dynamically handled)
623
- self.pos_embed = nn.Parameter(torch.zeros(1, config.max_position_embeddings, config.hidden_size), requires_grad=False)
624
-
625
- # Transformer Backbone
626
- self.blocks = nn.ModuleList([
627
- OmniMMDitBlock(config, i) for i in range(config.num_hidden_layers)
628
- ])
629
-
630
- # Final Layer (AdaLN-Zero + Linear)
631
- self.final_layer = nn.Sequential(
632
- OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
633
- nn.Linear(config.hidden_size, config.patch_size * config.patch_size * config.out_channels, bias=True)
634
- )
635
-
636
- self.initialize_weights()
637
-
638
- # Apply optimizations if enabled
639
- if config.use_fp8_quantization or config.use_compilation:
640
- self._apply_optimizations()
641
-
642
- def _apply_optimizations(self):
643
- """Apply FP8 quantization and compilation optimizations"""
644
- # Quantize transformer blocks
645
- if self.config.use_fp8_quantization:
646
- for i, block in enumerate(self.blocks):
647
- self.blocks[i] = self.optimizer.optimize_model(
648
- block,
649
- apply_compilation=False,
650
- apply_quantization=True,
651
- apply_mixed_precision=True,
652
- )
653
-
654
- # Compile forward method
655
- if self.config.use_compilation and HAS_TORCH_COMPILE:
656
- self.forward = torch.compile(
657
- self.forward,
658
- mode=self.config.compile_mode,
659
- dynamic=True,
660
- )
661
-
662
- def initialize_weights(self):
663
- def _basic_init(module):
664
- if isinstance(module, nn.Linear):
665
- torch.nn.init.xavier_uniform_(module.weight)
666
- if module.bias is not None:
667
- nn.init.constant_(module.bias, 0)
668
- self.apply(_basic_init)
669
-
670
- def unpatchify(self, x, h, w):
671
- c = self.config.out_channels
672
- p = self.config.patch_size
673
- h_ = h // p
674
- w_ = w // p
675
- x = x.reshape(shape=(x.shape[0], h_, w_, p, p, c))
676
- x = torch.einsum('nhwpqc->nchpwq', x)
677
- imgs = x.reshape(shape=(x.shape[0], c, h, w))
678
- return imgs
679
-
680
- def forward(
681
- self,
682
- hidden_states: torch.Tensor, # Noisy Latents [B, C, H, W] or [B, C, F, H, W]
683
- timestep: torch.LongTensor,
684
- encoder_hidden_states: torch.Tensor, # Text Embeddings
685
- visual_conditions: Optional[List[torch.Tensor]] = None, # List of [B, L, D]
686
- video_frames: Optional[int] = None, # If generating video
687
- return_dict: bool = True,
688
- ) -> Union[torch.Tensor, BaseOutput]:
689
-
690
- batch_size, channels, _, _ = hidden_states.shape
691
-
692
- # Patchify input latents
693
- p = self.config.patch_size
694
- h, w = hidden_states.shape[-2], hidden_states.shape[-1]
695
- x = hidden_states.unfold(2, p, p).unfold(3, p, p)
696
- x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
697
- x = x.view(batch_size, -1, channels * p * p)
698
-
699
- # Positional and temporal embeddings
700
- x = self.x_embedder(x)
701
- x = x + self.pos_embed[:, :x.shape[1], :]
702
-
703
- t = self.t_embedder(timestep, x.dtype)
704
-
705
- # Process visual conditioning
706
- visual_emb = None
707
- if visual_conditions is not None:
708
- concat_visuals = torch.cat(visual_conditions, dim=1)
709
- visual_emb = self.visual_projector(concat_visuals)
710
-
711
- # Transformer blocks
712
- for block in self.blocks:
713
- x = block(
714
- hidden_states=x,
715
- encoder_hidden_states=encoder_hidden_states,
716
- visual_context=visual_emb,
717
- timestep_emb=t
718
- )
719
-
720
- # Output projection
721
- x = self.final_layer[0](x)
722
- x = self.final_layer[1](x)
723
-
724
- # Unpatchify to image space
725
- output = self.unpatchify(x, h, w)
726
-
727
- if not return_dict:
728
- return (output,)
729
-
730
- return BaseOutput(sample=output)
731
-
732
- # -----------------------------------------------------------------------------
733
- # 5. The "Fancy" Pipeline
734
- # -----------------------------------------------------------------------------
735
-
736
- class OmniMMDitV2Pipeline(DiffusionPipeline):
737
- """
738
- Omni-Modal Diffusion Transformer Pipeline.
739
-
740
- Supports text-guided image editing and video generation with
741
- multi-image conditioning and advanced guidance techniques.
742
- """
743
- model: OmniMMDitV2
744
- tokenizer: CLIPTokenizer
745
- text_encoder: CLIPTextModel
746
- vae: Any # AutoencoderKL
747
- scheduler: DDIMScheduler
748
-
749
- _optional_components = ["visual_encoder"]
750
-
751
- def __init__(
752
- self,
753
- model: OmniMMDitV2,
754
- vae: Any,
755
- text_encoder: CLIPTextModel,
756
- tokenizer: CLIPTokenizer,
757
- scheduler: DDIMScheduler,
758
- visual_encoder: Optional[Any] = None,
759
- ):
760
- super().__init__()
761
- self.register_modules(
762
- model=model,
763
- vae=vae,
764
- text_encoder=text_encoder,
765
- tokenizer=tokenizer,
766
- scheduler=scheduler,
767
- visual_encoder=visual_encoder
768
- )
769
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
770
-
771
- # Initialize data processors
772
- self.image_processor = OmniImageProcessor(
773
- size=(512, 512),
774
- interpolation="bicubic",
775
- do_normalize=True,
776
- )
777
- self.video_processor = OmniVideoProcessor(
778
- image_processor=self.image_processor,
779
- num_frames=16,
780
- )
781
- self.latent_processor = OmniLatentProcessor(
782
- vae=vae,
783
- scaling_factor=0.18215,
784
- )
785
-
786
- # Initialize model optimizer
787
- self.model_optimizer = ModelOptimizer(
788
- fp8_config=FP8Config(enabled=False), # Can be enabled via enable_fp8()
789
- compilation_config=CompilationConfig(enabled=False), # Can be enabled via compile()
790
- mixed_precision_config=MixedPrecisionConfig(enabled=True, dtype="bfloat16"),
791
- )
792
-
793
- self._is_compiled = False
794
- self._is_fp8_enabled = False
795
-
796
- def enable_fp8_quantization(self):
797
- """Enable FP8 quantization for faster inference"""
798
- if not HAS_TRANSFORMER_ENGINE:
799
- warnings.warn("Transformer Engine not available. Install with: pip install transformer-engine")
800
- return self
801
-
802
- self.model_optimizer.fp8_config.enabled = True
803
- self.model = self.model_optimizer.optimize_model(
804
- self.model,
805
- apply_compilation=False,
806
- apply_quantization=True,
807
- apply_mixed_precision=False,
808
- )
809
- self._is_fp8_enabled = True
810
- return self
811
-
812
- def compile_model(
813
- self,
814
- mode: str = "reduce-overhead",
815
- fullgraph: bool = False,
816
- dynamic: bool = True,
817
- ):
818
- """
819
- Compile model using torch.compile for faster inference.
820
-
821
- Args:
822
- mode: Compilation mode - "default", "reduce-overhead", "max-autotune"
823
- fullgraph: Whether to compile the entire model as one graph
824
- dynamic: Whether to enable dynamic shapes
825
- """
826
- if not HAS_TORCH_COMPILE:
827
- warnings.warn("torch.compile not available. Upgrade to PyTorch 2.0+")
828
- return self
829
-
830
- self.model_optimizer.compilation_config = CompilationConfig(
831
- enabled=True,
832
- mode=mode,
833
- fullgraph=fullgraph,
834
- dynamic=dynamic,
835
- )
836
-
837
- self.model = self.model_optimizer._compile_model(self.model)
838
- self._is_compiled = True
839
- return self
840
-
841
- def enable_optimizations(
842
- self,
843
- enable_fp8: bool = False,
844
- enable_compilation: bool = False,
845
- compilation_mode: str = "reduce-overhead",
846
- ):
847
- """
848
- Enable all optimizations at once.
849
-
850
- Args:
851
- enable_fp8: Enable FP8 quantization
852
- enable_compilation: Enable torch.compile
853
- compilation_mode: Compilation mode for torch.compile
854
- """
855
- if enable_fp8:
856
- self.enable_fp8_quantization()
857
-
858
- if enable_compilation:
859
- self.compile_model(mode=compilation_mode)
860
-
861
- return self
862
-
863
- @torch.no_grad()
864
- def __call__(
865
- self,
866
- prompt: Union[str, List[str]] = None,
867
- input_images: Optional[List[Union[torch.Tensor, Any]]] = None,
868
- height: Optional[int] = 1024,
869
- width: Optional[int] = 1024,
870
- num_frames: Optional[int] = 1,
871
- num_inference_steps: int = 50,
872
- guidance_scale: float = 7.5,
873
- image_guidance_scale: float = 1.5,
874
- negative_prompt: Optional[Union[str, List[str]]] = None,
875
- eta: float = 0.0,
876
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
877
- latents: Optional[torch.Tensor] = None,
878
- output_type: Optional[str] = "pil",
879
- return_dict: bool = True,
880
- callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
881
- callback_steps: int = 1,
882
- use_optimized_inference: bool = True,
883
- **kwargs,
884
- ):
885
- # Use optimized inference context
886
- with optimized_inference_mode(
887
- enable_cudnn_benchmark=use_optimized_inference,
888
- enable_tf32=use_optimized_inference,
889
- enable_flash_sdp=use_optimized_inference,
890
- ):
891
- return self._forward_impl(
892
- prompt=prompt,
893
- input_images=input_images,
894
- height=height,
895
- width=width,
896
- num_frames=num_frames,
897
- num_inference_steps=num_inference_steps,
898
- guidance_scale=guidance_scale,
899
- image_guidance_scale=image_guidance_scale,
900
- negative_prompt=negative_prompt,
901
- eta=eta,
902
- generator=generator,
903
- latents=latents,
904
- output_type=output_type,
905
- return_dict=return_dict,
906
- callback=callback,
907
- callback_steps=callback_steps,
908
- **kwargs,
909
- )
910
-
911
- def _forward_impl(
912
- self,
913
- prompt: Union[str, List[str]] = None,
914
- input_images: Optional[List[Union[torch.Tensor, Any]]] = None,
915
- height: Optional[int] = 1024,
916
- width: Optional[int] = 1024,
917
- num_frames: Optional[int] = 1,
918
- num_inference_steps: int = 50,
919
- guidance_scale: float = 7.5,
920
- image_guidance_scale: float = 1.5,
921
- negative_prompt: Optional[Union[str, List[str]]] = None,
922
- eta: float = 0.0,
923
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
924
- latents: Optional[torch.Tensor] = None,
925
- output_type: Optional[str] = "pil",
926
- return_dict: bool = True,
927
- callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
928
- callback_steps: int = 1,
929
- **kwargs,
930
- ):
931
- # Validate and set default dimensions
932
- height = height or self.model.config.sample_size * self.vae_scale_factor
933
- width = width or self.model.config.sample_size * self.vae_scale_factor
934
-
935
- # Encode text prompts
936
- if isinstance(prompt, str):
937
- prompt = [prompt]
938
- batch_size = len(prompt)
939
-
940
- text_inputs = self.tokenizer(
941
- prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt"
942
- )
943
- text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]
944
-
945
- # Encode visual conditions with preprocessing
946
- visual_embeddings_list = []
947
- if input_images:
948
- if not isinstance(input_images, list):
949
- input_images = [input_images]
950
- if len(input_images) > 3:
951
- raise ValueError("Maximum 3 reference images supported")
952
-
953
- for img in input_images:
954
- # Preprocess image
955
- if not isinstance(img, torch.Tensor):
956
- img_tensor = self.image_processor.preprocess(img, return_tensors="pt")
957
- else:
958
- img_tensor = img
959
-
960
- img_tensor = img_tensor.to(device=self.device, dtype=text_embeddings.dtype)
961
-
962
- # Encode with visual encoder
963
- if self.visual_encoder is not None:
964
- vis_emb = self.visual_encoder(img_tensor).last_hidden_state
965
- else:
966
- # Fallback: use VAE encoder + projection
967
- with torch.no_grad():
968
- latent_features = self.vae.encode(img_tensor * 2 - 1).latent_dist.mode()
969
- B, C, H, W = latent_features.shape
970
- # Flatten spatial dims and project
971
- vis_emb = latent_features.flatten(2).transpose(1, 2) # [B, H*W, C]
972
- # Simple projection to visual_embed_dim
973
- if vis_emb.shape[-1] != self.model.config.visual_embed_dim:
974
- proj = nn.Linear(vis_emb.shape[-1], self.model.config.visual_embed_dim).to(self.device)
975
- vis_emb = proj(vis_emb)
976
-
977
- visual_embeddings_list.append(vis_emb)
978
-
979
- # Prepare timesteps
980
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
981
- timesteps = self.scheduler.timesteps
982
-
983
- # Initialize latent space
984
- num_channels_latents = self.model.config.in_channels
985
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
986
- if num_frames > 1:
987
- shape = (batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor)
988
-
989
- latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
990
- latents = latents * self.scheduler.init_noise_sigma
991
-
992
- # Denoising loop with optimizations
993
- with self.progress_bar(total=num_inference_steps) as progress_bar:
994
- for i, t in enumerate(timesteps):
995
- latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
996
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
997
-
998
- # Use mixed precision autocast
999
- with self.model_optimizer.autocast_context():
1000
- noise_pred = self.model(
1001
- hidden_states=latent_model_input,
1002
- timestep=t,
1003
- encoder_hidden_states=torch.cat([text_embeddings] * 2),
1004
- visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
1005
- video_frames=num_frames
1006
- ).sample
1007
-
1008
- # Apply classifier-free guidance
1009
- if guidance_scale > 1.0:
1010
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1011
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1012
-
1013
- latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
1014
-
1015
- # Call callback if provided
1016
- if callback is not None and i % callback_steps == 0:
1017
- callback(i, t, latents)
1018
-
1019
- progress_bar.update()
1020
-
1021
- # Decode latents with proper post-processing
1022
- if output_type == "latent":
1023
- output_images = latents
1024
- else:
1025
- # Decode latents to pixel space
1026
- with torch.no_grad():
1027
- if num_frames > 1:
1028
- # Video decoding: process frame by frame
1029
- B, C, T, H, W = latents.shape
1030
- latents_2d = latents.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
1031
- decoded = self.latent_processor.decode(latents_2d)
1032
- decoded = decoded.reshape(B, T, 3, H * 8, W * 8)
1033
-
1034
- # Convert to [0, 1] range
1035
- decoded = (decoded / 2 + 0.5).clamp(0, 1)
1036
-
1037
- # Post-process video
1038
- if output_type == "pil":
1039
- output_images = self.video_processor.postprocess_video(decoded, output_type="pil")
1040
- elif output_type == "np":
1041
- output_images = decoded.cpu().numpy()
1042
- else:
1043
- output_images = decoded
1044
- else:
1045
- # Image decoding
1046
- decoded = self.latent_processor.decode(latents)
1047
- decoded = (decoded / 2 + 0.5).clamp(0, 1)
1048
-
1049
- # Post-process images
1050
- if output_type == "pil":
1051
- output_images = self.image_processor.postprocess(decoded, output_type="pil")
1052
- elif output_type == "np":
1053
- output_images = decoded.cpu().numpy()
1054
- else:
1055
- output_images = decoded
1056
-
1057
- if not return_dict:
1058
- return (output_images,)
1059
-
1060
- return BaseOutput(images=output_images)
1061
-
1062
- # -----------------------------------------------------------------------------
1063
- # 6. Advanced Multi-Modal Window Attention Block (Audio + Video + Image)
1064
- # -----------------------------------------------------------------------------
1065
-
1066
- @dataclass
1067
- class MultiModalInput:
1068
- """Container for multi-modal inputs"""
1069
- image_embeds: Optional[torch.Tensor] = None # [B, L_img, D]
1070
- video_embeds: Optional[torch.Tensor] = None # [B, T_video, L_vid, D]
1071
- audio_embeds: Optional[torch.Tensor] = None # [B, T_audio, L_aud, D]
1072
- attention_mask: Optional[torch.Tensor] = None # [B, total_length]
1073
-
1074
-
1075
- class TemporalWindowPartition(nn.Module):
1076
- """
1077
- Partition temporal sequences into windows for efficient attention.
1078
- Supports both uniform and adaptive windowing strategies.
1079
- """
1080
- def __init__(
1081
- self,
1082
- window_size: int = 8,
1083
- shift_size: int = 0,
1084
- use_adaptive_window: bool = False,
1085
- ):
1086
- super().__init__()
1087
- self.window_size = window_size
1088
- self.shift_size = shift_size
1089
- self.use_adaptive_window = use_adaptive_window
1090
-
1091
- def partition(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, Any]]:
1092
- """
1093
- Partition sequence into windows.
1094
-
1095
- Args:
1096
- x: Input tensor [B, T, L, D] or [B, L, D]
1097
-
1098
- Returns:
1099
- windowed: [B * num_windows, window_size, L, D]
1100
- info: Dictionary with partition information
1101
- """
1102
- if x.ndim == 3: # Static input (image)
1103
- return x, {"is_temporal": False, "original_shape": x.shape}
1104
-
1105
- B, T, L, D = x.shape
1106
-
1107
- # Apply temporal shift for shifted window attention (Swin-Transformer style)
1108
- if self.shift_size > 0:
1109
- x = torch.roll(x, shifts=-self.shift_size, dims=1)
1110
-
1111
- # Pad if necessary
1112
- pad_t = (self.window_size - T % self.window_size) % self.window_size
1113
- if pad_t > 0:
1114
- x = F.pad(x, (0, 0, 0, 0, 0, pad_t))
1115
-
1116
- T_padded = T + pad_t
1117
- num_windows = T_padded // self.window_size
1118
-
1119
- # Reshape into windows: [B, num_windows, window_size, L, D]
1120
- x_windowed = x.view(B, num_windows, self.window_size, L, D)
1121
-
1122
- # Merge batch and window dims: [B * num_windows, window_size, L, D]
1123
- x_windowed = x_windowed.view(B * num_windows, self.window_size, L, D)
1124
-
1125
- info = {
1126
- "is_temporal": True,
1127
- "original_shape": (B, T, L, D),
1128
- "num_windows": num_windows,
1129
- "pad_t": pad_t,
1130
- }
1131
-
1132
- return x_windowed, info
1133
-
1134
- def merge(self, x_windowed: torch.Tensor, info: Dict[str, Any]) -> torch.Tensor:
1135
- """
1136
- Merge windows back to original sequence.
1137
-
1138
- Args:
1139
- x_windowed: Windowed tensor [B * num_windows, window_size, L, D]
1140
- info: Partition information from partition()
1141
-
1142
- Returns:
1143
- x: Merged tensor [B, T, L, D] or [B, L, D]
1144
- """
1145
- if not info["is_temporal"]:
1146
- return x_windowed
1147
-
1148
- B, T, L, D = info["original_shape"]
1149
- num_windows = info["num_windows"]
1150
- pad_t = info["pad_t"]
1151
-
1152
- # Reshape: [B * num_windows, window_size, L, D] -> [B, num_windows, window_size, L, D]
1153
- x = x_windowed.view(B, num_windows, self.window_size, L, D)
1154
-
1155
- # Merge windows: [B, T_padded, L, D]
1156
- x = x.view(B, num_windows * self.window_size, L, D)
1157
-
1158
- # Remove padding
1159
- if pad_t > 0:
1160
- x = x[:, :-pad_t, :, :]
1161
-
1162
- # Reverse temporal shift
1163
- if self.shift_size > 0:
1164
- x = torch.roll(x, shifts=self.shift_size, dims=1)
1165
-
1166
- return x
1167
-
1168
-
1169
- class WindowCrossAttention(nn.Module):
1170
- """
1171
- Window-based Cross Attention with support for temporal sequences.
1172
- Performs attention within local windows for computational efficiency.
1173
- """
1174
- def __init__(
1175
- self,
1176
- dim: int,
1177
- num_heads: int = 8,
1178
- window_size: int = 8,
1179
- qkv_bias: bool = True,
1180
- attn_drop: float = 0.0,
1181
- proj_drop: float = 0.0,
1182
- use_relative_position_bias: bool = True,
1183
- ):
1184
- super().__init__()
1185
- self.dim = dim
1186
- self.num_heads = num_heads
1187
- self.window_size = window_size
1188
- self.head_dim = dim // num_heads
1189
- self.scale = self.head_dim ** -0.5
1190
-
1191
- # Query, Key, Value projections
1192
- self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
1193
- self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
1194
- self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
1195
-
1196
- # QK Normalization for stability
1197
- self.q_norm = OmniRMSNorm(self.head_dim)
1198
- self.k_norm = OmniRMSNorm(self.head_dim)
1199
-
1200
- # Attention dropout
1201
- self.attn_drop = nn.Dropout(attn_drop)
1202
-
1203
- # Output projection
1204
- self.proj = nn.Linear(dim, dim)
1205
- self.proj_drop = nn.Dropout(proj_drop)
1206
-
1207
- # Relative position bias (for temporal coherence)
1208
- self.use_relative_position_bias = use_relative_position_bias
1209
- if use_relative_position_bias:
1210
- # Temporal relative position bias
1211
- self.relative_position_bias_table = nn.Parameter(
1212
- torch.zeros((2 * window_size - 1), num_heads)
1213
- )
1214
- nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
1215
-
1216
- # Get relative position index
1217
- coords = torch.arange(window_size)
1218
- relative_coords = coords[:, None] - coords[None, :] # [window_size, window_size]
1219
- relative_coords += window_size - 1 # Shift to start from 0
1220
- self.register_buffer("relative_position_index", relative_coords)
1221
-
1222
- def get_relative_position_bias(self, window_size: int) -> torch.Tensor:
1223
- """Generate relative position bias for attention"""
1224
- if not self.use_relative_position_bias:
1225
- return None
1226
-
1227
- relative_position_bias = self.relative_position_bias_table[
1228
- self.relative_position_index[:window_size, :window_size].reshape(-1)
1229
- ].reshape(window_size, window_size, -1)
1230
-
1231
- # Permute to [num_heads, window_size, window_size]
1232
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
1233
- return relative_position_bias
1234
-
1235
- def forward(
1236
- self,
1237
- query: torch.Tensor, # [B, T_q, L_q, D] or [B, L_q, D]
1238
- key: torch.Tensor, # [B, T_k, L_k, D] or [B, L_k, D]
1239
- value: torch.Tensor, # [B, T_v, L_v, D] or [B, L_v, D]
1240
- attention_mask: Optional[torch.Tensor] = None,
1241
- ) -> torch.Tensor:
1242
- """
1243
- Perform windowed cross attention.
1244
-
1245
- Args:
1246
- query: Query tensor
1247
- key: Key tensor
1248
- value: Value tensor
1249
- attention_mask: Optional attention mask
1250
-
1251
- Returns:
1252
- Output tensor with same shape as query
1253
- """
1254
- # Handle both temporal and non-temporal inputs
1255
- is_temporal = query.ndim == 4
1256
-
1257
- if is_temporal:
1258
- B, T_q, L_q, D = query.shape
1259
- _, T_k, L_k, _ = key.shape
1260
-
1261
- # Flatten temporal and spatial dims for cross attention
1262
- query_flat = query.reshape(B, T_q * L_q, D)
1263
- key_flat = key.reshape(B, T_k * L_k, D)
1264
- value_flat = value.reshape(B, T_k * L_k, D)
1265
- else:
1266
- B, L_q, D = query.shape
1267
- _, L_k, _ = key.shape
1268
- query_flat = query
1269
- key_flat = key
1270
- value_flat = value
1271
-
1272
- # Project to Q, K, V
1273
- q = self.q_proj(query_flat) # [B, N_q, D]
1274
- k = self.k_proj(key_flat) # [B, N_k, D]
1275
- v = self.v_proj(value_flat) # [B, N_v, D]
1276
-
1277
- # Reshape for multi-head attention
1278
- q = q.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_q, head_dim]
1279
- k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_k, head_dim]
1280
- v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N_v, head_dim]
1281
-
1282
- # Apply QK normalization
1283
- q = self.q_norm(q)
1284
- k = self.k_norm(k)
1285
-
1286
- # Scaled dot-product attention
1287
- attn = (q @ k.transpose(-2, -1)) * self.scale # [B, H, N_q, N_k]
1288
-
1289
- # Add relative position bias if temporal
1290
- if is_temporal and self.use_relative_position_bias:
1291
- # Apply per-window bias
1292
- rel_bias = self.get_relative_position_bias(min(T_q, self.window_size))
1293
- if rel_bias is not None:
1294
- # Broadcast bias across spatial dimensions
1295
- attn = attn + rel_bias.unsqueeze(0).unsqueeze(2)
1296
-
1297
- # Apply attention mask
1298
- if attention_mask is not None:
1299
- attn = attn.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
1300
-
1301
- # Softmax and dropout
1302
- attn = F.softmax(attn, dim=-1)
1303
- attn = self.attn_drop(attn)
1304
-
1305
- # Apply attention to values
1306
- out = (attn @ v).transpose(1, 2).reshape(B, -1, D) # [B, N_q, D]
1307
-
1308
- # Output projection
1309
- out = self.proj(out)
1310
- out = self.proj_drop(out)
1311
-
1312
- # Reshape back to original shape
1313
- if is_temporal:
1314
- out = out.reshape(B, T_q, L_q, D)
1315
- else:
1316
- out = out.reshape(B, L_q, D)
1317
-
1318
- return out
1319
-
1320
-
1321
- class MultiModalFusionLayer(nn.Module):
1322
- """
1323
- Fuses multiple modalities (audio, video, image) with learnable fusion weights.
1324
- """
1325
- def __init__(
1326
- self,
1327
- dim: int,
1328
- num_modalities: int = 3,
1329
- fusion_type: str = "weighted", # "weighted", "gated", "adaptive"
1330
- ):
1331
- super().__init__()
1332
- self.dim = dim
1333
- self.num_modalities = num_modalities
1334
- self.fusion_type = fusion_type
1335
-
1336
- if fusion_type == "weighted":
1337
- # Learnable fusion weights
1338
- self.fusion_weights = nn.Parameter(torch.ones(num_modalities) / num_modalities)
1339
-
1340
- elif fusion_type == "gated":
1341
- # Gated fusion with cross-modal interactions
1342
- self.gate_proj = nn.Sequential(
1343
- nn.Linear(dim * num_modalities, dim * 2),
1344
- nn.GELU(),
1345
- nn.Linear(dim * 2, num_modalities),
1346
- nn.Softmax(dim=-1)
1347
- )
1348
-
1349
- elif fusion_type == "adaptive":
1350
- # Adaptive fusion with per-token gating
1351
- self.adaptive_gate = nn.Sequential(
1352
- nn.Linear(dim, dim // 2),
1353
- nn.GELU(),
1354
- nn.Linear(dim // 2, num_modalities),
1355
- nn.Sigmoid()
1356
- )
1357
-
1358
- def forward(self, modality_features: List[torch.Tensor]) -> torch.Tensor:
1359
- """
1360
- Fuse multiple modality features.
1361
-
1362
- Args:
1363
- modality_features: List of [B, L, D] tensors for each modality
1364
-
1365
- Returns:
1366
- fused: Fused features [B, L, D]
1367
- """
1368
- if self.fusion_type == "weighted":
1369
- # Simple weighted sum
1370
- weights = F.softmax(self.fusion_weights, dim=0)
1371
- fused = sum(w * feat for w, feat in zip(weights, modality_features))
1372
-
1373
- elif self.fusion_type == "gated":
1374
- # Concatenate and compute gates
1375
- concat_features = torch.cat(modality_features, dim=-1) # [B, L, D * num_modalities]
1376
- gates = self.gate_proj(concat_features) # [B, L, num_modalities]
1377
-
1378
- # Apply gates
1379
- stacked = torch.stack(modality_features, dim=-1) # [B, L, D, num_modalities]
1380
- fused = (stacked * gates.unsqueeze(2)).sum(dim=-1) # [B, L, D]
1381
-
1382
- elif self.fusion_type == "adaptive":
1383
- # Adaptive per-token fusion
1384
- fused_list = []
1385
- for feat in modality_features:
1386
- gate = self.adaptive_gate(feat) # [B, L, num_modalities]
1387
- fused_list.append(feat.unsqueeze(-1) * gate.unsqueeze(2))
1388
-
1389
- fused = torch.cat(fused_list, dim=-1).sum(dim=-1) # [B, L, D]
1390
-
1391
- return fused
1392
-
1393
-
1394
- class FancyMultiModalWindowAttentionBlock(nn.Module):
1395
- """
1396
- 🎯 Fancy Multi-Modal Window Attention Block
1397
-
1398
- A state-of-the-art block that processes audio, video, and image embeddings
1399
- with temporal window-based cross-attention for efficient multi-modal fusion.
1400
-
1401
- Features:
1402
- - ✨ Temporal windowing for audio and video (frame-by-frame processing)
1403
- - 🪟 Shifted window attention for better temporal coherence (Swin-style)
1404
- - 🔄 Cross-modal attention between all modality pairs
1405
- - 🎭 Adaptive multi-modal fusion with learnable gates
1406
- - 🚀 Efficient computation with window partitioning
1407
- - 💎 QK normalization for training stability
1408
-
1409
- Architecture:
1410
- 1. Temporal Partitioning (audio/video frames → windows)
1411
- 2. Intra-Modal Self-Attention (within each modality)
1412
- 3. Inter-Modal Cross-Attention (audio ↔ video ↔ image)
1413
- 4. Multi-Modal Fusion (adaptive weighted combination)
1414
- 5. Feed-Forward Network (SwiGLU activation)
1415
- 6. Window Merging (reconstruct temporal sequences)
1416
- """
1417
-
1418
- def __init__(
1419
- self,
1420
- dim: int = 1024,
1421
- num_heads: int = 16,
1422
- window_size: int = 8,
1423
- shift_size: int = 4,
1424
- mlp_ratio: float = 4.0,
1425
- qkv_bias: bool = True,
1426
- drop: float = 0.0,
1427
- attn_drop: float = 0.0,
1428
- drop_path: float = 0.1,
1429
- use_relative_position_bias: bool = True,
1430
- fusion_type: str = "adaptive", # "weighted", "gated", "adaptive"
1431
- use_shifted_window: bool = True,
1432
- ):
1433
- super().__init__()
1434
- self.dim = dim
1435
- self.num_heads = num_heads
1436
- self.window_size = window_size
1437
- self.shift_size = shift_size if use_shifted_window else 0
1438
- self.mlp_ratio = mlp_ratio
1439
-
1440
- # =============== Temporal Window Partitioning ===============
1441
- self.window_partition = TemporalWindowPartition(
1442
- window_size=window_size,
1443
- shift_size=self.shift_size,
1444
- )
1445
-
1446
- # =============== Intra-Modal Self-Attention ===============
1447
- self.norm_audio_self = OmniRMSNorm(dim)
1448
- self.norm_video_self = OmniRMSNorm(dim)
1449
- self.norm_image_self = OmniRMSNorm(dim)
1450
-
1451
- self.audio_self_attn = WindowCrossAttention(
1452
- dim=dim,
1453
- num_heads=num_heads,
1454
- window_size=window_size,
1455
- qkv_bias=qkv_bias,
1456
- attn_drop=attn_drop,
1457
- proj_drop=drop,
1458
- use_relative_position_bias=use_relative_position_bias,
1459
- )
1460
-
1461
- self.video_self_attn = WindowCrossAttention(
1462
- dim=dim,
1463
- num_heads=num_heads,
1464
- window_size=window_size,
1465
- qkv_bias=qkv_bias,
1466
- attn_drop=attn_drop,
1467
- proj_drop=drop,
1468
- use_relative_position_bias=use_relative_position_bias,
1469
- )
1470
-
1471
- self.image_self_attn = WindowCrossAttention(
1472
- dim=dim,
1473
- num_heads=num_heads,
1474
- window_size=window_size,
1475
- qkv_bias=qkv_bias,
1476
- attn_drop=attn_drop,
1477
- proj_drop=drop,
1478
- use_relative_position_bias=False, # No temporal bias for static images
1479
- )
1480
-
1481
- # =============== Inter-Modal Cross-Attention ===============
1482
- # Audio → Video/Image
1483
- self.norm_audio_cross = OmniRMSNorm(dim)
1484
- self.audio_to_visual = WindowCrossAttention(
1485
- dim=dim, num_heads=num_heads, window_size=window_size,
1486
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1487
- )
1488
-
1489
- # Video → Audio/Image
1490
- self.norm_video_cross = OmniRMSNorm(dim)
1491
- self.video_to_others = WindowCrossAttention(
1492
- dim=dim, num_heads=num_heads, window_size=window_size,
1493
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1494
- )
1495
-
1496
- # Image → Audio/Video
1497
- self.norm_image_cross = OmniRMSNorm(dim)
1498
- self.image_to_temporal = WindowCrossAttention(
1499
- dim=dim, num_heads=num_heads, window_size=window_size,
1500
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1501
- )
1502
-
1503
- # =============== Multi-Modal Fusion ===============
1504
- self.multimodal_fusion = MultiModalFusionLayer(
1505
- dim=dim,
1506
- num_modalities=3,
1507
- fusion_type=fusion_type,
1508
- )
1509
-
1510
- # =============== Feed-Forward Network ===============
1511
- self.norm_ffn = OmniRMSNorm(dim)
1512
- mlp_hidden_dim = int(dim * mlp_ratio)
1513
- self.ffn = nn.Sequential(
1514
- nn.Linear(dim, mlp_hidden_dim, bias=False),
1515
- nn.GELU(),
1516
- nn.Dropout(drop),
1517
- nn.Linear(mlp_hidden_dim, dim, bias=False),
1518
- nn.Dropout(drop),
1519
- )
1520
-
1521
- # =============== Stochastic Depth (Drop Path) ===============
1522
- self.drop_path = nn.Identity() if drop_path <= 0. else nn.Dropout(drop_path)
1523
-
1524
- # =============== Output Projections ===============
1525
- self.output_projection = nn.ModuleDict({
1526
- 'audio': nn.Linear(dim, dim),
1527
- 'video': nn.Linear(dim, dim),
1528
- 'image': nn.Linear(dim, dim),
1529
- })
1530
-
1531
- def forward(
1532
- self,
1533
- audio_embeds: Optional[torch.Tensor] = None, # [B, T_audio, L_audio, D]
1534
- video_embeds: Optional[torch.Tensor] = None, # [B, T_video, L_video, D]
1535
- image_embeds: Optional[torch.Tensor] = None, # [B, L_image, D]
1536
- attention_mask: Optional[torch.Tensor] = None,
1537
- return_intermediates: bool = False,
1538
- ) -> Dict[str, torch.Tensor]:
1539
- """
1540
- Forward pass of the Fancy Multi-Modal Window Attention Block.
1541
-
1542
- Args:
1543
- audio_embeds: Audio embeddings [B, T_audio, L_audio, D]
1544
- T_audio: number of audio frames
1545
- L_audio: sequence length per frame
1546
- video_embeds: Video embeddings [B, T_video, L_video, D]
1547
- T_video: number of video frames
1548
- L_video: sequence length per frame (e.g., patches)
1549
- image_embeds: Image embeddings [B, L_image, D]
1550
- L_image: sequence length (e.g., image patches)
1551
- attention_mask: Optional attention mask
1552
- return_intermediates: Whether to return intermediate features
1553
-
1554
- Returns:
1555
- outputs: Dictionary containing processed embeddings for each modality
1556
- - 'audio': [B, T_audio, L_audio, D]
1557
- - 'video': [B, T_video, L_video, D]
1558
- - 'image': [B, L_image, D]
1559
- - 'fused': [B, L_total, D] (optional)
1560
- """
1561
- intermediates = {} if return_intermediates else None
1562
-
1563
- # ========== Stage 1: Temporal Window Partitioning ==========
1564
- partitioned_audio, audio_info = None, None
1565
- partitioned_video, video_info = None, None
1566
-
1567
- if audio_embeds is not None:
1568
- partitioned_audio, audio_info = self.window_partition.partition(audio_embeds)
1569
- if return_intermediates:
1570
- intermediates['audio_windows'] = partitioned_audio
1571
-
1572
- if video_embeds is not None:
1573
- partitioned_video, video_info = self.window_partition.partition(video_embeds)
1574
- if return_intermediates:
1575
- intermediates['video_windows'] = partitioned_video
1576
-
1577
- # ========== Stage 2: Intra-Modal Self-Attention ==========
1578
- audio_self_out, video_self_out, image_self_out = None, None, None
1579
-
1580
- if audio_embeds is not None:
1581
- audio_normed = self.norm_audio_self(partitioned_audio)
1582
- audio_self_out = self.audio_self_attn(audio_normed, audio_normed, audio_normed)
1583
- audio_self_out = partitioned_audio + self.drop_path(audio_self_out)
1584
-
1585
- if video_embeds is not None:
1586
- video_normed = self.norm_video_self(partitioned_video)
1587
- video_self_out = self.video_self_attn(video_normed, video_normed, video_normed)
1588
- video_self_out = partitioned_video + self.drop_path(video_self_out)
1589
-
1590
- if image_embeds is not None:
1591
- image_normed = self.norm_image_self(image_embeds)
1592
- image_self_out = self.image_self_attn(image_normed, image_normed, image_normed)
1593
- image_self_out = image_embeds + self.drop_path(image_self_out)
1594
-
1595
- # ========== Stage 3: Inter-Modal Cross-Attention ==========
1596
- audio_cross_out, video_cross_out, image_cross_out = None, None, None
1597
-
1598
- # Prepare context (merge windows temporarily for cross-attention)
1599
- if audio_self_out is not None:
1600
- audio_merged = self.window_partition.merge(audio_self_out, audio_info)
1601
- if video_self_out is not None:
1602
- video_merged = self.window_partition.merge(video_self_out, video_info)
1603
-
1604
- # Audio attends to Video and Image
1605
- if audio_embeds is not None:
1606
- audio_q = self.norm_audio_cross(audio_merged)
1607
-
1608
- # Create key-value context from other modalities
1609
- kv_list = []
1610
- if video_embeds is not None:
1611
- kv_list.append(video_merged)
1612
- if image_embeds is not None:
1613
- # Expand image to match temporal dimension
1614
- B, L_img, D = image_self_out.shape
1615
- T_audio = audio_merged.shape[1]
1616
- image_expanded = image_self_out.unsqueeze(1).expand(B, T_audio, L_img, D)
1617
- kv_list.append(image_expanded)
1618
-
1619
- if kv_list:
1620
- # Concatenate along sequence dimension
1621
- kv_context = torch.cat([kv.flatten(1, 2) for kv in kv_list], dim=1)
1622
- kv_context = kv_context.reshape(B, -1, D)
1623
-
1624
- audio_cross_out = self.audio_to_visual(
1625
- audio_q.flatten(1, 2),
1626
- kv_context,
1627
- kv_context,
1628
- attention_mask
1629
- )
1630
- audio_cross_out = audio_cross_out.reshape_as(audio_merged)
1631
- audio_cross_out = audio_merged + self.drop_path(audio_cross_out)
1632
- else:
1633
- audio_cross_out = audio_merged
1634
-
1635
- # Video attends to Audio and Image
1636
- if video_embeds is not None:
1637
- video_q = self.norm_video_cross(video_merged)
1638
-
1639
- kv_list = []
1640
- if audio_embeds is not None:
1641
- kv_list.append(audio_merged if audio_cross_out is None else audio_cross_out)
1642
- if image_embeds is not None:
1643
- B, L_img, D = image_self_out.shape
1644
- T_video = video_merged.shape[1]
1645
- image_expanded = image_self_out.unsqueeze(1).expand(B, T_video, L_img, D)
1646
- kv_list.append(image_expanded)
1647
-
1648
- if kv_list:
1649
- kv_context = torch.cat([kv.flatten(1, 2) for kv in kv_list], dim=1)
1650
- kv_context = kv_context.reshape(B, -1, D)
1651
-
1652
- video_cross_out = self.video_to_others(
1653
- video_q.flatten(1, 2),
1654
- kv_context,
1655
- kv_context,
1656
- attention_mask
1657
- )
1658
- video_cross_out = video_cross_out.reshape_as(video_merged)
1659
- video_cross_out = video_merged + self.drop_path(video_cross_out)
1660
- else:
1661
- video_cross_out = video_merged
1662
-
1663
- # Image attends to Audio and Video
1664
- if image_embeds is not None:
1665
- image_q = self.norm_image_cross(image_self_out)
1666
-
1667
- kv_list = []
1668
- if audio_embeds is not None:
1669
- # Average pool audio over time for image
1670
- audio_pooled = (audio_merged if audio_cross_out is None else audio_cross_out).mean(dim=1)
1671
- kv_list.append(audio_pooled)
1672
- if video_embeds is not None:
1673
- # Average pool video over time for image
1674
- video_pooled = (video_merged if video_cross_out is None else video_cross_out).mean(dim=1)
1675
- kv_list.append(video_pooled)
1676
-
1677
- if kv_list:
1678
- kv_context = torch.cat(kv_list, dim=1)
1679
-
1680
- image_cross_out = self.image_to_temporal(
1681
- image_q,
1682
- kv_context,
1683
- kv_context,
1684
- attention_mask
1685
- )
1686
- image_cross_out = image_self_out + self.drop_path(image_cross_out)
1687
- else:
1688
- image_cross_out = image_self_out
1689
-
1690
- # ========== Stage 4: Multi-Modal Fusion ==========
1691
- # Collect features from all modalities for fusion
1692
- fusion_features = []
1693
- if audio_cross_out is not None:
1694
- audio_flat = audio_cross_out.flatten(1, 2) # [B, T*L, D]
1695
- fusion_features.append(audio_flat)
1696
- if video_cross_out is not None:
1697
- video_flat = video_cross_out.flatten(1, 2) # [B, T*L, D]
1698
- fusion_features.append(video_flat)
1699
- if image_cross_out is not None:
1700
- fusion_features.append(image_cross_out) # [B, L, D]
1701
-
1702
- # Pad/align sequence lengths for fusion
1703
- if len(fusion_features) > 1:
1704
- max_len = max(f.shape[1] for f in fusion_features)
1705
- aligned_features = []
1706
- for feat in fusion_features:
1707
- if feat.shape[1] < max_len:
1708
- pad_len = max_len - feat.shape[1]
1709
- feat = F.pad(feat, (0, 0, 0, pad_len))
1710
- aligned_features.append(feat)
1711
-
1712
- # Fuse modalities
1713
- fused_features = self.multimodal_fusion(aligned_features)
1714
- else:
1715
- fused_features = fusion_features[0] if fusion_features else None
1716
-
1717
- # ========== Stage 5: Feed-Forward Network ==========
1718
- if fused_features is not None:
1719
- fused_normed = self.norm_ffn(fused_features)
1720
- fused_ffn = self.ffn(fused_normed)
1721
- fused_features = fused_features + self.drop_path(fused_ffn)
1722
-
1723
- # ========== Stage 6: Prepare Outputs ==========
1724
- outputs = {}
1725
-
1726
- # Project back to original shapes
1727
- if audio_embeds is not None and audio_cross_out is not None:
1728
- # Partition again for consistency
1729
- audio_final, _ = self.window_partition.partition(audio_cross_out)
1730
- audio_final = self.output_projection['audio'](audio_final)
1731
- audio_final = self.window_partition.merge(audio_final, audio_info)
1732
- outputs['audio'] = audio_final
1733
-
1734
- if video_embeds is not None and video_cross_out is not None:
1735
- video_final, _ = self.window_partition.partition(video_cross_out)
1736
- video_final = self.output_projection['video'](video_final)
1737
- video_final = self.window_partition.merge(video_final, video_info)
1738
- outputs['video'] = video_final
1739
-
1740
- if image_embeds is not None and image_cross_out is not None:
1741
- image_final = self.output_projection['image'](image_cross_out)
1742
- outputs['image'] = image_final
1743
-
1744
- if fused_features is not None:
1745
- outputs['fused'] = fused_features
1746
-
1747
- if return_intermediates:
1748
- outputs['intermediates'] = intermediates
1749
-
1750
- return outputs
1751
-
1752
-
1753
- # -----------------------------------------------------------------------------
1754
- # 7. Optimization Utilities (FP8, Compilation, Mixed Precision)
1755
- # -----------------------------------------------------------------------------
1756
-
1757
- @dataclass
1758
- class FP8Config:
1759
- """Configuration for FP8 quantization"""
1760
- enabled: bool = False
1761
- margin: int = 0
1762
- fp8_format: str = "hybrid" # "e4m3", "e5m2", "hybrid"
1763
- amax_history_len: int = 1024
1764
- amax_compute_algo: str = "max"
1765
-
1766
-
1767
- @dataclass
1768
- class CompilationConfig:
1769
- """Configuration for torch.compile"""
1770
- enabled: bool = False
1771
- mode: str = "reduce-overhead" # "default", "reduce-overhead", "max-autotune"
1772
- fullgraph: bool = False
1773
- dynamic: bool = True
1774
- backend: str = "inductor"
1775
-
1776
-
1777
- @dataclass
1778
- class MixedPrecisionConfig:
1779
- """Configuration for mixed precision training/inference"""
1780
- enabled: bool = True
1781
- dtype: str = "bfloat16" # "float16", "bfloat16"
1782
- use_amp: bool = True
1783
-
1784
-
1785
- class ModelOptimizer:
1786
- """
1787
- Unified model optimizer supporting FP8 quantization, torch.compile,
1788
- and mixed precision inference.
1789
- """
1790
- def __init__(
1791
- self,
1792
- fp8_config: Optional[FP8Config] = None,
1793
- compilation_config: Optional[CompilationConfig] = None,
1794
- mixed_precision_config: Optional[MixedPrecisionConfig] = None,
1795
- ):
1796
- self.fp8_config = fp8_config or FP8Config()
1797
- self.compilation_config = compilation_config or CompilationConfig()
1798
- self.mixed_precision_config = mixed_precision_config or MixedPrecisionConfig()
1799
-
1800
- # Setup mixed precision
1801
- self._setup_mixed_precision()
1802
-
1803
- def _setup_mixed_precision(self):
1804
- """Setup mixed precision context"""
1805
- if self.mixed_precision_config.enabled:
1806
- dtype_map = {
1807
- "float16": torch.float16,
1808
- "bfloat16": torch.bfloat16,
1809
- }
1810
- self.dtype = dtype_map.get(self.mixed_precision_config.dtype, torch.bfloat16)
1811
- else:
1812
- self.dtype = torch.float32
1813
-
1814
- @contextmanager
1815
- def autocast_context(self):
1816
- """Context manager for automatic mixed precision"""
1817
- if self.mixed_precision_config.enabled and self.mixed_precision_config.use_amp:
1818
- with torch.autocast(device_type='cuda', dtype=self.dtype):
1819
- yield
1820
- else:
1821
- yield
1822
-
1823
- def _compile_model(self, model: nn.Module) -> nn.Module:
1824
- """Compile model using torch.compile"""
1825
- if not self.compilation_config.enabled or not HAS_TORCH_COMPILE:
1826
- return model
1827
-
1828
- return torch.compile(
1829
- model,
1830
- mode=self.compilation_config.mode,
1831
- fullgraph=self.compilation_config.fullgraph,
1832
- dynamic=self.compilation_config.dynamic,
1833
- backend=self.compilation_config.backend,
1834
- )
1835
-
1836
- def _quantize_model_fp8(self, model: nn.Module) -> nn.Module:
1837
- """Apply FP8 quantization using Transformer Engine"""
1838
- if not self.fp8_config.enabled or not HAS_TRANSFORMER_ENGINE:
1839
- return model
1840
-
1841
- # Convert compatible layers to FP8
1842
- for name, module in model.named_modules():
1843
- if isinstance(module, nn.Linear):
1844
- # Replace with TE FP8 Linear
1845
- fp8_linear = te.Linear(
1846
- module.in_features,
1847
- module.out_features,
1848
- bias=module.bias is not None,
1849
- )
1850
- # Copy weights
1851
- fp8_linear.weight.data.copy_(module.weight.data)
1852
- if module.bias is not None:
1853
- fp8_linear.bias.data.copy_(module.bias.data)
1854
-
1855
- # Replace module
1856
- parent_name = '.'.join(name.split('.')[:-1])
1857
- child_name = name.split('.')[-1]
1858
- if parent_name:
1859
- parent = dict(model.named_modules())[parent_name]
1860
- setattr(parent, child_name, fp8_linear)
1861
-
1862
- return model
1863
-
1864
- def optimize_model(
1865
- self,
1866
- model: nn.Module,
1867
- apply_compilation: bool = True,
1868
- apply_quantization: bool = True,
1869
- apply_mixed_precision: bool = True,
1870
- ) -> nn.Module:
1871
- """
1872
- Apply all optimizations to model.
1873
-
1874
- Args:
1875
- model: Model to optimize
1876
- apply_compilation: Whether to compile with torch.compile
1877
- apply_quantization: Whether to apply FP8 quantization
1878
- apply_mixed_precision: Whether to convert to mixed precision dtype
1879
-
1880
- Returns:
1881
- Optimized model
1882
- """
1883
- # Apply FP8 quantization first
1884
- if apply_quantization and self.fp8_config.enabled:
1885
- model = self._quantize_model_fp8(model)
1886
-
1887
- # Convert to mixed precision dtype
1888
- if apply_mixed_precision and self.mixed_precision_config.enabled:
1889
- model = model.to(dtype=self.dtype)
1890
-
1891
- # Compile model last
1892
- if apply_compilation and self.compilation_config.enabled:
1893
- model = self._compile_model(model)
1894
-
1895
- return model
1896
-
1897
-
1898
- @contextmanager
1899
- def optimized_inference_mode(
1900
- enable_cudnn_benchmark: bool = True,
1901
- enable_tf32: bool = True,
1902
- enable_flash_sdp: bool = True,
1903
- ):
1904
- """
1905
- Context manager for optimized inference with various PyTorch optimizations.
1906
-
1907
- Args:
1908
- enable_cudnn_benchmark: Enable cuDNN autotuner
1909
- enable_tf32: Enable TF32 for faster matmul on Ampere+ GPUs
1910
- enable_flash_sdp: Enable Flash Attention in scaled_dot_product_attention
1911
- """
1912
- # Save original states
1913
- orig_benchmark = torch.backends.cudnn.benchmark
1914
- orig_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
1915
- orig_tf32_cudnn = torch.backends.cudnn.allow_tf32
1916
- orig_sdp_flash = torch.backends.cuda.flash_sdp_enabled()
1917
-
1918
- try:
1919
- # Enable optimizations
1920
- torch.backends.cudnn.benchmark = enable_cudnn_benchmark
1921
- torch.backends.cuda.matmul.allow_tf32 = enable_tf32
1922
- torch.backends.cudnn.allow_tf32 = enable_tf32
1923
-
1924
- if enable_flash_sdp:
1925
- torch.backends.cuda.enable_flash_sdp(True)
1926
-
1927
- yield
1928
-
1929
- finally:
1930
- # Restore original states
1931
- torch.backends.cudnn.benchmark = orig_benchmark
1932
- torch.backends.cuda.matmul.allow_tf32 = orig_tf32_matmul
1933
- torch.backends.cudnn.allow_tf32 = orig_tf32_cudnn
1934
- torch.backends.cuda.enable_flash_sdp(orig_sdp_flash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
push.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 设置仓库级别用户名
3
+ git config user.name "selfitcamera"
4
+ git config user.email "ethan.blake@heybeauty.ai"
5
+
6
+ # 验证
7
+ git config user.name
8
+ git config user.email
9
+
10
+
11
+ git add .
12
+ git commit -m "init"
13
+ git push
util.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import json
6
+ import random
7
+ import time
8
+ import datetime
9
+ import requests
10
+ import func_timeout
11
+ import numpy as np
12
+ import gradio as gr
13
+ import boto3
14
+ import tempfile
15
+ import io
16
+ import uuid
17
+ from botocore.client import Config
18
+ from PIL import Image
19
+
20
+
21
+ # TOKEN = os.environ['TOKEN']
22
+ # APIKEY = os.environ['APIKEY']
23
+ # UKAPIURL = os.environ['UKAPIURL']
24
+
25
+ OneKey = os.environ['OneKey'].strip()
26
+ OneKey = OneKey.split("#")
27
+ TOKEN = OneKey[0]
28
+ APIKEY = OneKey[1]
29
+ UKAPIURL = OneKey[2]
30
+ LLMKEY = OneKey[3]
31
+ R2_ACCESS_KEY = OneKey[4]
32
+ R2_SECRET_KEY = OneKey[5]
33
+ R2_ENDPOINT = OneKey[6]
34
+
35
+
36
+ # tmpFolder is no longer needed since we upload directly from memory
37
+ # tmpFolder = "tmp"
38
+ # os.makedirs(tmpFolder, exist_ok=True)
39
+
40
+
41
+ # Legacy function - no longer used since we upload directly from memory
42
+ # def upload_user_img(clientIp, timeId, img):
43
+ # fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
44
+ # local_path = os.path.join(tmpFolder, fileName)
45
+ # img = cv2.imread(img)
46
+ # cv2.imwrite(os.path.join(tmpFolder, fileName), img)
47
+ #
48
+ # json_data = {
49
+ # "token": TOKEN,
50
+ # "input1": fileName,
51
+ # "input2": "",
52
+ # "protocol": "",
53
+ # "cloud": "ali"
54
+ # }
55
+ #
56
+ # session = requests.session()
57
+ # ret = requests.post(
58
+ # f"{UKAPIURL}/upload",
59
+ # headers={'Content-Type': 'application/json'},
60
+ # json=json_data
61
+ # )
62
+ #
63
+ # res = ""
64
+ # if ret.status_code==200:
65
+ # if 'upload1' in ret.json():
66
+ # upload_url = ret.json()['upload1']
67
+ # headers = {'Content-Type': 'image/jpeg'}
68
+ # response = session.put(upload_url, data=open(local_path, 'rb').read(), headers=headers)
69
+ # # print(response.status_code)
70
+ # if response.status_code == 200:
71
+ # res = upload_url
72
+ # if os.path.exists(local_path):
73
+ # os.remove(local_path)
74
+ # return res
75
+
76
+
77
+ class R2Api:
78
+
79
+ def __init__(self, session=None):
80
+ super().__init__()
81
+ self.R2_BUCKET = "omni-creator"
82
+ self.domain = "https://www.omnicreator.net/"
83
+ self.R2_ACCESS_KEY = R2_ACCESS_KEY
84
+ self.R2_SECRET_KEY = R2_SECRET_KEY
85
+ self.R2_ENDPOINT = R2_ENDPOINT
86
+
87
+ self.client = boto3.client(
88
+ "s3",
89
+ endpoint_url=self.R2_ENDPOINT,
90
+ aws_access_key_id=self.R2_ACCESS_KEY,
91
+ aws_secret_access_key=self.R2_SECRET_KEY,
92
+ config=Config(signature_version="s3v4")
93
+ )
94
+
95
+ self.session = requests.Session() if session is None else session
96
+
97
+ def upload_from_memory(self, image_data, filename, content_type='image/jpeg'):
98
+ """
99
+ Upload image data directly from memory to R2
100
+
101
+ Args:
102
+ image_data (bytes): Image data in bytes
103
+ filename (str): Filename for the uploaded file
104
+ content_type (str): MIME type of the image
105
+
106
+ Returns:
107
+ str: URL of the uploaded file
108
+ """
109
+ t1 = time.time()
110
+ headers = {"Content-Type": content_type}
111
+
112
+ cloud_path = f"ImageEdit/Uploads/{str(datetime.date.today())}/{filename}"
113
+ url = self.client.generate_presigned_url(
114
+ "put_object",
115
+ Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": content_type},
116
+ ExpiresIn=604800
117
+ )
118
+
119
+ retry_count = 0
120
+ while retry_count < 3:
121
+ try:
122
+ response = self.session.put(url, data=image_data, headers=headers, timeout=15)
123
+ if response.status_code == 200:
124
+ break
125
+ else:
126
+ print(f"⚠️ Upload failed with status code: {response.status_code}")
127
+ retry_count += 1
128
+ except (requests.exceptions.Timeout, requests.exceptions.RequestException) as e:
129
+ print(f"⚠️ Upload retry {retry_count + 1}/3 failed: {e}")
130
+ retry_count += 1
131
+ if retry_count == 3:
132
+ raise Exception(f'Failed to upload file to R2 after 3 retries! Last error: {str(e)}')
133
+ time.sleep(1) # 等待1秒后重试
134
+ continue
135
+ print("upload_from_memory time is ====>", time.time() - t1)
136
+ return f"{self.domain}{cloud_path}"
137
+
138
+ def upload_user_img_r2(clientIp, timeId, pil_image):
139
+ """
140
+ Upload PIL Image directly to R2 without saving to local file
141
+
142
+ Args:
143
+ clientIp (str): Client IP address
144
+ timeId (int): Timestamp
145
+ pil_image (PIL.Image): PIL Image object
146
+
147
+ Returns:
148
+ str: Uploaded URL
149
+ """
150
+ # Generate unique filename using UUID to prevent file conflicts in concurrent environment
151
+ unique_id = str(uuid.uuid4())
152
+ fileName = f"user_img_{unique_id}_{timeId}.jpg"
153
+
154
+ # Convert PIL Image to bytes
155
+ img_buffer = io.BytesIO()
156
+ if pil_image.mode != 'RGB':
157
+ pil_image = pil_image.convert('RGB')
158
+ pil_image.save(img_buffer, format='JPEG', quality=95)
159
+ img_data = img_buffer.getvalue()
160
+
161
+ # Upload directly from memory
162
+ res = R2Api().upload_from_memory(img_data, fileName, 'image/jpeg')
163
+ return res
164
+
165
+
166
+
167
+ def create_mask_from_layers(base_image, layers):
168
+ """
169
+ Create mask image from ImageEditor layers
170
+
171
+ Args:
172
+ base_image (PIL.Image): Original image
173
+ layers (list): ImageEditor layer data
174
+
175
+ Returns:
176
+ PIL.Image: Black and white mask image
177
+ """
178
+ from PIL import Image, ImageDraw
179
+ import numpy as np
180
+
181
+ # Create blank mask with same size as original image
182
+ mask = Image.new('L', base_image.size, 0) # 'L' mode is grayscale, 0 is black
183
+
184
+ if not layers:
185
+ return mask
186
+
187
+ # Iterate through all layers, set drawn areas to white
188
+ for layer in layers:
189
+ if layer is not None:
190
+ # Convert layer to numpy array
191
+ layer_array = np.array(layer)
192
+
193
+ # Check layer format
194
+ if len(layer_array.shape) == 3: # RGB/RGBA format
195
+ # If RGBA, check alpha channel
196
+ if layer_array.shape[2] == 4:
197
+ # Use alpha channel as mask
198
+ alpha_channel = layer_array[:, :, 3]
199
+ # Set non-transparent areas (alpha > 0) to white
200
+ mask_array = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
201
+ else:
202
+ # RGB format, check if not pure black (0,0,0)
203
+ # Assume drawn areas are non-black
204
+ non_black = np.any(layer_array > 0, axis=2)
205
+ mask_array = np.where(non_black, 255, 0).astype(np.uint8)
206
+ elif len(layer_array.shape) == 2: # Grayscale
207
+ # Use grayscale values directly, set non-zero areas to white
208
+ mask_array = np.where(layer_array > 0, 255, 0).astype(np.uint8)
209
+ else:
210
+ continue
211
+
212
+ # Convert mask_array to PIL image and merge into total mask
213
+ layer_mask = Image.fromarray(mask_array, mode='L')
214
+ # Resize to match original image
215
+ if layer_mask.size != base_image.size:
216
+ layer_mask = layer_mask.resize(base_image.size, Image.LANCZOS)
217
+
218
+ # Merge masks (use maximum value to ensure all drawn areas are included)
219
+ mask_array_current = np.array(mask)
220
+ layer_mask_array = np.array(layer_mask)
221
+ combined_mask_array = np.maximum(mask_array_current, layer_mask_array)
222
+ mask = Image.fromarray(combined_mask_array, mode='L')
223
+
224
+ return mask
225
+
226
+
227
+ def upload_mask_image_r2(client_ip, time_id, mask_image):
228
+ """
229
+ Upload mask image to R2 directly from memory
230
+
231
+ Args:
232
+ client_ip (str): Client IP
233
+ time_id (int): Timestamp
234
+ mask_image (PIL.Image): Mask image
235
+
236
+ Returns:
237
+ str: Uploaded URL
238
+ """
239
+ # Generate unique filename using UUID to prevent file conflicts in concurrent environment
240
+ unique_id = str(uuid.uuid4())
241
+ file_name = f"mask_img_{unique_id}_{time_id}.png"
242
+
243
+ try:
244
+ # Convert mask image to bytes
245
+ img_buffer = io.BytesIO()
246
+ mask_image.save(img_buffer, format='PNG')
247
+ img_data = img_buffer.getvalue()
248
+
249
+ # Upload directly from memory
250
+ res = R2Api().upload_from_memory(img_data, file_name, 'image/png')
251
+
252
+ return res
253
+ except Exception as e:
254
+ print(f"Failed to upload mask image: {e}")
255
+ return None
256
+
257
+
258
+
259
+ def submit_image_edit_task(user_image_url, prompt, task_type="80", mask_image_url="", reference_image_url=""):
260
+ """
261
+ Submit image editing task with improved error handling using API v2
262
+ """
263
+ headers = {
264
+ 'Content-Type': 'application/json',
265
+ 'Authorization': f'Bearer {APIKEY}'
266
+ }
267
+
268
+ data = {
269
+ "user_image": user_image_url,
270
+ "user_mask": mask_image_url,
271
+ "type": task_type,
272
+ "text": prompt,
273
+ "user_uuid": APIKEY,
274
+ "priority": 0,
275
+ "secret_key": "219ngu"
276
+ }
277
+
278
+ if reference_image_url:
279
+ data["user_image2"] = reference_image_url
280
+
281
+ retry_count = 0
282
+ max_retries = 3
283
+
284
+ while retry_count < max_retries:
285
+ try:
286
+ response = requests.post(
287
+ f'{UKAPIURL}/public_image_edit_v2',
288
+ headers=headers,
289
+ json=data,
290
+ timeout=30 # 增加超时时间
291
+ )
292
+
293
+ if response.status_code == 200:
294
+ result = response.json()
295
+ if result.get('code') == 0:
296
+ return result['data']['task_id'], None
297
+ else:
298
+ return None, f"API Error: {result.get('message', 'Unknown error')}"
299
+ elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试
300
+ retry_count += 1
301
+ if retry_count < max_retries:
302
+ print(f"⚠️ Server error {response.status_code}, retrying {retry_count}/{max_retries}")
303
+ time.sleep(2) # 等待2秒后重试
304
+ continue
305
+ else:
306
+ return None, f"HTTP Error after {max_retries} retries: {response.status_code}"
307
+ else:
308
+ return None, f"HTTP Error: {response.status_code}"
309
+
310
+ except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
311
+ retry_count += 1
312
+ if retry_count < max_retries:
313
+ print(f"⚠️ Network error, retrying {retry_count}/{max_retries}: {e}")
314
+ time.sleep(2)
315
+ continue
316
+ else:
317
+ return None, f"Network error after {max_retries} retries: {str(e)}"
318
+ except Exception as e:
319
+ return None, f"Request Exception: {str(e)}"
320
+
321
+ return None, f"Failed after {max_retries} retries"
322
+
323
+
324
+ def check_task_status(task_id):
325
+ """
326
+ Query task status with improved error handling using API v2
327
+ """
328
+ headers = {
329
+ 'Content-Type': 'application/json',
330
+ 'Authorization': f'Bearer {APIKEY}'
331
+ }
332
+
333
+ data = {
334
+ "task_id": task_id
335
+ }
336
+
337
+ retry_count = 0
338
+ max_retries = 2 # 状态查询重试次数少一些
339
+
340
+ while retry_count < max_retries:
341
+ try:
342
+ response = requests.post(
343
+ f'{UKAPIURL}/status_image_edit_v2',
344
+ headers=headers,
345
+ json=data,
346
+ timeout=15 # 状态查询超时时间短一些
347
+ )
348
+
349
+ if response.status_code == 200:
350
+ result = response.json()
351
+ if result.get('code') == 0:
352
+ task_data = result['data']
353
+ status = task_data['status']
354
+ image_url = task_data.get('image_url')
355
+
356
+ # Extract and log queue information for better user feedback
357
+ queue_info = task_data.get('queue_info', {})
358
+ if queue_info:
359
+ tasks_ahead = queue_info.get('tasks_ahead', 0)
360
+ current_priority = queue_info.get('current_priority', 0)
361
+ description = queue_info.get('description', '')
362
+ print(f"📊 Queue Status - Tasks ahead: {tasks_ahead}, Priority: {current_priority}, Status: {status}")
363
+
364
+ return status, image_url, task_data
365
+ else:
366
+ return 'error', None, result.get('message', 'Unknown error')
367
+ elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试
368
+ retry_count += 1
369
+ if retry_count < max_retries:
370
+ print(f"⚠️ Status check server error {response.status_code}, retrying {retry_count}/{max_retries}")
371
+ time.sleep(1) # 状态查询重试间隔短一些
372
+ continue
373
+ else:
374
+ return 'error', None, f"HTTP Error after {max_retries} retries: {response.status_code}"
375
+ else:
376
+ return 'error', None, f"HTTP Error: {response.status_code}"
377
+
378
+ except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
379
+ retry_count += 1
380
+ if retry_count < max_retries:
381
+ print(f"⚠️ Status check network error, retrying {retry_count}/{max_retries}: {e}")
382
+ time.sleep(1)
383
+ continue
384
+ else:
385
+ return 'error', None, f"Network error after {max_retries} retries: {str(e)}"
386
+ except Exception as e:
387
+ return 'error', None, f"Request Exception: {str(e)}"
388
+
389
+ return 'error', None, f"Failed after {max_retries} retries"
390
+
391
+
392
+ def process_image_edit(img_input, prompt, reference_image=None, progress_callback=None):
393
+ """
394
+ Complete process for image editing
395
+
396
+ Args:
397
+ img_input: Can be file path (str) or PIL Image object
398
+ prompt: Editing instructions
399
+ progress_callback: Progress callback function
400
+ """
401
+ try:
402
+ # Generate client IP and timestamp
403
+ client_ip = "127.0.0.1" # Default IP
404
+ time_id = int(time.time())
405
+
406
+ # Process input image - supports PIL Image and file path
407
+ if hasattr(img_input, 'save'): # PIL Image object
408
+ pil_image = img_input
409
+ print(f"💾 Using PIL Image directly from memory")
410
+ else:
411
+ # Load from file path
412
+ pil_image = Image.open(img_input)
413
+ print(f"📁 Loaded image from file: {img_input}")
414
+
415
+ if progress_callback:
416
+ progress_callback("uploading image...")
417
+
418
+ # Upload user image directly from memory
419
+ uploaded_url = upload_user_img_r2(client_ip, time_id, pil_image)
420
+ if not uploaded_url:
421
+ return None, "image upload failed", None
422
+
423
+ # Extract actual image URL from upload URL
424
+ if "?" in uploaded_url:
425
+ uploaded_url = uploaded_url.split("?")[0]
426
+
427
+ if progress_callback:
428
+ progress_callback("submitting edit task...")
429
+
430
+ reference_url = ""
431
+ if reference_image is not None:
432
+ try:
433
+ if progress_callback:
434
+ progress_callback("uploading reference image...")
435
+
436
+ if hasattr(reference_image, 'save'):
437
+ reference_pil = reference_image
438
+ else:
439
+ reference_pil = Image.open(reference_image)
440
+
441
+ reference_url = upload_user_img_r2(client_ip, time_id, reference_pil)
442
+ if not reference_url:
443
+ return None, "reference image upload failed", None
444
+
445
+ if "?" in reference_url:
446
+ reference_url = reference_url.split("?")[0]
447
+ except Exception as e:
448
+ return None, f"reference image processing failed: {str(e)}", None
449
+
450
+ # Submit image editing task
451
+ task_id, error = submit_image_edit_task(uploaded_url, prompt, reference_image_url=reference_url)
452
+ if error:
453
+ return None, error, None
454
+
455
+ if progress_callback:
456
+ progress_callback(f"task submitted, ID: {task_id}, processing...")
457
+
458
+ # Wait for task completion
459
+ max_attempts = 60 # Wait up to 10 minutes
460
+ task_uuid = None
461
+ for attempt in range(max_attempts):
462
+ status, output_url, task_data = check_task_status(task_id)
463
+
464
+ # Extract task_uuid from task_data
465
+ if task_data and isinstance(task_data, dict):
466
+ task_uuid = task_data.get('uuid', None)
467
+
468
+ if status == 'completed':
469
+ if output_url:
470
+ return output_url, "image edit completed", task_uuid
471
+ else:
472
+ return None, "Task completed but no result image returned", task_uuid
473
+ elif status == 'error' or status == 'failed':
474
+ return None, f"task processing failed: {task_data}", task_uuid
475
+ elif status in ['queued', 'processing', 'running', 'created', 'working']:
476
+ # Enhanced progress message with queue info and website promotion
477
+ if progress_callback and task_data and isinstance(task_data, dict):
478
+ queue_info = task_data.get('queue_info', {})
479
+ if queue_info and status in ['queued', 'created']:
480
+ tasks_ahead = queue_info.get('tasks_ahead', 0)
481
+ current_priority = queue_info.get('current_priority', 0)
482
+ if tasks_ahead > 0:
483
+ progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator")
484
+ else:
485
+ progress_callback(f"🚀 Processing your image editing request...")
486
+ elif status == 'processing':
487
+ progress_callback(f"🎨 AI is processing... Please wait")
488
+ elif status in ['running', 'working']:
489
+ progress_callback(f"⚡ Generating... Almost done")
490
+ else:
491
+ progress_callback(f"📋 Task status: {status}")
492
+ else:
493
+ if progress_callback:
494
+ progress_callback(f"task processing... (status: {status})")
495
+ time.sleep(1)
496
+ else:
497
+ if progress_callback:
498
+ progress_callback(f"unknown status: {status}")
499
+ time.sleep(1)
500
+
501
+ return None, "task processing timeout", task_uuid
502
+
503
+ except Exception as e:
504
+ return None, f"error occurred during processing: {str(e)}", None
505
+
506
+
507
+ def process_local_image_edit(base_image, layers, prompt, reference_image=None, progress_callback=None, use_example_mask=None):
508
+ """
509
+ 处理局部图片编辑的完整流程
510
+
511
+ Args:
512
+ base_image (PIL.Image): 原始图片
513
+ layers (list): ImageEditor的层数据
514
+ prompt (str): 编辑指令
515
+ progress_callback: 进度回调函数
516
+ """
517
+ try:
518
+ # Generate client IP and timestamp
519
+ client_ip = "127.0.0.1" # Default IP
520
+ time_id = int(time.time())
521
+
522
+ if progress_callback:
523
+ progress_callback("creating mask image...")
524
+
525
+ # Check if we should use example mask (backdoor for example case)
526
+ if use_example_mask:
527
+ # Load local mask file for example
528
+ try:
529
+ from PIL import Image
530
+ import os
531
+
532
+ # Check if base_image is valid
533
+ if base_image is None:
534
+ return None, "Base image is None, cannot process example mask", None
535
+
536
+ if os.path.exists(use_example_mask):
537
+ mask_image = Image.open(use_example_mask)
538
+
539
+ # Ensure mask has same size as base image
540
+ if hasattr(base_image, 'size') and mask_image.size != base_image.size:
541
+ mask_image = mask_image.resize(base_image.size)
542
+
543
+ # Ensure mask is in L mode (grayscale)
544
+ if mask_image.mode != 'L':
545
+ mask_image = mask_image.convert('L')
546
+
547
+ print(f"🎭 Using example mask from: {use_example_mask}, size: {mask_image.size}")
548
+ else:
549
+ return None, f"Example mask file not found: {use_example_mask}", None
550
+ except Exception as e:
551
+ import traceback
552
+ traceback.print_exc()
553
+ return None, f"Failed to load example mask: {str(e)}", None
554
+ else:
555
+ # Normal case: create mask from layers
556
+ mask_image = create_mask_from_layers(base_image, layers)
557
+
558
+ # 检查mask是否有内容
559
+ mask_array = np.array(mask_image)
560
+ if np.max(mask_array) == 0:
561
+ return None, "please draw mask", None
562
+
563
+ # Print mask statistics
564
+ if not use_example_mask:
565
+ print(f"📝 创建mask图片成功,绘制区域像素数: {np.sum(mask_array > 0)}")
566
+ else:
567
+ mask_array = np.array(mask_image)
568
+ print(f"🎭 Example mask loaded successfully, mask pixels: {np.sum(mask_array > 0)}")
569
+
570
+ if progress_callback:
571
+ progress_callback("uploading original image...")
572
+
573
+ # 直接从内存上传原始图片
574
+ uploaded_url = upload_user_img_r2(client_ip, time_id, base_image)
575
+ if not uploaded_url:
576
+ return None, "original image upload failed", None
577
+
578
+ # 从上传 URL 中提取实际的图片 URL
579
+ if "?" in uploaded_url:
580
+ uploaded_url = uploaded_url.split("?")[0]
581
+
582
+ if progress_callback:
583
+ progress_callback("uploading mask image...")
584
+
585
+ # 直接从内存上传mask图片
586
+ mask_url = upload_mask_image_r2(client_ip, time_id, mask_image)
587
+ if not mask_url:
588
+ return None, "mask image upload failed", None
589
+
590
+ # 从上传 URL 中提取实际的图片 URL
591
+ if "?" in mask_url:
592
+ mask_url = mask_url.split("?")[0]
593
+
594
+ reference_url = ""
595
+ if reference_image is not None:
596
+ try:
597
+ if progress_callback:
598
+ progress_callback("uploading reference image...")
599
+
600
+ if hasattr(reference_image, 'save'):
601
+ reference_pil = reference_image
602
+ else:
603
+ reference_pil = Image.open(reference_image)
604
+
605
+ reference_url = upload_user_img_r2(client_ip, time_id, reference_pil)
606
+ if not reference_url:
607
+ return None, "reference image upload failed", None
608
+
609
+ if "?" in reference_url:
610
+ reference_url = reference_url.split("?")[0]
611
+ except Exception as e:
612
+ return None, f"reference image processing failed: {str(e)}", None
613
+
614
+ print(f"📤 图片上传成功:")
615
+ print(f" 原始图片: {uploaded_url}")
616
+ print(f" Mask图片: {mask_url}")
617
+ if reference_url:
618
+ print(f" 参考图片: {reference_url}")
619
+
620
+ if progress_callback:
621
+ progress_callback("submitting local edit task...")
622
+
623
+ # 提交局部图片编辑任务 (task_type=81)
624
+ task_id, error = submit_image_edit_task(
625
+ uploaded_url,
626
+ prompt,
627
+ task_type="81",
628
+ mask_image_url=mask_url,
629
+ reference_image_url=reference_url
630
+ )
631
+ if error:
632
+ return None, error, None
633
+
634
+ if progress_callback:
635
+ progress_callback(f"task submitted, ID: {task_id}, processing...")
636
+
637
+ print(f"🚀 局部编辑任务已提交,任务ID: {task_id}")
638
+
639
+ # Wait for task completion
640
+ max_attempts = 60 # Wait up to 10 minutes
641
+ task_uuid = None
642
+ for attempt in range(max_attempts):
643
+ status, output_url, task_data = check_task_status(task_id)
644
+
645
+ # Extract task_uuid from task_data
646
+ if task_data and isinstance(task_data, dict):
647
+ task_uuid = task_data.get('uuid', None)
648
+
649
+ if status == 'completed':
650
+ if output_url:
651
+ print(f"✅ 局部编辑任务完成,结果: {output_url}")
652
+ return output_url, "local image edit completed", task_uuid
653
+ else:
654
+ return None, "task completed but no result image returned", task_uuid
655
+ elif status == 'error' or status == 'failed':
656
+ return None, f"task processing failed: {task_data}", task_uuid
657
+ elif status in ['queued', 'processing', 'running', 'created', 'working']:
658
+ # Enhanced progress message with queue info and website promotion
659
+ if progress_callback and task_data and isinstance(task_data, dict):
660
+ queue_info = task_data.get('queue_info', {})
661
+ if queue_info and status in ['queued', 'created']:
662
+ tasks_ahead = queue_info.get('tasks_ahead', 0)
663
+ current_priority = queue_info.get('current_priority', 0)
664
+ if tasks_ahead > 0:
665
+ progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator")
666
+ else:
667
+ progress_callback(f"🚀 Processing your local editing request...")
668
+ elif status == 'processing':
669
+ progress_callback(f"🎨 AI is processing... Please wait")
670
+ elif status in ['running', 'working']:
671
+ progress_callback(f"⚡ Generating... Almost done")
672
+ else:
673
+ progress_callback(f"📋 Task status: {status}")
674
+ else:
675
+ if progress_callback:
676
+ progress_callback(f"processing... (status: {status})")
677
+ time.sleep(1) # Wait 1 second before retry
678
+ else:
679
+ if progress_callback:
680
+ progress_callback(f"unknown status: {status}")
681
+ time.sleep(1)
682
+
683
+ return None, "task processing timeout", task_uuid
684
+
685
+ except Exception as e:
686
+ print(f"❌ 局部编辑处理异常: {str(e)}")
687
+ return None, f"error occurred during processing: {str(e)}", None
688
+
689
+
690
+ def download_and_check_result_nsfw(image_url, nsfw_detector=None):
691
+ """
692
+ 下载结果图片并进行NSFW检测
693
+
694
+ Args:
695
+ image_url (str): 结果图片URL
696
+ nsfw_detector: NSFW检测器实例
697
+
698
+ Returns:
699
+ tuple: (is_nsfw, error_message)
700
+ """
701
+ if nsfw_detector is None:
702
+ return False, None
703
+
704
+ try:
705
+ # 下载图片
706
+ response = requests.get(image_url, timeout=30)
707
+ if response.status_code != 200:
708
+ return False, f"Failed to download result image: HTTP {response.status_code}"
709
+
710
+ # 将图片数据转换为PIL Image
711
+ image_data = io.BytesIO(response.content)
712
+ result_image = Image.open(image_data)
713
+
714
+ # 进行NSFW检测
715
+ nsfw_result = nsfw_detector.predict_pil_label_only(result_image)
716
+
717
+ is_nsfw = nsfw_result.lower() == "nsfw"
718
+ print(f"🔍 结果图片NSFW检测: {'❌❌❌ ' + nsfw_result if is_nsfw else '✅✅✅ ' + nsfw_result}")
719
+
720
+ return is_nsfw, None
721
+
722
+ except Exception as e:
723
+ print(f"⚠️ 结果图片NSFW检测失败: {e}")
724
+ return False, f"Failed to check result image: {str(e)}"
725
+
726
+
727
+ if __name__ == "__main__":
728
+
729
+ pass