frogleo commited on
Commit
fdf1179
·
verified ·
1 Parent(s): c260184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -104
app.py CHANGED
@@ -26,6 +26,9 @@ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_ori
26
  from torchvision.transforms.functional import to_pil_image
27
  from PIL import Image, ImageDraw, ImageFont
28
 
 
 
 
29
 
30
  # Enhanced logging configuration
31
  logging.basicConfig(
@@ -164,119 +167,130 @@ def _infer(person,garment,denoise_steps,seed):
164
  progress(0,desc="Starting")
165
  device = "cuda"
166
 
167
- openpose_model.preprocessor.body_estimation.model.to(device)
168
- pipe.to(device)
169
- pipe.unet_encoder.to(device)
 
170
 
171
- personRGB = person.convert("RGB")
172
- crop_size = personRGB.size
173
 
174
- human_img = personRGB.resize((768,1024))
175
- garm_img= garment.convert("RGB").resize((768,1024))
176
-
177
- progress(0.1,desc="Mask generating")
178
-
179
- keypoints = openpose_model(human_img.resize((384,512)))
180
- model_parse, _ = parsing_model(human_img.resize((384,512)))
181
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
182
- mask = mask.resize((768,1024))
183
 
184
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
185
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
186
 
187
- progress(0.3,desc="DensePose processing")
188
 
189
- human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
190
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
191
 
192
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
193
- # verbosity = getattr(args, "verbosity", None)
194
- pose_img = args.func(args,human_img_arg)
195
- pose_img = pose_img[:,:,::-1]
196
- pose_img = Image.fromarray(pose_img).resize((768,1024))
197
 
198
- progress(0.5,desc="Image generating")
199
 
200
- def callback(pipe, step, timestep, callback_kwargs):
201
- progress_value = 0.5 + ((step+1.0)/denoise_steps)*(0.5/1.0)
202
- progress(progress_value, desc=f"Image generating, {step + 1}/{denoise_steps} steps")
203
- return callback_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- with torch.no_grad():
206
- # Extract the images
207
- with torch.cuda.amp.autocast():
208
- with torch.no_grad():
209
- prompt = "model is wearing clothing"
210
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
211
- with torch.inference_mode():
212
- (
213
- prompt_embeds,
214
- negative_prompt_embeds,
215
- pooled_prompt_embeds,
216
- negative_pooled_prompt_embeds,
217
- ) = pipe.encode_prompt(
218
- prompt,
219
- num_images_per_prompt=1,
220
- do_classifier_free_guidance=True,
221
- negative_prompt=negative_prompt,
222
- )
223
-
224
- prompt = "a photo of clothing"
225
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
226
- if not isinstance(prompt, List):
227
- prompt = [prompt] * 1
228
- if not isinstance(negative_prompt, List):
229
- negative_prompt = [negative_prompt] * 1
230
- with torch.inference_mode():
231
- (
232
- prompt_embeds_c,
233
- _,
234
- _,
235
- _,
236
- ) = pipe.encode_prompt(
237
- prompt,
238
- num_images_per_prompt=1,
239
- do_classifier_free_guidance=False,
240
- negative_prompt=negative_prompt,
241
- )
242
-
243
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
244
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
245
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
246
- images = pipe(
247
- prompt_embeds=prompt_embeds.to(device,torch.float16),
248
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
249
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
250
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
251
- num_inference_steps=denoise_steps,
252
- generator=generator,
253
- strength = 1.0,
254
- pose_img = pose_img.to(device,torch.float16),
255
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
256
- cloth = garm_tensor.to(device,torch.float16),
257
- mask_image=mask,
258
- image=human_img,
259
- height=1024,
260
- width=768,
261
- ip_adapter_image = garm_img.resize((768,1024)),
262
- guidance_scale=2.0,
263
- callback_on_step_end=callback
264
- )[0]
265
- out_img = images[0].resize(crop_size)
266
-
267
- # NSFW 检测
268
- if nsfw_model and nsfw_processor:
269
- if detect_nsfw(out_img):
270
- error_info = {
271
- "error": "Generated image contains NSFW content and cannot be displayed. Please modify your prompt and try again.",
272
- "status": "failed"
273
- }
274
- return None, error_info
275
- info = {
276
- "status": "success"
277
- }
278
- progress(1,desc="Complete")
279
- return out_img, info
280
 
281
 
282
  def infer(person,garment,denoise_steps,seed):
 
26
  from torchvision.transforms.functional import to_pil_image
27
  from PIL import Image, ImageDraw, ImageFont
28
 
29
+ class GenerationError(Exception):
30
+ """Custom exception for generation errors"""
31
+ pass
32
 
33
  # Enhanced logging configuration
34
  logging.basicConfig(
 
167
  progress(0,desc="Starting")
168
  device = "cuda"
169
 
170
+ try:
171
+ openpose_model.preprocessor.body_estimation.model.to(device)
172
+ pipe.to(device)
173
+ pipe.unet_encoder.to(device)
174
 
175
+ personRGB = person.convert("RGB")
176
+ crop_size = personRGB.size
177
 
178
+ human_img = personRGB.resize((768,1024))
179
+ garm_img= garment.convert("RGB").resize((768,1024))
180
+
181
+ progress(0.1,desc="Mask generating")
182
+
183
+ keypoints = openpose_model(human_img.resize((384,512)))
184
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
185
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
186
+ mask = mask.resize((768,1024))
187
 
188
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
189
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
190
 
191
+ progress(0.3,desc="DensePose processing")
192
 
193
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
194
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
195
 
196
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
197
+ # verbosity = getattr(args, "verbosity", None)
198
+ pose_img = args.func(args,human_img_arg)
199
+ pose_img = pose_img[:,:,::-1]
200
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
201
 
202
+ progress(0.5,desc="Image generating")
203
 
204
+ def callback(pipe, step, timestep, callback_kwargs):
205
+ progress_value = 0.5 + ((step+1.0)/denoise_steps)*(0.5/1.0)
206
+ progress(progress_value, desc=f"Image generating, {step + 1}/{denoise_steps} steps")
207
+ return callback_kwargs
208
+
209
+ with torch.no_grad():
210
+ # Extract the images
211
+ with torch.cuda.amp.autocast():
212
+ with torch.no_grad():
213
+ prompt = "model is wearing clothing"
214
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
215
+ with torch.inference_mode():
216
+ (
217
+ prompt_embeds,
218
+ negative_prompt_embeds,
219
+ pooled_prompt_embeds,
220
+ negative_pooled_prompt_embeds,
221
+ ) = pipe.encode_prompt(
222
+ prompt,
223
+ num_images_per_prompt=1,
224
+ do_classifier_free_guidance=True,
225
+ negative_prompt=negative_prompt,
226
+ )
227
+
228
+ prompt = "a photo of clothing"
229
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
230
+ if not isinstance(prompt, List):
231
+ prompt = [prompt] * 1
232
+ if not isinstance(negative_prompt, List):
233
+ negative_prompt = [negative_prompt] * 1
234
+ with torch.inference_mode():
235
+ (
236
+ prompt_embeds_c,
237
+ _,
238
+ _,
239
+ _,
240
+ ) = pipe.encode_prompt(
241
+ prompt,
242
+ num_images_per_prompt=1,
243
+ do_classifier_free_guidance=False,
244
+ negative_prompt=negative_prompt,
245
+ )
246
+
247
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
248
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
249
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
250
+ images = pipe(
251
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
252
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
253
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
254
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
255
+ num_inference_steps=denoise_steps,
256
+ generator=generator,
257
+ strength = 1.0,
258
+ pose_img = pose_img.to(device,torch.float16),
259
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
260
+ cloth = garm_tensor.to(device,torch.float16),
261
+ mask_image=mask,
262
+ image=human_img,
263
+ height=1024,
264
+ width=768,
265
+ ip_adapter_image = garm_img.resize((768,1024)),
266
+ guidance_scale=2.0,
267
+ callback_on_step_end=callback
268
+ )[0]
269
+ out_img = images[0].resize(crop_size)
270
 
271
+ # NSFW 检测
272
+ if nsfw_model and nsfw_processor:
273
+ if detect_nsfw(out_img):
274
+ msg = "Generated image contains NSFW content and cannot be displayed. Please modify your prompt and try again."
275
+ raise Exception(msg)
276
+
277
+ info = {
278
+ "status": "success"
279
+ }
280
+ progress(1,desc="Complete")
281
+ return out_img, info
282
+ except GenerationError as e:
283
+ error_info = {
284
+ "error": str(e),
285
+ "status": "failed",
286
+ }
287
+ return None, error_info
288
+ except Exception as e:
289
+ error_info = {
290
+ "error": str(e),
291
+ "status": "failed",
292
+ }
293
+ return None, error_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
 
296
  def infer(person,garment,denoise_steps,seed):