Spaces:
Runtime error
Runtime error
Update src/pipelines/pipeline_echo_mimic.py
Browse files
src/pipelines/pipeline_echo_mimic.py
CHANGED
|
@@ -34,6 +34,7 @@ from transformers import CLIPImageProcessor
|
|
| 34 |
from src.models.mutual_self_attention import ReferenceAttentionControl
|
| 35 |
from src.pipelines.context import get_context_scheduler
|
| 36 |
from src.pipelines.utils import get_tensor_interpolation_method
|
|
|
|
| 37 |
|
| 38 |
@dataclass
|
| 39 |
class Audio2VideoPipelineOutput(BaseOutput):
|
|
@@ -417,9 +418,9 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 417 |
generator
|
| 418 |
)
|
| 419 |
# print(video_length, latents.shape)
|
| 420 |
-
|
| 421 |
-
uc_face_locator_tensor = torch.zeros_like(
|
| 422 |
-
face_locator_tensor = torch.cat([uc_face_locator_tensor,
|
| 423 |
# Prepare extra step kwargs.
|
| 424 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 425 |
|
|
@@ -474,7 +475,7 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 474 |
encoder_hidden_states=None,
|
| 475 |
return_dict=False,
|
| 476 |
)
|
| 477 |
-
reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=
|
| 478 |
|
| 479 |
|
| 480 |
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
|
@@ -498,8 +499,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 498 |
.to(device)
|
| 499 |
.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
|
| 500 |
)
|
| 501 |
-
|
| 502 |
-
audio_latents = torch.cat([torch.zeros_like(
|
| 503 |
|
| 504 |
latent_model_input = self.scheduler.scale_model_input(
|
| 505 |
latent_model_input, t
|
|
@@ -508,11 +509,15 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 508 |
latent_model_input,
|
| 509 |
t,
|
| 510 |
encoder_hidden_states=None,
|
| 511 |
-
audio_cond_fea=audio_latents,
|
| 512 |
-
face_musk_fea=face_locator_tensor,
|
| 513 |
return_dict=False,
|
| 514 |
)[0]
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
for j, c in enumerate(new_context):
|
| 517 |
noise_pred[:, :, c] = noise_pred[:, :, c] + pred
|
| 518 |
counter[:, :, c] = counter[:, :, c] + 1
|
|
@@ -523,6 +528,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 523 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 524 |
noise_pred_text - noise_pred_uncond
|
| 525 |
)
|
|
|
|
|
|
|
| 526 |
|
| 527 |
latents = self.scheduler.step(
|
| 528 |
noise_pred, t, latents, **extra_step_kwargs
|
|
@@ -583,4 +590,4 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
| 583 |
smoothed_tensor = torch.cat(
|
| 584 |
[tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
|
| 585 |
|
| 586 |
-
return smoothed_tensor
|
|
|
|
| 34 |
from src.models.mutual_self_attention import ReferenceAttentionControl
|
| 35 |
from src.pipelines.context import get_context_scheduler
|
| 36 |
from src.pipelines.utils import get_tensor_interpolation_method
|
| 37 |
+
from src.utils.step_func import origin_by_velocity_and_sample, psuedo_velocity_wrt_noisy_and_timestep, get_alpha
|
| 38 |
|
| 39 |
@dataclass
|
| 40 |
class Audio2VideoPipelineOutput(BaseOutput):
|
|
|
|
| 418 |
generator
|
| 419 |
)
|
| 420 |
# print(video_length, latents.shape)
|
| 421 |
+
c_face_locator_tensor = self.face_locator(face_mask_tensor)
|
| 422 |
+
uc_face_locator_tensor = torch.zeros_like(c_face_locator_tensor)
|
| 423 |
+
face_locator_tensor = torch.cat([uc_face_locator_tensor, c_face_locator_tensor], dim=0)
|
| 424 |
# Prepare extra step kwargs.
|
| 425 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 426 |
|
|
|
|
| 475 |
encoder_hidden_states=None,
|
| 476 |
return_dict=False,
|
| 477 |
)
|
| 478 |
+
reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=do_classifier_free_guidance)
|
| 479 |
|
| 480 |
|
| 481 |
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
|
|
|
| 499 |
.to(device)
|
| 500 |
.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
|
| 501 |
)
|
| 502 |
+
c_audio_latents = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
|
| 503 |
+
audio_latents = torch.cat([torch.zeros_like(c_audio_latents), c_audio_latents], 0)
|
| 504 |
|
| 505 |
latent_model_input = self.scheduler.scale_model_input(
|
| 506 |
latent_model_input, t
|
|
|
|
| 509 |
latent_model_input,
|
| 510 |
t,
|
| 511 |
encoder_hidden_states=None,
|
| 512 |
+
audio_cond_fea=audio_latents if do_classifier_free_guidance else c_audio_latents,
|
| 513 |
+
face_musk_fea=face_locator_tensor if do_classifier_free_guidance else c_face_locator_tensor,
|
| 514 |
return_dict=False,
|
| 515 |
)[0]
|
| 516 |
|
| 517 |
+
alphas_cumprod = self.scheduler.alphas_cumprod.to(latent_model_input.device)
|
| 518 |
+
x_pred = origin_by_velocity_and_sample(pred, latent_model_input, alphas_cumprod, t)
|
| 519 |
+
pred = psuedo_velocity_wrt_noisy_and_timestep(latent_model_input, x_pred, alphas_cumprod, t, torch.ones_like(t) * (-1))
|
| 520 |
+
|
| 521 |
for j, c in enumerate(new_context):
|
| 522 |
noise_pred[:, :, c] = noise_pred[:, :, c] + pred
|
| 523 |
counter[:, :, c] = counter[:, :, c] + 1
|
|
|
|
| 528 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 529 |
noise_pred_text - noise_pred_uncond
|
| 530 |
)
|
| 531 |
+
else:
|
| 532 |
+
noise_pred = noise_pred / counter
|
| 533 |
|
| 534 |
latents = self.scheduler.step(
|
| 535 |
noise_pred, t, latents, **extra_step_kwargs
|
|
|
|
| 590 |
smoothed_tensor = torch.cat(
|
| 591 |
[tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
|
| 592 |
|
| 593 |
+
return smoothed_tensor
|