rahul7star commited on
Commit
8f03f67
Β·
verified Β·
1 Parent(s): d46cfdd

Update app_quant_latent1.py

Browse files
Files changed (1) hide show
  1. app_quant_latent1.py +77 -62
app_quant_latent1.py CHANGED
@@ -250,53 +250,6 @@ log_system_stats("AFTER PIPELINE BUILD")
250
  from PIL import Image
251
  import torch
252
 
253
- def safe_generate_with_latents(
254
- transformer,
255
- vae,
256
- text_encoder,
257
- tokenizer,
258
- scheduler,
259
- pipe,
260
- prompt,
261
- height,
262
- width,
263
- steps,
264
- guidance_scale,
265
- negative_prompt,
266
- num_images_per_prompt,
267
- generator,
268
- cfg_normalization,
269
- cfg_truncation,
270
- max_sequence_length,
271
- ):
272
-
273
- try:
274
-
275
- latents_or_images = generate(
276
- transformer=transformer,
277
- vae=vae,
278
- text_encoder=text_encoder,
279
- tokenizer=tokenizer,
280
- scheduler=scheduler,
281
- prompt=prompt,
282
- height=height,
283
- width=width,
284
- num_inference_steps=steps,
285
- guidance_scale=guidance_scale,
286
- negative_prompt=negative_prompt,
287
- num_images_per_prompt=num_images_per_prompt,
288
- generator=generator,
289
- cfg_normalization=cfg_normalization,
290
- cfg_truncation=cfg_truncation,
291
- max_sequence_length=max_sequence_length,
292
- output_type="latent", # IMPORTANT
293
- )
294
- return latents_or_images, None
295
-
296
- except Exception as e:
297
- return None, e
298
-
299
-
300
 
301
 
302
 
@@ -364,22 +317,84 @@ def safe_get_latents(pipe, height, width, generator, device, LOGS):
364
  # --------------------------
365
  @spaces.GPU
366
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
367
- LOGS = []
368
- latents = None
369
- image = None
370
- gallery = []
371
-
372
- # placeholder image if all fails
373
- placeholder = Image.new("RGB", (width, height), color=(255, 255, 255))
374
- print(prompt)
375
-
376
- latents, latent_err = safe_generate_with_latents( transformer=transformer, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler,
377
- pipe=pipe, prompt=prompt, height=height, width=width, steps=steps, guidance_scale=guidance_scale, negative_prompt="", num_images_per_prompt=1, generator=generator, cfg_normalization=False, cfg_truncation=1.0, max_sequence_length=4096, )
378
- if latent_err is None: log("βœ… Latent generator succeeded.")
379
- try: # Decode latents to image shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0 dec = (latents.to(vae.dtype) / vae.config.scaling_factor) +
380
- shift_factor image = vae.decode(dec, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image * 255).round().astype("uint8") from PIL import Image image = Image.fromarray(image[0]) log("🟒 Final image decoded from latent generator.") return image, latents, LOGS except Exception as decode_error: log(f"⚠️ Latent decode failed: {decode_error}") log("πŸ” Falling back to standard pipeline...") else: log(f"⚠️ Latent generator failed: {latent_err}") log("πŸ” Switching to standard pipeline...") # ========================================================== # 🟩 STANDARD PIPELINE FALLBACK (Never fails) # ========================================================== try: output = pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=guidance_scale, generator=generator, ) image = output.images[0] log("🟒 Standard pipeline succeeded.") return image, None, LOGS except Exception as e: log(f"❌ Standard pipeline failed: {e}") return None, None, LOGS
381
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
 
 
 
383
 
384
  # --------------------------
385
  # Helper: Safe latent extractor
 
250
  from PIL import Image
251
  import torch
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
 
 
317
  # --------------------------
318
  @spaces.GPU
319
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
320
+ LOGS = []
321
+ def log(msg):
322
+ LOGS.append(msg)
323
+ print(msg)
324
+
325
+ device = "cuda" if torch.cuda.is_available() else "cpu"
326
+ generator = torch.Generator(device).manual_seed(int(seed))
327
+
328
+ log("🎨 START IMAGE GENERATION")
329
+
330
+
331
+ # ==========================================================
332
+ # πŸ§ͺ TRY ADVANCED LATENT GENERATOR (Your original generate())
333
+ # ==========================================================
334
+ latents, latent_err = safe_generate_with_latents(
335
+ transformer=transformer,
336
+ vae=vae,
337
+ text_encoder=text_encoder,
338
+ tokenizer=tokenizer,
339
+ scheduler=scheduler,
340
+ pipe=pipe,
341
+ prompt=prompt,
342
+ height=height,
343
+ width=width,
344
+ steps=steps,
345
+ guidance_scale=guidance_scale,
346
+ negative_prompt="",
347
+ num_images_per_prompt=1,
348
+ generator=generator,
349
+ cfg_normalization=False,
350
+ cfg_truncation=1.0,
351
+ max_sequence_length=4096,
352
+ )
353
+
354
+ if latent_err is None:
355
+ log("βœ… Latent generator succeeded.")
356
+ try:
357
+ # Decode latents to image
358
+ shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
359
+ dec = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor
360
+ image = vae.decode(dec, return_dict=False)[0]
361
+
362
+ image = (image / 2 + 0.5).clamp(0, 1)
363
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
364
+ image = (image * 255).round().astype("uint8")
365
+ from PIL import Image
366
+ image = Image.fromarray(image[0])
367
+
368
+ log("🟒 Final image decoded from latent generator.")
369
+ return image, latents, LOGS
370
+
371
+ except Exception as decode_error:
372
+ log(f"⚠️ Latent decode failed: {decode_error}")
373
+ log("πŸ” Falling back to standard pipeline...")
374
+
375
+ else:
376
+ log(f"⚠️ Latent generator failed: {latent_err}")
377
+ log("πŸ” Switching to standard pipeline...")
378
+
379
+ # ==========================================================
380
+ # 🟩 STANDARD PIPELINE FALLBACK (Never fails)
381
+ # ==========================================================
382
+ try:
383
+ output = pipe(
384
+ prompt=prompt,
385
+ height=height,
386
+ width=width,
387
+ num_inference_steps=steps,
388
+ guidance_scale=guidance_scale,
389
+ generator=generator,
390
+ )
391
+ image = output.images[0]
392
+ log("🟒 Standard pipeline succeeded.")
393
+ return image, None, LOGS
394
 
395
+ except Exception as e:
396
+ log(f"❌ Standard pipeline failed: {e}")
397
+ return None, None, LOGS
398
 
399
  # --------------------------
400
  # Helper: Safe latent extractor