Commit
·
0cd6025
0
Parent(s):
clean repo without raw binaries
Browse files- .gitattributes +36 -0
- .gitignore +33 -0
- Inference.py +635 -0
- Inference_with_status.py +410 -0
- README.md +189 -0
- app.py +671 -0
- console_capture.py +45 -0
- demo1/mix.mp4 +3 -0
- face_detection_utils.py +122 -0
- look2hear/datas/transform.py +191 -0
- look2hear/models/__init__.py +1 -0
- look2hear/models/dolphin.py +1376 -0
- look2hear/models/video_compoent.py +876 -0
- requirements.txt +14 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules
|
| 11 |
+
.DS_Store
|
| 12 |
+
dist
|
| 13 |
+
dist-ssr
|
| 14 |
+
coverage
|
| 15 |
+
*.local
|
| 16 |
+
|
| 17 |
+
/cypress/videos/
|
| 18 |
+
/cypress/screenshots/
|
| 19 |
+
|
| 20 |
+
# Editor directories and files
|
| 21 |
+
.vscode/*
|
| 22 |
+
!.vscode/extensions.json
|
| 23 |
+
.idea
|
| 24 |
+
*.suo
|
| 25 |
+
*.ntvs*
|
| 26 |
+
*.njsproj
|
| 27 |
+
*.sln
|
| 28 |
+
*.sw?
|
| 29 |
+
yarn.lock
|
| 30 |
+
|
| 31 |
+
tmp/*
|
| 32 |
+
.gradio
|
| 33 |
+
*.pyc
|
Inference.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import face_alignment
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
from PIL import Image, ImageDraw
|
| 11 |
+
from moviepy import *
|
| 12 |
+
from collections import deque
|
| 13 |
+
from skimage import transform as tf
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
from look2hear.models import Dolphin
|
| 17 |
+
from look2hear.datas.transform import get_preprocessing_pipelines
|
| 18 |
+
|
| 19 |
+
from face_detection_utils import detect_faces
|
| 20 |
+
|
| 21 |
+
# -- Landmark interpolation:
|
| 22 |
+
def linear_interpolate(landmarks, start_idx, stop_idx):
|
| 23 |
+
start_landmarks = landmarks[start_idx]
|
| 24 |
+
stop_landmarks = landmarks[stop_idx]
|
| 25 |
+
delta = stop_landmarks - start_landmarks
|
| 26 |
+
for idx in range(1, stop_idx-start_idx):
|
| 27 |
+
landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta
|
| 28 |
+
return landmarks
|
| 29 |
+
|
| 30 |
+
# -- Face Transformation
|
| 31 |
+
def warp_img(src, dst, img, std_size):
|
| 32 |
+
tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix
|
| 33 |
+
warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image
|
| 34 |
+
warped = warped * 255 # note output from wrap is double image (value range [0,1])
|
| 35 |
+
warped = warped.astype('uint8')
|
| 36 |
+
return warped, tform
|
| 37 |
+
|
| 38 |
+
def apply_transform(transform, img, std_size):
|
| 39 |
+
warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
|
| 40 |
+
warped = warped * 255 # note output from wrap is double image (value range [0,1])
|
| 41 |
+
warped = warped.astype('uint8')
|
| 42 |
+
return warped
|
| 43 |
+
|
| 44 |
+
# -- Crop
|
| 45 |
+
def cut_patch(img, landmarks, height, width, threshold=5):
|
| 46 |
+
|
| 47 |
+
center_x, center_y = np.mean(landmarks, axis=0)
|
| 48 |
+
|
| 49 |
+
if center_y - height < 0:
|
| 50 |
+
center_y = height
|
| 51 |
+
if center_y - height < 0 - threshold:
|
| 52 |
+
raise Exception('too much bias in height')
|
| 53 |
+
if center_x - width < 0:
|
| 54 |
+
center_x = width
|
| 55 |
+
if center_x - width < 0 - threshold:
|
| 56 |
+
raise Exception('too much bias in width')
|
| 57 |
+
|
| 58 |
+
if center_y + height > img.shape[0]:
|
| 59 |
+
center_y = img.shape[0] - height
|
| 60 |
+
if center_y + height > img.shape[0] + threshold:
|
| 61 |
+
raise Exception('too much bias in height')
|
| 62 |
+
if center_x + width > img.shape[1]:
|
| 63 |
+
center_x = img.shape[1] - width
|
| 64 |
+
if center_x + width > img.shape[1] + threshold:
|
| 65 |
+
raise Exception('too much bias in width')
|
| 66 |
+
|
| 67 |
+
cutted_img = np.copy(img[ int(round(center_y) - round(height)): int(round(center_y) + round(height)),
|
| 68 |
+
int(round(center_x) - round(width)): int(round(center_x) + round(width))])
|
| 69 |
+
return cutted_img
|
| 70 |
+
|
| 71 |
+
# -- RGB to GRAY
|
| 72 |
+
def convert_bgr2gray(data):
|
| 73 |
+
return np.stack([cv2.cvtColor(_, cv2.COLOR_BGR2GRAY) for _ in data], axis=0)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def save2npz(filename, data=None):
|
| 77 |
+
assert data is not None, "data is {}".format(data)
|
| 78 |
+
if not os.path.exists(os.path.dirname(filename)):
|
| 79 |
+
os.makedirs(os.path.dirname(filename))
|
| 80 |
+
np.savez_compressed(filename, data=data)
|
| 81 |
+
|
| 82 |
+
def read_video(filename):
|
| 83 |
+
"""Read video frames using MoviePy for better compatibility"""
|
| 84 |
+
try:
|
| 85 |
+
video_clip = VideoFileClip(filename)
|
| 86 |
+
for frame in video_clip.iter_frames():
|
| 87 |
+
# Convert RGB to BGR to match cv2 format
|
| 88 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 89 |
+
yield frame_bgr
|
| 90 |
+
video_clip.close()
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Error reading video {filename}: {e}")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
def face2head(boxes, scale=1.5):
|
| 96 |
+
new_boxes = []
|
| 97 |
+
for box in boxes:
|
| 98 |
+
width = box[2] - box[0]
|
| 99 |
+
height= box[3] - box[1]
|
| 100 |
+
width_center = (box[2] + box[0]) / 2
|
| 101 |
+
height_center = (box[3] + box[1]) / 2
|
| 102 |
+
square_width = int(max(width, height) * scale)
|
| 103 |
+
new_box = [width_center - square_width/2, height_center - square_width/2, width_center + square_width/2, height_center + square_width/2]
|
| 104 |
+
new_boxes.append(new_box)
|
| 105 |
+
return new_boxes
|
| 106 |
+
|
| 107 |
+
def bb_intersection_over_union(boxA, boxB):
|
| 108 |
+
# determine the (x, y)-coordinates of the intersection rectangle
|
| 109 |
+
xA = max(boxA[0], boxB[0])
|
| 110 |
+
yA = max(boxA[1], boxB[1])
|
| 111 |
+
xB = min(boxA[2], boxB[2])
|
| 112 |
+
yB = min(boxA[3], boxB[3])
|
| 113 |
+
# compute the area of intersection rectangle
|
| 114 |
+
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
| 115 |
+
# compute the area of both the prediction and ground-truth
|
| 116 |
+
# rectangles
|
| 117 |
+
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
| 118 |
+
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
| 119 |
+
# compute the intersection over union by taking the intersection
|
| 120 |
+
# area and dividing it by the sum of prediction + ground-truth
|
| 121 |
+
# areas - the interesection area
|
| 122 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
| 123 |
+
# return the intersection over union value
|
| 124 |
+
return iou
|
| 125 |
+
|
| 126 |
+
def detectface(video_input_path, output_path, detect_every_N_frame, scalar_face_detection, number_of_speakers):
|
| 127 |
+
device = torch.device('cuda' if torch.cuda.get_device_name() else 'cpu')
|
| 128 |
+
print('Running on device: {}'.format(device))
|
| 129 |
+
os.makedirs(os.path.join(output_path, 'faces'), exist_ok=True)
|
| 130 |
+
os.makedirs(os.path.join(output_path, 'landmark'), exist_ok=True)
|
| 131 |
+
|
| 132 |
+
landmarks_dic = {}
|
| 133 |
+
faces_dic = {}
|
| 134 |
+
boxes_dic = {}
|
| 135 |
+
|
| 136 |
+
for i in range(number_of_speakers):
|
| 137 |
+
landmarks_dic[i] = []
|
| 138 |
+
faces_dic[i] = []
|
| 139 |
+
boxes_dic[i] = []
|
| 140 |
+
|
| 141 |
+
video_clip = VideoFileClip(video_input_path)
|
| 142 |
+
print("Video statistics: ", video_clip.w, video_clip.h, (video_clip.w, video_clip.h), video_clip.fps)
|
| 143 |
+
frames = [Image.fromarray(frame) for frame in video_clip.iter_frames()]
|
| 144 |
+
print('Number of frames in video: ', len(frames))
|
| 145 |
+
video_clip.close()
|
| 146 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
|
| 147 |
+
|
| 148 |
+
for i, frame in enumerate(frames):
|
| 149 |
+
print('\rTracking frame: {}'.format(i + 1), end='')
|
| 150 |
+
|
| 151 |
+
# Detect faces every N frames
|
| 152 |
+
if i % detect_every_N_frame == 0:
|
| 153 |
+
frame_array = np.array(frame)
|
| 154 |
+
|
| 155 |
+
detected_boxes, _ = detect_faces(
|
| 156 |
+
frame_array,
|
| 157 |
+
threshold=0.9,
|
| 158 |
+
allow_upscaling=False,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if detected_boxes is None or len(detected_boxes) == 0:
|
| 162 |
+
detected_boxes, _ = detect_faces(
|
| 163 |
+
frame_array,
|
| 164 |
+
threshold=0.7,
|
| 165 |
+
allow_upscaling=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if detected_boxes is not None and len(detected_boxes) > 0:
|
| 169 |
+
detected_boxes = detected_boxes[:number_of_speakers]
|
| 170 |
+
detected_boxes = face2head(detected_boxes, scalar_face_detection)
|
| 171 |
+
else:
|
| 172 |
+
detected_boxes = []
|
| 173 |
+
|
| 174 |
+
# Process the detection results
|
| 175 |
+
if i == 0:
|
| 176 |
+
# First frame - initialize tracking
|
| 177 |
+
if len(detected_boxes) < number_of_speakers:
|
| 178 |
+
raise ValueError(f"First frame must detect at least {number_of_speakers} faces, but only found {len(detected_boxes)}")
|
| 179 |
+
|
| 180 |
+
# Assign first detections to speakers in order
|
| 181 |
+
for j in range(number_of_speakers):
|
| 182 |
+
box = detected_boxes[j]
|
| 183 |
+
face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
|
| 184 |
+
preds = fa.get_landmarks(np.array(face))
|
| 185 |
+
|
| 186 |
+
if preds is None:
|
| 187 |
+
raise ValueError(f"Face landmarks not detected in initial frame for speaker {j}")
|
| 188 |
+
|
| 189 |
+
faces_dic[j].append(face)
|
| 190 |
+
landmarks_dic[j].append(preds)
|
| 191 |
+
boxes_dic[j].append(box)
|
| 192 |
+
else:
|
| 193 |
+
# For subsequent frames, match detected boxes to speakers
|
| 194 |
+
matched_speakers = set()
|
| 195 |
+
speaker_boxes = [None] * number_of_speakers
|
| 196 |
+
|
| 197 |
+
# Match each detected box to the most likely speaker
|
| 198 |
+
for box in detected_boxes:
|
| 199 |
+
iou_scores = []
|
| 200 |
+
for speaker_id in range(number_of_speakers):
|
| 201 |
+
if speaker_id in matched_speakers:
|
| 202 |
+
iou_scores.append(-1) # Already matched
|
| 203 |
+
else:
|
| 204 |
+
last_box = boxes_dic[speaker_id][-1]
|
| 205 |
+
iou_score = bb_intersection_over_union(box, last_box)
|
| 206 |
+
iou_scores.append(iou_score)
|
| 207 |
+
|
| 208 |
+
if max(iou_scores) > 0: # Valid match found
|
| 209 |
+
best_speaker = iou_scores.index(max(iou_scores))
|
| 210 |
+
speaker_boxes[best_speaker] = box
|
| 211 |
+
matched_speakers.add(best_speaker)
|
| 212 |
+
|
| 213 |
+
# Process each speaker
|
| 214 |
+
for speaker_id in range(number_of_speakers):
|
| 215 |
+
if speaker_boxes[speaker_id] is not None:
|
| 216 |
+
# Use detected box
|
| 217 |
+
box = speaker_boxes[speaker_id]
|
| 218 |
+
else:
|
| 219 |
+
# Use previous box for this speaker
|
| 220 |
+
box = boxes_dic[speaker_id][-1]
|
| 221 |
+
|
| 222 |
+
# Extract face and landmarks
|
| 223 |
+
face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
|
| 224 |
+
preds = fa.get_landmarks(np.array(face))
|
| 225 |
+
|
| 226 |
+
if preds is None:
|
| 227 |
+
# Use previous landmarks if detection fails
|
| 228 |
+
preds = landmarks_dic[speaker_id][-1]
|
| 229 |
+
|
| 230 |
+
faces_dic[speaker_id].append(face)
|
| 231 |
+
landmarks_dic[speaker_id].append(preds)
|
| 232 |
+
boxes_dic[speaker_id].append(box)
|
| 233 |
+
|
| 234 |
+
# Verify all speakers have same number of frames
|
| 235 |
+
frame_counts = [len(boxes_dic[s]) for s in range(number_of_speakers)]
|
| 236 |
+
print(f"\nFrame counts per speaker: {frame_counts}")
|
| 237 |
+
assert all(count == len(frames) for count in frame_counts), f"Inconsistent frame counts: {frame_counts}"
|
| 238 |
+
|
| 239 |
+
# Continue with saving videos and landmarks...
|
| 240 |
+
for s in range(number_of_speakers):
|
| 241 |
+
frames_tracked = []
|
| 242 |
+
for i, frame in enumerate(frames):
|
| 243 |
+
frame_draw = frame.copy()
|
| 244 |
+
draw = ImageDraw.Draw(frame_draw)
|
| 245 |
+
draw.rectangle(boxes_dic[s][i], outline=(255, 0, 0), width=6)
|
| 246 |
+
frames_tracked.append(frame_draw)
|
| 247 |
+
|
| 248 |
+
# Save tracked video
|
| 249 |
+
tracked_frames = [np.array(frame) for frame in frames_tracked]
|
| 250 |
+
if tracked_frames:
|
| 251 |
+
tracked_clip = ImageSequenceClip(tracked_frames, fps=25.0)
|
| 252 |
+
tracked_video_path = os.path.join(output_path, 'video_tracked' + str(s+1) + '.mp4')
|
| 253 |
+
tracked_clip.write_videofile(tracked_video_path, codec='libx264', audio=False, logger=None)
|
| 254 |
+
tracked_clip.close()
|
| 255 |
+
|
| 256 |
+
# Save landmarks
|
| 257 |
+
for i in range(number_of_speakers):
|
| 258 |
+
save2npz(os.path.join(output_path, 'landmark', 'speaker' + str(i+1)+'.npz'), data=landmarks_dic[i])
|
| 259 |
+
|
| 260 |
+
# Save face video
|
| 261 |
+
face_frames = [np.array(frame) for frame in faces_dic[i]]
|
| 262 |
+
if face_frames:
|
| 263 |
+
face_clip = ImageSequenceClip(face_frames, fps=25.0)
|
| 264 |
+
face_video_path = os.path.join(output_path, 'faces', 'speaker' + str(i+1) + '.mp4')
|
| 265 |
+
face_clip.write_videofile(face_video_path, codec='libx264', audio=False, logger=None)
|
| 266 |
+
face_clip.close()
|
| 267 |
+
|
| 268 |
+
# Output video path
|
| 269 |
+
parts = video_input_path.split('/')
|
| 270 |
+
video_name = parts[-1][:-4]
|
| 271 |
+
if not os.path.exists(os.path.join(output_path, 'filename_input')):
|
| 272 |
+
os.mkdir(os.path.join(output_path, 'filename_input'))
|
| 273 |
+
csvfile = open(os.path.join(output_path, 'filename_input', str(video_name) + '.csv'), 'w')
|
| 274 |
+
for i in range(number_of_speakers):
|
| 275 |
+
csvfile.write('speaker' + str(i+1)+ ',0\n')
|
| 276 |
+
csvfile.close()
|
| 277 |
+
return os.path.join(output_path, 'filename_input', str(video_name) + '.csv')
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def crop_patch(mean_face_landmarks, video_pathname, landmarks, window_margin, start_idx, stop_idx, crop_height, crop_width, STD_SIZE=(256, 256)):
|
| 281 |
+
|
| 282 |
+
"""Crop mouth patch
|
| 283 |
+
:param str video_pathname: pathname for the video_dieo
|
| 284 |
+
:param list landmarks: interpolated landmarks
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
stablePntsIDs = [33, 36, 39, 42, 45]
|
| 288 |
+
|
| 289 |
+
frame_idx = 0
|
| 290 |
+
frame_gen = read_video(video_pathname)
|
| 291 |
+
while True:
|
| 292 |
+
try:
|
| 293 |
+
frame = frame_gen.__next__() ## -- BGR
|
| 294 |
+
except StopIteration:
|
| 295 |
+
break
|
| 296 |
+
if frame_idx == 0:
|
| 297 |
+
q_frame, q_landmarks = deque(), deque()
|
| 298 |
+
sequence = []
|
| 299 |
+
|
| 300 |
+
q_landmarks.append(landmarks[frame_idx])
|
| 301 |
+
q_frame.append(frame)
|
| 302 |
+
if len(q_frame) == window_margin:
|
| 303 |
+
smoothed_landmarks = np.mean(q_landmarks, axis=0)
|
| 304 |
+
cur_landmarks = q_landmarks.popleft()
|
| 305 |
+
cur_frame = q_frame.popleft()
|
| 306 |
+
# -- affine transformation
|
| 307 |
+
trans_frame, trans = warp_img( smoothed_landmarks[stablePntsIDs, :],
|
| 308 |
+
mean_face_landmarks[stablePntsIDs, :],
|
| 309 |
+
cur_frame,
|
| 310 |
+
STD_SIZE)
|
| 311 |
+
trans_landmarks = trans(cur_landmarks)
|
| 312 |
+
# -- crop mouth patch
|
| 313 |
+
sequence.append( cut_patch( trans_frame,
|
| 314 |
+
trans_landmarks[start_idx:stop_idx],
|
| 315 |
+
crop_height//2,
|
| 316 |
+
crop_width//2,))
|
| 317 |
+
if frame_idx == len(landmarks)-1:
|
| 318 |
+
#deal with corner case with video too short
|
| 319 |
+
if len(landmarks) < window_margin:
|
| 320 |
+
smoothed_landmarks = np.mean(q_landmarks, axis=0)
|
| 321 |
+
cur_landmarks = q_landmarks.popleft()
|
| 322 |
+
cur_frame = q_frame.popleft()
|
| 323 |
+
|
| 324 |
+
# -- affine transformation
|
| 325 |
+
trans_frame, trans = warp_img(smoothed_landmarks[stablePntsIDs, :],
|
| 326 |
+
mean_face_landmarks[stablePntsIDs, :],
|
| 327 |
+
cur_frame,
|
| 328 |
+
STD_SIZE)
|
| 329 |
+
trans_landmarks = trans(cur_landmarks)
|
| 330 |
+
# -- crop mouth patch
|
| 331 |
+
sequence.append(cut_patch( trans_frame,
|
| 332 |
+
trans_landmarks[start_idx:stop_idx],
|
| 333 |
+
crop_height//2,
|
| 334 |
+
crop_width//2,))
|
| 335 |
+
|
| 336 |
+
while q_frame:
|
| 337 |
+
cur_frame = q_frame.popleft()
|
| 338 |
+
# -- transform frame
|
| 339 |
+
trans_frame = apply_transform( trans, cur_frame, STD_SIZE)
|
| 340 |
+
# -- transform landmarks
|
| 341 |
+
trans_landmarks = trans(q_landmarks.popleft())
|
| 342 |
+
# -- crop mouth patch
|
| 343 |
+
sequence.append( cut_patch( trans_frame,
|
| 344 |
+
trans_landmarks[start_idx:stop_idx],
|
| 345 |
+
crop_height//2,
|
| 346 |
+
crop_width//2,))
|
| 347 |
+
return np.array(sequence)
|
| 348 |
+
frame_idx += 1
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
def landmarks_interpolate(landmarks):
|
| 352 |
+
|
| 353 |
+
"""Interpolate landmarks
|
| 354 |
+
param list landmarks: landmarks detected in raw videos
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
|
| 358 |
+
if not valid_frames_idx:
|
| 359 |
+
return None
|
| 360 |
+
for idx in range(1, len(valid_frames_idx)):
|
| 361 |
+
if valid_frames_idx[idx] - valid_frames_idx[idx-1] == 1:
|
| 362 |
+
continue
|
| 363 |
+
else:
|
| 364 |
+
landmarks = linear_interpolate(landmarks, valid_frames_idx[idx-1], valid_frames_idx[idx])
|
| 365 |
+
valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
|
| 366 |
+
# -- Corner case: keep frames at the beginning or at the end failed to be detected.
|
| 367 |
+
if valid_frames_idx:
|
| 368 |
+
landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
|
| 369 |
+
landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1])
|
| 370 |
+
valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
|
| 371 |
+
assert len(valid_frames_idx) == len(landmarks), "not every frame has landmark"
|
| 372 |
+
return landmarks
|
| 373 |
+
|
| 374 |
+
def crop_mouth(video_direc, landmark_direc, filename_path, save_direc, convert_gray=False, testset_only=False):
|
| 375 |
+
lines = open(filename_path).read().splitlines()
|
| 376 |
+
lines = list(filter(lambda x: 'test' in x, lines)) if testset_only else lines
|
| 377 |
+
|
| 378 |
+
for filename_idx, line in enumerate(lines):
|
| 379 |
+
|
| 380 |
+
filename, person_id = line.split(',')
|
| 381 |
+
print('idx: {} \tProcessing.\t{}'.format(filename_idx, filename))
|
| 382 |
+
|
| 383 |
+
video_pathname = os.path.join(video_direc, filename+'.mp4')
|
| 384 |
+
landmarks_pathname = os.path.join(landmark_direc, filename+'.npz')
|
| 385 |
+
dst_pathname = os.path.join( save_direc, filename+'.npz')
|
| 386 |
+
|
| 387 |
+
# if os.path.exists(dst_pathname):
|
| 388 |
+
# continue
|
| 389 |
+
|
| 390 |
+
multi_sub_landmarks = np.load(landmarks_pathname, allow_pickle=True)['data']
|
| 391 |
+
landmarks = [None] * len(multi_sub_landmarks)
|
| 392 |
+
for frame_idx in range(len(landmarks)):
|
| 393 |
+
try:
|
| 394 |
+
#landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)]['facial_landmarks'] #original for LRW
|
| 395 |
+
landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)] #VOXCELEB2
|
| 396 |
+
except (IndexError, TypeError):
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
# -- pre-process landmarks: interpolate frames not being detected.
|
| 400 |
+
preprocessed_landmarks = landmarks_interpolate(landmarks)
|
| 401 |
+
if not preprocessed_landmarks:
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
# -- crop
|
| 405 |
+
mean_face_landmarks = np.load('assets/20words_mean_face.npy')
|
| 406 |
+
sequence = crop_patch(mean_face_landmarks, video_pathname, preprocessed_landmarks, 12, 48, 68, 96, 96)
|
| 407 |
+
assert sequence is not None, "cannot crop from {}.".format(filename)
|
| 408 |
+
|
| 409 |
+
# -- save
|
| 410 |
+
data = convert_bgr2gray(sequence) if convert_gray else sequence[...,::-1]
|
| 411 |
+
save2npz(dst_pathname, data=data)
|
| 412 |
+
|
| 413 |
+
def convert_video_fps(input_file, output_file, target_fps=25):
|
| 414 |
+
"""Convert video to target FPS using moviepy"""
|
| 415 |
+
video = VideoFileClip(input_file)
|
| 416 |
+
video_fps = video.fps
|
| 417 |
+
|
| 418 |
+
if video_fps != target_fps:
|
| 419 |
+
video.write_videofile(
|
| 420 |
+
output_file,
|
| 421 |
+
fps=target_fps,
|
| 422 |
+
codec='libx264',
|
| 423 |
+
audio_codec='aac',
|
| 424 |
+
temp_audiofile='temp-audio.m4a',
|
| 425 |
+
remove_temp=True,
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
# If already at target fps, just copy
|
| 429 |
+
import shutil
|
| 430 |
+
shutil.copy2(input_file, output_file)
|
| 431 |
+
|
| 432 |
+
video.close()
|
| 433 |
+
print(f'Video has been converted to {target_fps} fps and saved to {output_file}')
|
| 434 |
+
|
| 435 |
+
def extract_audio(video_file, audio_output_file, sample_rate=16000):
|
| 436 |
+
"""Extract audio from video using moviepy"""
|
| 437 |
+
video = VideoFileClip(video_file)
|
| 438 |
+
audio = video.audio
|
| 439 |
+
|
| 440 |
+
# Save audio with specified sample rate
|
| 441 |
+
audio.write_audiofile(audio_output_file, fps=sample_rate, nbytes=2, codec='pcm_s16le')
|
| 442 |
+
|
| 443 |
+
video.close()
|
| 444 |
+
audio.close()
|
| 445 |
+
|
| 446 |
+
def merge_video_audio(video_file, audio_file, output_file):
|
| 447 |
+
"""Merge video and audio using moviepy"""
|
| 448 |
+
video = VideoFileClip(video_file)
|
| 449 |
+
audio = AudioFileClip(audio_file)
|
| 450 |
+
|
| 451 |
+
# Attach audio (MoviePy v2 renamed set_audio to with_audio)
|
| 452 |
+
set_audio_fn = getattr(video, "set_audio", None)
|
| 453 |
+
if callable(set_audio_fn):
|
| 454 |
+
final_video = set_audio_fn(audio)
|
| 455 |
+
else:
|
| 456 |
+
with_audio_fn = getattr(video, "with_audio", None)
|
| 457 |
+
if not callable(with_audio_fn):
|
| 458 |
+
video.close()
|
| 459 |
+
audio.close()
|
| 460 |
+
raise AttributeError("VideoFileClip object lacks both set_audio and with_audio methods")
|
| 461 |
+
final_video = with_audio_fn(audio)
|
| 462 |
+
|
| 463 |
+
# Write the result
|
| 464 |
+
final_video.write_videofile(output_file, codec='libx264', audio_codec='aac', temp_audiofile='temp-audio.m4a', remove_temp=True)
|
| 465 |
+
|
| 466 |
+
# Clean up
|
| 467 |
+
video.close()
|
| 468 |
+
audio.close()
|
| 469 |
+
final_video.close()
|
| 470 |
+
|
| 471 |
+
def process_video(input_file, output_path, number_of_speakers=2,
|
| 472 |
+
detect_every_N_frame=8, scalar_face_detection=1.5,
|
| 473 |
+
config_path="checkpoints/vox2/conf.yml",
|
| 474 |
+
cuda_device=None):
|
| 475 |
+
"""Main processing function for video speaker separation"""
|
| 476 |
+
|
| 477 |
+
# Set CUDA device if specified
|
| 478 |
+
if cuda_device is not None:
|
| 479 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)
|
| 480 |
+
|
| 481 |
+
# Create output directory
|
| 482 |
+
os.makedirs(output_path, exist_ok=True)
|
| 483 |
+
|
| 484 |
+
# Convert video to 25fps
|
| 485 |
+
temp_25fps_file = os.path.join(output_path, 'temp_25fps.mp4')
|
| 486 |
+
convert_video_fps(input_file, temp_25fps_file, target_fps=25)
|
| 487 |
+
|
| 488 |
+
# Detect faces
|
| 489 |
+
filename_path = detectface(video_input_path=temp_25fps_file,
|
| 490 |
+
output_path=output_path,
|
| 491 |
+
detect_every_N_frame=detect_every_N_frame,
|
| 492 |
+
scalar_face_detection=scalar_face_detection,
|
| 493 |
+
number_of_speakers=number_of_speakers)
|
| 494 |
+
|
| 495 |
+
# Extract audio
|
| 496 |
+
audio_output = os.path.join(output_path, 'audio.wav')
|
| 497 |
+
extract_audio(temp_25fps_file, audio_output, sample_rate=16000)
|
| 498 |
+
|
| 499 |
+
# Crop mouth
|
| 500 |
+
crop_mouth(video_direc=os.path.join(output_path, "faces"),
|
| 501 |
+
landmark_direc=os.path.join(output_path, "landmark"),
|
| 502 |
+
filename_path=filename_path,
|
| 503 |
+
save_direc=os.path.join(output_path, "mouthroi"),
|
| 504 |
+
convert_gray=True,
|
| 505 |
+
testset_only=False)
|
| 506 |
+
|
| 507 |
+
# Load model
|
| 508 |
+
audiomodel = Dolphin.from_pretrained("JusperLee/Dolphin")
|
| 509 |
+
|
| 510 |
+
audiomodel.cuda()
|
| 511 |
+
audiomodel.eval()
|
| 512 |
+
|
| 513 |
+
# Process each speaker
|
| 514 |
+
with torch.no_grad():
|
| 515 |
+
for i in range(number_of_speakers):
|
| 516 |
+
mouth_roi = np.load(os.path.join(output_path, "mouthroi", f"speaker{i+1}.npz"))["data"]
|
| 517 |
+
mouth_roi = get_preprocessing_pipelines()["val"](mouth_roi)
|
| 518 |
+
|
| 519 |
+
mix, sr = torchaudio.load(audio_output)
|
| 520 |
+
mix = mix.cuda().mean(dim=0)
|
| 521 |
+
|
| 522 |
+
window_size = 4 * sr
|
| 523 |
+
hop_size = 4 * sr
|
| 524 |
+
|
| 525 |
+
all_estimates = []
|
| 526 |
+
|
| 527 |
+
# 滑动窗口处理
|
| 528 |
+
start_idx = 0
|
| 529 |
+
while start_idx < len(mix):
|
| 530 |
+
end_idx = min(start_idx + window_size, len(mix))
|
| 531 |
+
window_mix = mix[start_idx:end_idx]
|
| 532 |
+
|
| 533 |
+
start_frame = int(start_idx / sr * 25)
|
| 534 |
+
end_frame = int(end_idx / sr * 25)
|
| 535 |
+
end_frame = min(end_frame, len(mouth_roi))
|
| 536 |
+
window_mouth_roi = mouth_roi[start_frame:end_frame]
|
| 537 |
+
|
| 538 |
+
est_sources = audiomodel(window_mix[None],
|
| 539 |
+
torch.from_numpy(window_mouth_roi[None, None]).float().cuda())
|
| 540 |
+
|
| 541 |
+
all_estimates.append({
|
| 542 |
+
'start': start_idx,
|
| 543 |
+
'end': end_idx,
|
| 544 |
+
'estimate': est_sources[0].cpu()
|
| 545 |
+
})
|
| 546 |
+
|
| 547 |
+
start_idx += hop_size
|
| 548 |
+
|
| 549 |
+
if start_idx >= len(mix):
|
| 550 |
+
break
|
| 551 |
+
|
| 552 |
+
output_length = len(mix)
|
| 553 |
+
merged_output = torch.zeros(1, output_length)
|
| 554 |
+
weights = torch.zeros(output_length)
|
| 555 |
+
|
| 556 |
+
for est in all_estimates:
|
| 557 |
+
window_len = est['end'] - est['start']
|
| 558 |
+
hann_window = torch.hann_window(window_len)
|
| 559 |
+
|
| 560 |
+
merged_output[0, est['start']:est['end']] += est['estimate'][0, :window_len] * hann_window
|
| 561 |
+
weights[est['start']:est['end']] += hann_window
|
| 562 |
+
|
| 563 |
+
merged_output[:, weights > 0] /= weights[weights > 0]
|
| 564 |
+
|
| 565 |
+
torchaudio.save(os.path.join(output_path, f"speaker{i+1}_est.wav"), merged_output, sr)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# Merge video with separated audio for each speaker
|
| 569 |
+
output_files = []
|
| 570 |
+
for i in range(number_of_speakers):
|
| 571 |
+
video_input = os.path.join(output_path, f"video_tracked{i+1}.mp4")
|
| 572 |
+
audio_input = os.path.join(output_path, f"speaker{i+1}_est.wav")
|
| 573 |
+
video_output = os.path.join(output_path, f"s{i+1}.mp4")
|
| 574 |
+
|
| 575 |
+
merge_video_audio(video_input, audio_input, video_output)
|
| 576 |
+
output_files.append(video_output)
|
| 577 |
+
|
| 578 |
+
# Clean up temporary file
|
| 579 |
+
if os.path.exists(temp_25fps_file):
|
| 580 |
+
os.remove(temp_25fps_file)
|
| 581 |
+
|
| 582 |
+
return output_files
|
| 583 |
+
|
| 584 |
+
if __name__ == '__main__':
|
| 585 |
+
parser = argparse.ArgumentParser(description='Video Speaker Separation using Dolphin model')
|
| 586 |
+
parser.add_argument('--input', '-i', type=str, required=True,
|
| 587 |
+
help='Path to input video file')
|
| 588 |
+
parser.add_argument('--output', '-o', type=str, default=None,
|
| 589 |
+
help='Output directory path (default: creates directory based on input filename)')
|
| 590 |
+
parser.add_argument('--speakers', '-s', type=int, default=2,
|
| 591 |
+
help='Number of speakers to separate (default: 2)')
|
| 592 |
+
parser.add_argument('--detect-every-n', type=int, default=8,
|
| 593 |
+
help='Detect faces every N frames (default: 8)')
|
| 594 |
+
parser.add_argument('--face-scale', type=float, default=1.5,
|
| 595 |
+
help='Face detection bounding box scale factor (default: 1.5)')
|
| 596 |
+
parser.add_argument('--cuda-device', type=int, default=0,
|
| 597 |
+
help='CUDA device ID to use (default: 0, set to -1 for CPU)')
|
| 598 |
+
parser.add_argument('--config', type=str, default="checkpoints/vox2/conf.yml",
|
| 599 |
+
help='Path to model configuration file')
|
| 600 |
+
|
| 601 |
+
args = parser.parse_args()
|
| 602 |
+
|
| 603 |
+
# 验证输入文件是否存在
|
| 604 |
+
if not os.path.exists(args.input):
|
| 605 |
+
print(f"Error: Input file '{args.input}' does not exist")
|
| 606 |
+
exit(1)
|
| 607 |
+
|
| 608 |
+
# 如果没有指定输出路径,基于输入文件名创建输出目录
|
| 609 |
+
if args.output is None:
|
| 610 |
+
input_basename = os.path.splitext(os.path.basename(args.input))[0]
|
| 611 |
+
args.output = os.path.join(os.path.dirname(args.input), input_basename + "_output")
|
| 612 |
+
|
| 613 |
+
# 设置CUDA设备
|
| 614 |
+
cuda_device = args.cuda_device if args.cuda_device >= 0 else None
|
| 615 |
+
|
| 616 |
+
print(f"Processing video: {args.input}")
|
| 617 |
+
print(f"Output directory: {args.output}")
|
| 618 |
+
print(f"Number of speakers: {args.speakers}")
|
| 619 |
+
print(f"CUDA device: {cuda_device if cuda_device is not None else 'CPU'}")
|
| 620 |
+
|
| 621 |
+
# 处理视频
|
| 622 |
+
output_files = process_video(
|
| 623 |
+
input_file=args.input,
|
| 624 |
+
output_path=args.output,
|
| 625 |
+
number_of_speakers=args.speakers,
|
| 626 |
+
detect_every_N_frame=args.detect_every_n,
|
| 627 |
+
scalar_face_detection=args.face_scale,
|
| 628 |
+
config_path=args.config,
|
| 629 |
+
cuda_device=cuda_device
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
print("\nProcessing completed!")
|
| 633 |
+
print("Output files:")
|
| 634 |
+
for i, output_file in enumerate(output_files):
|
| 635 |
+
print(f" Speaker {i+1}: {output_file}")
|
Inference_with_status.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from moviepy import *
|
| 8 |
+
from PIL import Image, ImageDraw
|
| 9 |
+
import face_alignment
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
from look2hear.models import Dolphin
|
| 13 |
+
from look2hear.datas.transform import get_preprocessing_pipelines
|
| 14 |
+
|
| 15 |
+
from face_detection_utils import detect_faces
|
| 16 |
+
|
| 17 |
+
# Import functions from original Inference.py
|
| 18 |
+
from Inference import (
|
| 19 |
+
linear_interpolate, warp_img, apply_transform, cut_patch, convert_bgr2gray,
|
| 20 |
+
save2npz, read_video, face2head, bb_intersection_over_union,
|
| 21 |
+
landmarks_interpolate, crop_patch, convert_video_fps, extract_audio, merge_video_audio
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def detectface_with_status(video_input_path, output_path, detect_every_N_frame, scalar_face_detection, number_of_speakers, status_callback=None):
|
| 25 |
+
"""Face detection with status updates"""
|
| 26 |
+
device = torch.device('cuda' if torch.cuda.get_device_name() else 'cpu')
|
| 27 |
+
if status_callback:
|
| 28 |
+
status_callback({'status': f'Running on device: {device}', 'progress': 0.0})
|
| 29 |
+
|
| 30 |
+
os.makedirs(os.path.join(output_path, 'faces'), exist_ok=True)
|
| 31 |
+
os.makedirs(os.path.join(output_path, 'landmark'), exist_ok=True)
|
| 32 |
+
|
| 33 |
+
landmarks_dic = {}
|
| 34 |
+
faces_dic = {}
|
| 35 |
+
boxes_dic = {}
|
| 36 |
+
|
| 37 |
+
for i in range(number_of_speakers):
|
| 38 |
+
landmarks_dic[i] = []
|
| 39 |
+
faces_dic[i] = []
|
| 40 |
+
boxes_dic[i] = []
|
| 41 |
+
|
| 42 |
+
video_clip = VideoFileClip(video_input_path)
|
| 43 |
+
if status_callback:
|
| 44 |
+
status_callback({'status': f"Video: {video_clip.w}x{video_clip.h}, {video_clip.fps}fps", 'progress': 0.05})
|
| 45 |
+
|
| 46 |
+
frames = [Image.fromarray(frame) for frame in video_clip.iter_frames()]
|
| 47 |
+
total_frames = len(frames)
|
| 48 |
+
if status_callback:
|
| 49 |
+
status_callback({'status': f'Processing {total_frames} frames', 'progress': 0.1})
|
| 50 |
+
|
| 51 |
+
video_clip.close()
|
| 52 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
|
| 53 |
+
|
| 54 |
+
for i, frame in enumerate(frames):
|
| 55 |
+
if status_callback and i % 10 == 0:
|
| 56 |
+
status_callback({'status': f'Tracking frame: {i+1}/{total_frames}', 'progress': 0.1 + 0.3 * (i / total_frames)})
|
| 57 |
+
|
| 58 |
+
# Detect faces every N frames
|
| 59 |
+
if i % detect_every_N_frame == 0:
|
| 60 |
+
frame_array = np.array(frame)
|
| 61 |
+
|
| 62 |
+
detected_boxes, _ = detect_faces(
|
| 63 |
+
frame_array,
|
| 64 |
+
threshold=0.9,
|
| 65 |
+
allow_upscaling=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if detected_boxes is None or len(detected_boxes) == 0:
|
| 69 |
+
detected_boxes, _ = detect_faces(
|
| 70 |
+
frame_array,
|
| 71 |
+
threshold=0.7,
|
| 72 |
+
allow_upscaling=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if detected_boxes is not None and len(detected_boxes) > 0:
|
| 76 |
+
detected_boxes = np.asarray(detected_boxes, dtype=np.float32)
|
| 77 |
+
areas = (detected_boxes[:, 2] - detected_boxes[:, 0]) * (detected_boxes[:, 3] - detected_boxes[:, 1])
|
| 78 |
+
sort_idx = np.argsort(areas)[::-1]
|
| 79 |
+
detected_boxes = detected_boxes[sort_idx][:number_of_speakers]
|
| 80 |
+
detected_boxes = face2head(detected_boxes, scalar_face_detection)
|
| 81 |
+
detected_boxes = [box for box in detected_boxes]
|
| 82 |
+
else:
|
| 83 |
+
detected_boxes = []
|
| 84 |
+
|
| 85 |
+
# Process the detection results (same as original)
|
| 86 |
+
if i == 0:
|
| 87 |
+
# First frame - initialize tracking
|
| 88 |
+
if len(detected_boxes) < number_of_speakers:
|
| 89 |
+
raise ValueError(f"First frame must detect at least {number_of_speakers} faces, but only found {len(detected_boxes)}")
|
| 90 |
+
|
| 91 |
+
# Assign first detections to speakers in order
|
| 92 |
+
for j in range(number_of_speakers):
|
| 93 |
+
box = detected_boxes[j]
|
| 94 |
+
face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
|
| 95 |
+
preds = fa.get_landmarks(np.array(face))
|
| 96 |
+
|
| 97 |
+
if preds is None:
|
| 98 |
+
raise ValueError(f"Face landmarks not detected in initial frame for speaker {j}")
|
| 99 |
+
|
| 100 |
+
faces_dic[j].append(face)
|
| 101 |
+
landmarks_dic[j].append(preds)
|
| 102 |
+
boxes_dic[j].append(box)
|
| 103 |
+
else:
|
| 104 |
+
# For subsequent frames, match detected boxes to speakers
|
| 105 |
+
matched_speakers = set()
|
| 106 |
+
speaker_boxes = [None] * number_of_speakers
|
| 107 |
+
|
| 108 |
+
# Match each detected box to the most likely speaker
|
| 109 |
+
for box in detected_boxes:
|
| 110 |
+
iou_scores = []
|
| 111 |
+
for speaker_id in range(number_of_speakers):
|
| 112 |
+
if speaker_id in matched_speakers:
|
| 113 |
+
iou_scores.append(-1) # Already matched
|
| 114 |
+
else:
|
| 115 |
+
last_box = boxes_dic[speaker_id][-1]
|
| 116 |
+
iou_score = bb_intersection_over_union(box, last_box)
|
| 117 |
+
iou_scores.append(iou_score)
|
| 118 |
+
|
| 119 |
+
if max(iou_scores) > 0: # Valid match found
|
| 120 |
+
best_speaker = iou_scores.index(max(iou_scores))
|
| 121 |
+
speaker_boxes[best_speaker] = box
|
| 122 |
+
matched_speakers.add(best_speaker)
|
| 123 |
+
|
| 124 |
+
# Process each speaker
|
| 125 |
+
for speaker_id in range(number_of_speakers):
|
| 126 |
+
if speaker_boxes[speaker_id] is not None:
|
| 127 |
+
# Use detected box
|
| 128 |
+
box = speaker_boxes[speaker_id]
|
| 129 |
+
else:
|
| 130 |
+
# Use previous box for this speaker
|
| 131 |
+
box = boxes_dic[speaker_id][-1]
|
| 132 |
+
|
| 133 |
+
# Extract face and landmarks
|
| 134 |
+
face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
|
| 135 |
+
preds = fa.get_landmarks(np.array(face))
|
| 136 |
+
|
| 137 |
+
if preds is None:
|
| 138 |
+
# Use previous landmarks if detection fails
|
| 139 |
+
preds = landmarks_dic[speaker_id][-1]
|
| 140 |
+
|
| 141 |
+
faces_dic[speaker_id].append(face)
|
| 142 |
+
landmarks_dic[speaker_id].append(preds)
|
| 143 |
+
boxes_dic[speaker_id].append(box)
|
| 144 |
+
|
| 145 |
+
# Verify all speakers have same number of frames
|
| 146 |
+
frame_counts = [len(boxes_dic[s]) for s in range(number_of_speakers)]
|
| 147 |
+
if status_callback:
|
| 148 |
+
status_callback({'status': f"Frame counts per speaker: {frame_counts}", 'progress': 0.4})
|
| 149 |
+
|
| 150 |
+
assert all(count == len(frames) for count in frame_counts), f"Inconsistent frame counts: {frame_counts}"
|
| 151 |
+
|
| 152 |
+
# Continue with saving videos and landmarks...
|
| 153 |
+
for s in range(number_of_speakers):
|
| 154 |
+
if status_callback:
|
| 155 |
+
status_callback({'status': f'Saving tracked video for speaker {s+1}', 'progress': 0.4 + 0.1 * (s / number_of_speakers)})
|
| 156 |
+
|
| 157 |
+
frames_tracked = []
|
| 158 |
+
for i, frame in enumerate(frames):
|
| 159 |
+
frame_draw = frame.copy()
|
| 160 |
+
draw = ImageDraw.Draw(frame_draw)
|
| 161 |
+
draw.rectangle(boxes_dic[s][i], outline=(255, 0, 0), width=6)
|
| 162 |
+
frames_tracked.append(frame_draw)
|
| 163 |
+
|
| 164 |
+
# Save tracked video
|
| 165 |
+
tracked_frames = [np.array(frame) for frame in frames_tracked]
|
| 166 |
+
if tracked_frames:
|
| 167 |
+
tracked_clip = ImageSequenceClip(tracked_frames, fps=25.0)
|
| 168 |
+
tracked_video_path = os.path.join(output_path, 'video_tracked' + str(s+1) + '.mp4')
|
| 169 |
+
tracked_clip.write_videofile(tracked_video_path, codec='libx264', audio=False, logger=None)
|
| 170 |
+
tracked_clip.close()
|
| 171 |
+
|
| 172 |
+
# Save landmarks
|
| 173 |
+
for i in range(number_of_speakers):
|
| 174 |
+
# Create landmark directory if it doesn't exist
|
| 175 |
+
landmark_dir = os.path.join(output_path, 'landmark')
|
| 176 |
+
os.makedirs(landmark_dir, exist_ok=True)
|
| 177 |
+
save2npz(os.path.join(landmark_dir, 'speaker' + str(i+1)+'.npz'), data=landmarks_dic[i])
|
| 178 |
+
|
| 179 |
+
# Save face video
|
| 180 |
+
face_frames = [np.array(frame) for frame in faces_dic[i]]
|
| 181 |
+
if face_frames:
|
| 182 |
+
face_clip = ImageSequenceClip(face_frames, fps=25.0)
|
| 183 |
+
face_video_path = os.path.join(output_path, 'faces', 'speaker' + str(i+1) + '.mp4')
|
| 184 |
+
face_clip.write_videofile(face_video_path, codec='libx264', audio=False, logger=None)
|
| 185 |
+
face_clip.close()
|
| 186 |
+
|
| 187 |
+
# Output video path
|
| 188 |
+
parts = video_input_path.split('/')
|
| 189 |
+
video_name = parts[-1][:-4]
|
| 190 |
+
filename_dir = os.path.join(output_path, 'filename_input')
|
| 191 |
+
os.makedirs(filename_dir, exist_ok=True)
|
| 192 |
+
csvfile = open(os.path.join(filename_dir, str(video_name) + '.csv'), 'w')
|
| 193 |
+
for i in range(number_of_speakers):
|
| 194 |
+
csvfile.write('speaker' + str(i+1)+ ',0\n')
|
| 195 |
+
csvfile.close()
|
| 196 |
+
return os.path.join(filename_dir, str(video_name) + '.csv')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def crop_mouth_with_status(video_direc, landmark_direc, filename_path, save_direc, status_callback=None, convert_gray=False, testset_only=False):
|
| 200 |
+
"""Crop mouth with status updates"""
|
| 201 |
+
lines = open(filename_path).read().splitlines()
|
| 202 |
+
lines = list(filter(lambda x: 'test' in x, lines)) if testset_only else lines
|
| 203 |
+
|
| 204 |
+
for filename_idx, line in enumerate(lines):
|
| 205 |
+
filename, person_id = line.split(',')
|
| 206 |
+
|
| 207 |
+
if status_callback:
|
| 208 |
+
status_callback({'status': f'Processing speaker{int(person_id)+1}', 'progress': 0.5 + 0.1 * filename_idx / len(lines)})
|
| 209 |
+
|
| 210 |
+
video_pathname = os.path.join(video_direc, filename+'.mp4')
|
| 211 |
+
landmarks_pathname = os.path.join(landmark_direc, filename+'.npz')
|
| 212 |
+
|
| 213 |
+
# Create mouthroi directory if it doesn't exist
|
| 214 |
+
os.makedirs(save_direc, exist_ok=True)
|
| 215 |
+
dst_pathname = os.path.join(save_direc, filename+'.npz')
|
| 216 |
+
|
| 217 |
+
multi_sub_landmarks = np.load(landmarks_pathname, allow_pickle=True)['data']
|
| 218 |
+
if len(multi_sub_landmarks) == 0:
|
| 219 |
+
print(f"No landmarks found for {filename}, skipping crop.")
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
landmark_frame_count = len(multi_sub_landmarks)
|
| 223 |
+
cap = cv2.VideoCapture(video_pathname)
|
| 224 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 225 |
+
cap.release()
|
| 226 |
+
|
| 227 |
+
if frame_count > 0 and frame_count != landmark_frame_count:
|
| 228 |
+
print(
|
| 229 |
+
f"Frame count mismatch for {filename}: video has {frame_count} frames, "
|
| 230 |
+
f"landmarks have {landmark_frame_count} entries. Adjusting to match."
|
| 231 |
+
)
|
| 232 |
+
if frame_count < landmark_frame_count:
|
| 233 |
+
multi_sub_landmarks = multi_sub_landmarks[:frame_count]
|
| 234 |
+
else:
|
| 235 |
+
pad_count = frame_count - landmark_frame_count
|
| 236 |
+
pad = np.repeat(multi_sub_landmarks[-1:], pad_count, axis=0)
|
| 237 |
+
multi_sub_landmarks = np.concatenate((multi_sub_landmarks, pad), axis=0)
|
| 238 |
+
|
| 239 |
+
landmarks = [None] * len(multi_sub_landmarks)
|
| 240 |
+
for frame_idx in range(len(landmarks)):
|
| 241 |
+
try:
|
| 242 |
+
landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)]
|
| 243 |
+
except (IndexError, TypeError):
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
# Pre-process landmarks: interpolate frames not being detected
|
| 247 |
+
preprocessed_landmarks = landmarks_interpolate(landmarks)
|
| 248 |
+
if not preprocessed_landmarks:
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
# Crop
|
| 252 |
+
mean_face_landmarks = np.load('assets/20words_mean_face.npy')
|
| 253 |
+
sequence = crop_patch(mean_face_landmarks, video_pathname, preprocessed_landmarks, 12, 48, 68, 96, 96)
|
| 254 |
+
assert sequence is not None, "cannot crop from {}.".format(filename)
|
| 255 |
+
|
| 256 |
+
# Save
|
| 257 |
+
data = convert_bgr2gray(sequence) if convert_gray else sequence[...,::-1]
|
| 258 |
+
save2npz(dst_pathname, data=data)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def process_video_with_status(input_file, output_path, number_of_speakers=2,
|
| 262 |
+
detect_every_N_frame=8, scalar_face_detection=1.5,
|
| 263 |
+
config_path="checkpoints/vox2/conf.yml",
|
| 264 |
+
cuda_device=None, status_callback=None):
|
| 265 |
+
"""Main processing function with status updates"""
|
| 266 |
+
|
| 267 |
+
# Set CUDA device if specified
|
| 268 |
+
if cuda_device is not None:
|
| 269 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)
|
| 270 |
+
|
| 271 |
+
# Create output directory
|
| 272 |
+
os.makedirs(output_path, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
# Convert video to 25fps
|
| 275 |
+
if status_callback:
|
| 276 |
+
status_callback({'status': 'Converting video to 25fps', 'progress': 0.0})
|
| 277 |
+
|
| 278 |
+
temp_25fps_file = os.path.join(output_path, 'temp_25fps.mp4')
|
| 279 |
+
convert_video_fps(input_file, temp_25fps_file, target_fps=25)
|
| 280 |
+
|
| 281 |
+
# Detect faces
|
| 282 |
+
if status_callback:
|
| 283 |
+
status_callback({'status': 'Detecting faces and tracking speakers', 'progress': 0.1})
|
| 284 |
+
|
| 285 |
+
filename_path = detectface_with_status(
|
| 286 |
+
video_input_path=temp_25fps_file,
|
| 287 |
+
output_path=output_path,
|
| 288 |
+
detect_every_N_frame=detect_every_N_frame,
|
| 289 |
+
scalar_face_detection=scalar_face_detection,
|
| 290 |
+
number_of_speakers=number_of_speakers,
|
| 291 |
+
status_callback=status_callback
|
| 292 |
+
)
|
| 293 |
+
torch.cuda.empty_cache()
|
| 294 |
+
# Extract audio
|
| 295 |
+
if status_callback:
|
| 296 |
+
status_callback({'status': 'Extracting audio from video', 'progress': 0.5})
|
| 297 |
+
|
| 298 |
+
audio_output = os.path.join(output_path, 'audio.wav')
|
| 299 |
+
extract_audio(temp_25fps_file, audio_output, sample_rate=16000)
|
| 300 |
+
|
| 301 |
+
# Crop mouth
|
| 302 |
+
if status_callback:
|
| 303 |
+
status_callback({'status': 'Cropping mouth regions', 'progress': 0.55})
|
| 304 |
+
|
| 305 |
+
crop_mouth_with_status(
|
| 306 |
+
video_direc=os.path.join(output_path, "faces"),
|
| 307 |
+
landmark_direc=os.path.join(output_path, "landmark"),
|
| 308 |
+
filename_path=filename_path,
|
| 309 |
+
save_direc=os.path.join(output_path, "mouthroi"),
|
| 310 |
+
convert_gray=True,
|
| 311 |
+
testset_only=False,
|
| 312 |
+
status_callback=status_callback
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Load model
|
| 316 |
+
if status_callback:
|
| 317 |
+
status_callback({'status': 'Loading Dolphin model', 'progress': 0.6})
|
| 318 |
+
torch.cuda.empty_cache()
|
| 319 |
+
audiomodel = Dolphin.from_pretrained("JusperLee/Dolphin")
|
| 320 |
+
# audiomodel.cuda()
|
| 321 |
+
audiomodel.eval()
|
| 322 |
+
|
| 323 |
+
# Process each speaker
|
| 324 |
+
with torch.no_grad():
|
| 325 |
+
for i in range(number_of_speakers):
|
| 326 |
+
if status_callback:
|
| 327 |
+
status_callback({'status': f'Processing audio for speaker {i+1}', 'progress': 0.65 + 0.25 * (i / number_of_speakers)})
|
| 328 |
+
|
| 329 |
+
mouth_roi_path = os.path.join(output_path, "mouthroi", f"speaker{i+1}.npz")
|
| 330 |
+
mouth_roi = np.load(mouth_roi_path)["data"]
|
| 331 |
+
mouth_roi = get_preprocessing_pipelines()["val"](mouth_roi)
|
| 332 |
+
|
| 333 |
+
mix, sr = torchaudio.load(audio_output)
|
| 334 |
+
mix = mix.mean(dim=0)
|
| 335 |
+
|
| 336 |
+
window_size = 4 * sr
|
| 337 |
+
hop_size = int(4 * sr)
|
| 338 |
+
|
| 339 |
+
all_estimates = []
|
| 340 |
+
|
| 341 |
+
# Sliding window processing
|
| 342 |
+
start_idx = 0
|
| 343 |
+
window_count = 0
|
| 344 |
+
while start_idx < len(mix):
|
| 345 |
+
end_idx = min(start_idx + window_size, len(mix))
|
| 346 |
+
window_mix = mix[start_idx:end_idx]
|
| 347 |
+
|
| 348 |
+
start_frame = int(start_idx / sr * 25)
|
| 349 |
+
end_frame = int(end_idx / sr * 25)
|
| 350 |
+
end_frame = min(end_frame, len(mouth_roi))
|
| 351 |
+
window_mouth_roi = mouth_roi[start_frame:end_frame]
|
| 352 |
+
|
| 353 |
+
est_sources = audiomodel(window_mix[None],
|
| 354 |
+
torch.from_numpy(window_mouth_roi[None, None]).float())
|
| 355 |
+
|
| 356 |
+
all_estimates.append({
|
| 357 |
+
'start': start_idx,
|
| 358 |
+
'end': end_idx,
|
| 359 |
+
'estimate': est_sources[0].cpu()
|
| 360 |
+
})
|
| 361 |
+
|
| 362 |
+
window_count += 1
|
| 363 |
+
if status_callback:
|
| 364 |
+
progress = 0.65 + 0.25 * (i / number_of_speakers) + 0.25 / number_of_speakers * (window_count * hop_size / len(mix))
|
| 365 |
+
status_callback({'status': f'Processing audio window {window_count} for speaker {i+1}', 'progress': min(progress, 0.9)})
|
| 366 |
+
|
| 367 |
+
start_idx += hop_size
|
| 368 |
+
|
| 369 |
+
if start_idx >= len(mix):
|
| 370 |
+
break
|
| 371 |
+
torch.cuda.empty_cache()
|
| 372 |
+
|
| 373 |
+
output_length = len(mix)
|
| 374 |
+
merged_output = torch.zeros(1, output_length)
|
| 375 |
+
weights = torch.zeros(output_length)
|
| 376 |
+
|
| 377 |
+
for est in all_estimates:
|
| 378 |
+
window_len = est['end'] - est['start']
|
| 379 |
+
hann_window = torch.hann_window(window_len)
|
| 380 |
+
|
| 381 |
+
merged_output[0, est['start']:est['end']] += est['estimate'][0, :window_len] * hann_window
|
| 382 |
+
weights[est['start']:est['end']] += hann_window
|
| 383 |
+
|
| 384 |
+
merged_output[:, weights > 0] /= weights[weights > 0]
|
| 385 |
+
|
| 386 |
+
audio_save_path = os.path.join(output_path, f"speaker{i+1}_est.wav")
|
| 387 |
+
torchaudio.save(audio_save_path, merged_output, sr)
|
| 388 |
+
|
| 389 |
+
# Merge video with separated audio for each speaker
|
| 390 |
+
torch.cuda.empty_cache()
|
| 391 |
+
if status_callback:
|
| 392 |
+
status_callback({'status': 'Merging videos with separated audio', 'progress': 0.9})
|
| 393 |
+
|
| 394 |
+
output_files = []
|
| 395 |
+
for i in range(number_of_speakers):
|
| 396 |
+
video_input = os.path.join(output_path, f"video_tracked{i+1}.mp4")
|
| 397 |
+
audio_input = os.path.join(output_path, f"speaker{i+1}_est.wav")
|
| 398 |
+
video_output = os.path.join(output_path, f"s{i+1}.mp4")
|
| 399 |
+
|
| 400 |
+
merge_video_audio(video_input, audio_input, video_output)
|
| 401 |
+
output_files.append(video_output)
|
| 402 |
+
|
| 403 |
+
# Clean up temporary file
|
| 404 |
+
if os.path.exists(temp_25fps_file):
|
| 405 |
+
os.remove(temp_25fps_file)
|
| 406 |
+
|
| 407 |
+
if status_callback:
|
| 408 |
+
status_callback({'status': 'Processing completed!', 'progress': 1.0})
|
| 409 |
+
|
| 410 |
+
return output_files
|
README.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="assets/icon.png" alt="Dolphin Logo" width="150"/>
|
| 3 |
+
</p>
|
| 4 |
+
<h3 align="center">Dolphin: Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Multi-Scale Global-Local Attention</h3>
|
| 5 |
+
<p align="center">
|
| 6 |
+
<strong>Kai Li*, Kejun Gao*, Xiaolin Hu </strong><br>
|
| 7 |
+
<strong>Tsinghua University</strong>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<p align="center">
|
| 11 |
+
<img src="https://visitor-badge.laobi.icu/badge?page_id=JusperLee.Dolphin" alt="访客统计" />
|
| 12 |
+
<img src="https://img.shields.io/github/stars/JusperLee/Dolphin?style=social" alt="GitHub stars" />
|
| 13 |
+
<img alt="Static Badge" src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" />
|
| 14 |
+
<a href="https://arxiv.org/abs/2509.23610" target="_blank" rel="noreferrer noopener">
|
| 15 |
+
<img alt="arXiv Paper" src="https://img.shields.io/badge/arXiv-2509.23610-b31b1b.svg?logo=arxiv&logoColor=white" />
|
| 16 |
+
</a>
|
| 17 |
+
<a href="https://huggingface.co/JusperLee/Dolphin" target="_blank" rel="noreferrer noopener">
|
| 18 |
+
<img alt="Hugging Face Models" src="https://img.shields.io/badge/Hugging%20Face-Models-ff9d2c?logo=huggingface&logoColor=white" />
|
| 19 |
+
</a>
|
| 20 |
+
<a href="https://dolphin.cslikai.cn/" target="_blank" rel="noreferrer noopener">
|
| 21 |
+
<img alt="Gradio Live Demo" src="https://img.shields.io/badge/Gradio-Live%20Demo-00a67e?logo=gradio&logoColor=white" />
|
| 22 |
+
</a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
|
| 27 |
+
> Dolphin is an efficient audio-visual speech separation framework that leverages discrete lip semantics and global–local attention to achieve state-of-the-art performance with significantly reduced computational complexity.
|
| 28 |
+
|
| 29 |
+
## 🎯 Highlights
|
| 30 |
+
|
| 31 |
+
- **Balanced Quality & Efficiency**: Single-pass separator achieves state-of-the-art AVSS performance without iterative refinement.
|
| 32 |
+
- **DP-LipCoder**: Dual-path, vector-quantized video encoder produces discrete audio-aligned semantic tokens while staying lightweight.
|
| 33 |
+
- **Global–Local Attention**: TDANet-based separator augments each layer with coarse global self-attention and heat diffusion local attention.
|
| 34 |
+
- **Edge-Friendly Deployment**: Delivers >50% parameter reduction, >2.4× lower MACs, and >6× faster GPU inference versus IIANet.
|
| 35 |
+
|
| 36 |
+
## 💥 News
|
| 37 |
+
|
| 38 |
+
- **[2025-09-28]** Code and pre-trained models are released! 📦
|
| 39 |
+
|
| 40 |
+
## 📜 Abstract
|
| 41 |
+
|
| 42 |
+
Audio-visual speech separation (AVSS) methods leverage visual cues to extract target speech in noisy acoustic environments, but most existing systems remain computationally heavy. Dolphin tackles this tension by combining a lightweight, dual-path video encoder with a single-pass global–local collaborative separator. The video pathway, DP-LipCoder, maps lip movements into discrete semantic tokens that remain tightly aligned with audio through vector quantization and distillation from AV-HuBERT. The audio separator builds upon TDANet and injects global–local attention (GLA) blocks—coarse-grained self-attention for long-range context and heat diffusion attention for denoising fine details. Across three public AVSS benchmarks, Dolphin not only outperforms the state-of-the-art IIANet on separation metrics but also delivers over 50% fewer parameters, more than 2.4× lower MACs, and over 6× faster GPU inference, making it practical for edge deployment.
|
| 43 |
+
|
| 44 |
+
## 🌍 Motivation
|
| 45 |
+
|
| 46 |
+
In real-world environments, target speech is often masked by background noise and interfering speakers. This phenomenon reflects the classic “cocktail party effect,” where listeners selectively attend to a single speaker within a noisy scene (Cherry, 1953). These challenges have spurred extensive research on speech separation.
|
| 47 |
+
|
| 48 |
+
Audio-only approaches tend to struggle in complex acoustic conditions, while the integration of synchronous visual cues offers greater robustness. Recent deep learning-based AVSS systems achieve strong performance, yet many rely on computationally intensive separators or heavy iterative refinement, limiting their practicality.
|
| 49 |
+
|
| 50 |
+
Beyond the separator itself, AVSS models frequently inherit high computational cost from their video encoders. Large-scale lip-reading backbones provide rich semantic alignment but bring prohibitive parameter counts. Compressing them often erodes lip semantics, whereas designing new lightweight encoders from scratch risks losing semantic fidelity and degrading separation quality. Building a video encoder that balances compactness with semantic alignment therefore remains a central challenge for AVSS.
|
| 51 |
+
|
| 52 |
+
## 🧠 Method Overview
|
| 53 |
+
|
| 54 |
+
To address these limitations, Dolphin introduces a novel AVSS pipeline centered on two components:
|
| 55 |
+
|
| 56 |
+
- **DP-LipCoder**: A dual-path, vector-quantized video encoder that separates compressed visual structure from audio-aligned semantics. By combining vector quantization with knowledge distillation from AV-HuBERT, it converts continuous lip motion into discrete semantic tokens without sacrificing representational capacity.
|
| 57 |
+
- **Single-Pass GLA Separator**: A lightweight TDANet-based audio separator that removes the need for iterative refinement. Each layer hosts a global–local attention block: coarse-grained self-attention captures long-range dependencies at low resolution, while heat diffusion attention smooths features across channels to suppress noise and retain detail.
|
| 58 |
+
|
| 59 |
+
Together, these components strike a balance between separation quality and computational efficiency, enabling deployment in resource-constrained scenarios.
|
| 60 |
+
|
| 61 |
+
## 🧪 Experimental Highlights
|
| 62 |
+
|
| 63 |
+
We evaluate Dolphin on LRS2, LRS3, and VoxCeleb2. Compared with the state-of-the-art IIANet, Dolphin achieves higher scores across all separation metrics while dramatically reducing resource consumption:
|
| 64 |
+
|
| 65 |
+
- **Parameters**: >50% reduction
|
| 66 |
+
- **Computation**: >2.4× decrease in MACs
|
| 67 |
+
- **Inference**: >6× speedup on GPU
|
| 68 |
+
|
| 69 |
+
These results demonstrate that Dolphin provides competitive AVSS quality on edge hardware without heavy iterative processing.
|
| 70 |
+
|
| 71 |
+
## 🏗️ Architecture
|
| 72 |
+
|
| 73 |
+

|
| 74 |
+
|
| 75 |
+
> The overall architecture of Dolphin.
|
| 76 |
+
|
| 77 |
+
### Video Encoder
|
| 78 |
+
|
| 79 |
+

|
| 80 |
+
|
| 81 |
+
> The video encoder of Dolphin.
|
| 82 |
+
|
| 83 |
+
### Dolphin Model Overview
|
| 84 |
+
|
| 85 |
+

|
| 86 |
+
|
| 87 |
+
> The overall architecture of Dolphin's separator.
|
| 88 |
+
|
| 89 |
+
### Key Components
|
| 90 |
+
|
| 91 |
+

|
| 92 |
+
|
| 93 |
+
1. **Global Attention (GA) Block**
|
| 94 |
+
- Applies coarse-grained self-attention to capture long-range structure
|
| 95 |
+
- Operates at low spatial resolution for efficiency
|
| 96 |
+
- Enhances robustness to complex acoustic mixtures
|
| 97 |
+
|
| 98 |
+
2. **Local Attention (LA) Block**
|
| 99 |
+
- Uses heat diffusion attention to smooth features across channels
|
| 100 |
+
- Suppresses background noise while preserving details
|
| 101 |
+
- Complements GA to balance global context and local fidelity
|
| 102 |
+
|
| 103 |
+
## 📊 Results
|
| 104 |
+
|
| 105 |
+
### Performance Comparison
|
| 106 |
+
|
| 107 |
+
Performance metrics on three public AVSS benchmark datasets. Bold indicates best performance.
|
| 108 |
+
|
| 109 |
+

|
| 110 |
+
|
| 111 |
+
### Efficiency Analysis
|
| 112 |
+
|
| 113 |
+

|
| 114 |
+
|
| 115 |
+
Dolphin achieves:
|
| 116 |
+
- ✅ **>50%** parameter reduction
|
| 117 |
+
- ✅ **2.4×** lower computational cost (MACs)
|
| 118 |
+
- ✅ **6×** faster GPU inference speed
|
| 119 |
+
- ✅ Superior separation quality across all metrics
|
| 120 |
+
|
| 121 |
+
## 📦 Installation
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
git clone https://github.com/JusperLee/Dolphin.git
|
| 125 |
+
cd Dolphin
|
| 126 |
+
pip install torch torchvision
|
| 127 |
+
pip install -r requirements.txt
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Requirements
|
| 131 |
+
|
| 132 |
+
- Python >= 3.10
|
| 133 |
+
- PyTorch >= 2.5.0
|
| 134 |
+
- CUDA >= 12.4
|
| 135 |
+
- Other dependencies in requirements.txt
|
| 136 |
+
|
| 137 |
+
## 🚀 Quick Start
|
| 138 |
+
|
| 139 |
+
### Inference with Pre-trained Model
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
# Single audio-visual separation
|
| 143 |
+
python inference.py \
|
| 144 |
+
--input /path/to/video.mp4 \
|
| 145 |
+
--output /path/to/output/directory \
|
| 146 |
+
--speakers 2 \
|
| 147 |
+
--detect-every-n 8 \
|
| 148 |
+
--face-scale 1.5 \
|
| 149 |
+
--cuda-device 0 \
|
| 150 |
+
--config checkpoints/vox2/conf.yml
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## 📁 Model Zoo
|
| 154 |
+
|
| 155 |
+
| Model | Training Data | SI-SNRi | PESQ | Download |
|
| 156 |
+
|-------|--------------|---------|------|----------|
|
| 157 |
+
| Dolphin | VoxCeleb2 | 16.1 dB | 3.45 | [Link](https://huggingface.co/JusperLee/Dolphin) |
|
| 158 |
+
|
| 159 |
+
## 📖 Citation
|
| 160 |
+
|
| 161 |
+
If you find Dolphin useful in your research, please cite:
|
| 162 |
+
|
| 163 |
+
```bibtex
|
| 164 |
+
@misc{li2025efficientaudiovisualspeechseparation,
|
| 165 |
+
title={Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Multi-Scale Global-Local Attention},
|
| 166 |
+
author={Kai Li and Kejun Gao and Xiaolin Hu},
|
| 167 |
+
year={2025},
|
| 168 |
+
eprint={2509.23610},
|
| 169 |
+
archivePrefix={arXiv},
|
| 170 |
+
primaryClass={cs.SD},
|
| 171 |
+
url={https://arxiv.org/abs/2509.23610},
|
| 172 |
+
}
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## 🤝 Acknowledgments
|
| 176 |
+
|
| 177 |
+
We thank the authors of [IIANet](https://github.com/JusperLee/IIANet) and [SepReformer](https://github.com/dmlguq456/SepReformer) for providing parts of the code used in this project.
|
| 178 |
+
|
| 179 |
+
## 📧 Contact
|
| 180 |
+
|
| 181 |
+
For questions and feedback, please open an issue on GitHub or contact us at: [tsinghua.kaili@gmail.com](tsinghua.kaili@gmail.com)
|
| 182 |
+
|
| 183 |
+
## 📄 License
|
| 184 |
+
|
| 185 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 186 |
+
|
| 187 |
+
<p align="center">
|
| 188 |
+
Made with stars ⭐️ for efficient audio-visual speech separation
|
| 189 |
+
</p>
|
app.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Audio-visual Speech Separation Gradio App - Hugging Face Space Version
|
| 4 |
+
Automatically detects and separates all speakers in videos
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings("ignore")
|
| 9 |
+
import os
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import numpy as np
|
| 12 |
+
import shutil
|
| 13 |
+
import tempfile
|
| 14 |
+
import time
|
| 15 |
+
import sys
|
| 16 |
+
import threading
|
| 17 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 18 |
+
from moviepy import *
|
| 19 |
+
import spaces
|
| 20 |
+
|
| 21 |
+
from face_detection_utils import detect_faces
|
| 22 |
+
|
| 23 |
+
# Use HF Space's temp directory
|
| 24 |
+
TEMP_DIR = os.environ.get('TMPDIR', '/tmp')
|
| 25 |
+
|
| 26 |
+
# Shared state for relaying GPU-side status back to the UI thread.
|
| 27 |
+
GPU_PROGRESS_STATE = {"progress": 0.0, "status": "Processing on GPU..."}
|
| 28 |
+
GPU_PROGRESS_LOCK = threading.Lock()
|
| 29 |
+
|
| 30 |
+
class LogCollector:
|
| 31 |
+
"""Collect logs in a list"""
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.logs = []
|
| 34 |
+
|
| 35 |
+
def add(self, message):
|
| 36 |
+
if message and message.strip():
|
| 37 |
+
timestamp = time.strftime("%H:%M:%S")
|
| 38 |
+
self.logs.append(f"[{timestamp}] {message.strip()}")
|
| 39 |
+
|
| 40 |
+
def get_text(self, last_n=None):
|
| 41 |
+
if last_n:
|
| 42 |
+
return "\n".join(self.logs[-last_n:])
|
| 43 |
+
return "\n".join(self.logs)
|
| 44 |
+
|
| 45 |
+
# Global log collector for capturing print statements
|
| 46 |
+
GLOBAL_LOG = LogCollector()
|
| 47 |
+
|
| 48 |
+
class StdoutCapture:
|
| 49 |
+
"""Capture stdout and add to log"""
|
| 50 |
+
def __init__(self, original):
|
| 51 |
+
self.original = original
|
| 52 |
+
|
| 53 |
+
def write(self, text):
|
| 54 |
+
self.original.write(text)
|
| 55 |
+
if text.strip():
|
| 56 |
+
GLOBAL_LOG.add(text.strip())
|
| 57 |
+
|
| 58 |
+
def flush(self):
|
| 59 |
+
self.original.flush()
|
| 60 |
+
|
| 61 |
+
def remove_duplicate_faces(boxes, probs, iou_threshold=0.5):
|
| 62 |
+
"""Remove duplicate face detections using IoU (Intersection over Union)"""
|
| 63 |
+
if len(boxes) <= 1:
|
| 64 |
+
return boxes, probs
|
| 65 |
+
|
| 66 |
+
# Calculate IoU between all pairs of boxes
|
| 67 |
+
def calculate_iou(box1, box2):
|
| 68 |
+
x1 = max(box1[0], box2[0])
|
| 69 |
+
y1 = max(box1[1], box2[1])
|
| 70 |
+
x2 = min(box1[2], box2[2])
|
| 71 |
+
y2 = min(box1[3], box2[3])
|
| 72 |
+
|
| 73 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| 74 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 75 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 76 |
+
union = area1 + area2 - intersection
|
| 77 |
+
|
| 78 |
+
return intersection / union if union > 0 else 0
|
| 79 |
+
|
| 80 |
+
# Keep track of which boxes to keep
|
| 81 |
+
keep = []
|
| 82 |
+
used = set()
|
| 83 |
+
|
| 84 |
+
# Sort by confidence (if available) or by area
|
| 85 |
+
if probs is not None:
|
| 86 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 87 |
+
else:
|
| 88 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 89 |
+
sorted_indices = np.argsort(areas)[::-1]
|
| 90 |
+
|
| 91 |
+
for i in sorted_indices:
|
| 92 |
+
if i in used:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
keep.append(i)
|
| 96 |
+
used.add(i)
|
| 97 |
+
|
| 98 |
+
# Mark overlapping boxes as used
|
| 99 |
+
for j in range(len(boxes)):
|
| 100 |
+
if j != i and j not in used:
|
| 101 |
+
iou = calculate_iou(boxes[i], boxes[j])
|
| 102 |
+
if iou > iou_threshold:
|
| 103 |
+
used.add(j)
|
| 104 |
+
|
| 105 |
+
# Return filtered boxes and probs
|
| 106 |
+
keep = sorted(keep) # Maintain original order
|
| 107 |
+
filtered_boxes = boxes[keep]
|
| 108 |
+
filtered_probs = probs[keep] if probs is not None else None
|
| 109 |
+
|
| 110 |
+
return filtered_boxes, filtered_probs
|
| 111 |
+
|
| 112 |
+
def process_detected_faces(boxes, probs, frame_rgb, frame_pil):
|
| 113 |
+
"""Process detected faces and return face images"""
|
| 114 |
+
face_images = []
|
| 115 |
+
full_frame_annotated = frame_rgb.copy()
|
| 116 |
+
|
| 117 |
+
if boxes is None or len(boxes) == 0:
|
| 118 |
+
return [], 0, full_frame_annotated, "No faces detected"
|
| 119 |
+
|
| 120 |
+
boxes = np.asarray(boxes, dtype=np.float32)
|
| 121 |
+
|
| 122 |
+
# Filter by confidence if available
|
| 123 |
+
if probs is not None:
|
| 124 |
+
# Keep faces with confidence > 0.9
|
| 125 |
+
confident_indices = probs > 0.9
|
| 126 |
+
boxes = boxes[confident_indices]
|
| 127 |
+
probs = probs[confident_indices]
|
| 128 |
+
print(f"After filtering by confidence: {len(boxes)} faces")
|
| 129 |
+
|
| 130 |
+
if len(boxes) == 0:
|
| 131 |
+
return [], 0, full_frame_annotated, "No faces passed the confidence filter"
|
| 132 |
+
|
| 133 |
+
# Remove duplicate detections
|
| 134 |
+
boxes, probs = remove_duplicate_faces(boxes, probs, iou_threshold=0.3)
|
| 135 |
+
print(f"After removing duplicates: {len(boxes)} faces")
|
| 136 |
+
|
| 137 |
+
if len(boxes) == 0:
|
| 138 |
+
return [], 0, full_frame_annotated, "No faces remained after duplicate removal"
|
| 139 |
+
|
| 140 |
+
# Sort boxes by area (larger faces first)
|
| 141 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 142 |
+
sorted_indices = np.argsort(areas)[::-1]
|
| 143 |
+
boxes = boxes[sorted_indices]
|
| 144 |
+
|
| 145 |
+
# Annotate full frame
|
| 146 |
+
full_frame_pil = Image.fromarray(full_frame_annotated)
|
| 147 |
+
draw = ImageDraw.Draw(full_frame_pil)
|
| 148 |
+
|
| 149 |
+
# Try to use a better font
|
| 150 |
+
try:
|
| 151 |
+
font = ImageFont.load_default()
|
| 152 |
+
except:
|
| 153 |
+
font = None
|
| 154 |
+
|
| 155 |
+
# Extract face images and annotate
|
| 156 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
|
| 157 |
+
|
| 158 |
+
for i, box in enumerate(boxes):
|
| 159 |
+
color = colors[i % len(colors)]
|
| 160 |
+
|
| 161 |
+
# Draw bounding box
|
| 162 |
+
draw.rectangle(box.tolist(), outline=color, width=4)
|
| 163 |
+
label = f"Speaker {i+1}"
|
| 164 |
+
|
| 165 |
+
# Draw label
|
| 166 |
+
if font:
|
| 167 |
+
draw.text((box[0] + 5, box[1] - 20), label, fill=color, font=font)
|
| 168 |
+
|
| 169 |
+
# Extract face with margin
|
| 170 |
+
margin = 50
|
| 171 |
+
x1 = max(0, int(box[0] - margin))
|
| 172 |
+
y1 = max(0, int(box[1] - margin))
|
| 173 |
+
x2 = min(frame_rgb.shape[1], int(box[2] + margin))
|
| 174 |
+
y2 = min(frame_rgb.shape[0], int(box[3] + margin))
|
| 175 |
+
|
| 176 |
+
face_crop = frame_rgb[y1:y2, x1:x2]
|
| 177 |
+
# Resize maintaining aspect ratio
|
| 178 |
+
face_crop = Image.fromarray(face_crop)
|
| 179 |
+
face_crop.thumbnail((250, 250), Image.Resampling.LANCZOS)
|
| 180 |
+
face_crop = np.array(face_crop)
|
| 181 |
+
|
| 182 |
+
face_images.append(face_crop)
|
| 183 |
+
|
| 184 |
+
full_frame_annotated = np.array(full_frame_pil)
|
| 185 |
+
return face_images, len(boxes), full_frame_annotated, None
|
| 186 |
+
|
| 187 |
+
@spaces.GPU(duration=60, enable_queue=True)
|
| 188 |
+
def detect_faces_gpu(frame_pil):
|
| 189 |
+
"""GPU-accelerated face detection"""
|
| 190 |
+
print("Detecting faces with RetinaFace")
|
| 191 |
+
|
| 192 |
+
frame_array = np.array(frame_pil)
|
| 193 |
+
|
| 194 |
+
boxes, probs = detect_faces(
|
| 195 |
+
frame_array,
|
| 196 |
+
threshold=0.9,
|
| 197 |
+
allow_upscaling=False,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if boxes is None or len(boxes) == 0:
|
| 201 |
+
print("No faces detected at high threshold, relaxing criteria...")
|
| 202 |
+
boxes, probs = detect_faces(
|
| 203 |
+
frame_array,
|
| 204 |
+
threshold=0.7,
|
| 205 |
+
allow_upscaling=True,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return boxes, probs
|
| 209 |
+
|
| 210 |
+
def detect_and_extract_all_faces(video_path):
|
| 211 |
+
"""Detect all faces in the first frame and extract them"""
|
| 212 |
+
print("Starting face detection...")
|
| 213 |
+
|
| 214 |
+
# Check if video file exists
|
| 215 |
+
if not os.path.exists(video_path):
|
| 216 |
+
print(f"Error: Video file does not exist at path: {video_path}")
|
| 217 |
+
return [], 0, None, f"Video file not found: {video_path}"
|
| 218 |
+
|
| 219 |
+
print(f"Video path: {video_path}")
|
| 220 |
+
print(f"File size: {os.path.getsize(video_path) / 1024 / 1024:.2f} MB")
|
| 221 |
+
|
| 222 |
+
# Use moviepy to read video
|
| 223 |
+
print("Opening video with moviepy...")
|
| 224 |
+
try:
|
| 225 |
+
clip = VideoFileClip(video_path)
|
| 226 |
+
|
| 227 |
+
# Get video properties
|
| 228 |
+
fps = clip.fps
|
| 229 |
+
duration = clip.duration
|
| 230 |
+
total_frames = int(fps * duration)
|
| 231 |
+
|
| 232 |
+
print(f"Video info: FPS: {fps}, Duration: {duration}s, Total frames: {total_frames}")
|
| 233 |
+
|
| 234 |
+
# Get first frame
|
| 235 |
+
frame = clip.get_frame(0) # MoviePy returns RGB
|
| 236 |
+
frame_rgb = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
|
| 237 |
+
|
| 238 |
+
print(f"Successfully read frame with moviepy: {frame_rgb.shape}")
|
| 239 |
+
|
| 240 |
+
# Close the clip to free resources
|
| 241 |
+
clip.close()
|
| 242 |
+
|
| 243 |
+
# Convert to PIL for downstream processing
|
| 244 |
+
frame_pil = Image.fromarray(frame_rgb)
|
| 245 |
+
|
| 246 |
+
# Detect faces using RetinaFace
|
| 247 |
+
print("Detecting faces with RetinaFace...")
|
| 248 |
+
boxes, probs = detect_faces(
|
| 249 |
+
frame_rgb,
|
| 250 |
+
threshold=0.9,
|
| 251 |
+
allow_upscaling=False,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if boxes is None or len(boxes) == 0:
|
| 255 |
+
print("No faces detected at high threshold, trying relaxed settings...")
|
| 256 |
+
boxes, probs = detect_faces(
|
| 257 |
+
frame_rgb,
|
| 258 |
+
threshold=0.7,
|
| 259 |
+
allow_upscaling=True,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if boxes is not None and len(boxes) > 0:
|
| 263 |
+
print(f"Detected {len(boxes)} faces")
|
| 264 |
+
return process_detected_faces(boxes, probs, frame_rgb, frame_pil)
|
| 265 |
+
else:
|
| 266 |
+
return [], 0, frame_rgb, "No faces detected in the first frame"
|
| 267 |
+
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"MoviePy failed: {e}")
|
| 270 |
+
import traceback
|
| 271 |
+
traceback.print_exc()
|
| 272 |
+
return [], 0, None, f"Failed to open video file. Error: {str(e)}"
|
| 273 |
+
|
| 274 |
+
@spaces.GPU(duration=300, enable_queue=True)
|
| 275 |
+
def process_video_gpu(video_file, temp_dir, num_speakers):
|
| 276 |
+
"""GPU-accelerated video processing"""
|
| 277 |
+
try:
|
| 278 |
+
from Inference_with_status import process_video_with_status
|
| 279 |
+
|
| 280 |
+
# Define status callback inside GPU function
|
| 281 |
+
def gpu_status_callback(message):
|
| 282 |
+
status_text = message.get('status', 'Processing...')
|
| 283 |
+
print(f"GPU Processing: {status_text}")
|
| 284 |
+
progress_value = message.get('progress')
|
| 285 |
+
with GPU_PROGRESS_LOCK:
|
| 286 |
+
GPU_PROGRESS_STATE["status"] = status_text
|
| 287 |
+
if progress_value is not None:
|
| 288 |
+
try:
|
| 289 |
+
numeric_progress = float(progress_value)
|
| 290 |
+
GPU_PROGRESS_STATE["progress"] = min(max(numeric_progress, 0.0), 1.0)
|
| 291 |
+
except (TypeError, ValueError):
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
output_files = process_video_with_status(
|
| 295 |
+
input_file=video_file,
|
| 296 |
+
output_path=temp_dir,
|
| 297 |
+
number_of_speakers=num_speakers,
|
| 298 |
+
detect_every_N_frame=8,
|
| 299 |
+
scalar_face_detection=1.5,
|
| 300 |
+
status_callback=gpu_status_callback
|
| 301 |
+
)
|
| 302 |
+
return output_files
|
| 303 |
+
except ImportError:
|
| 304 |
+
from Inference import process_video
|
| 305 |
+
print("Using standard process_video (status callbacks not available)")
|
| 306 |
+
output_files = process_video(
|
| 307 |
+
input_file=video_file,
|
| 308 |
+
output_path=temp_dir,
|
| 309 |
+
number_of_speakers=num_speakers,
|
| 310 |
+
detect_every_N_frame=8,
|
| 311 |
+
scalar_face_detection=1.5
|
| 312 |
+
)
|
| 313 |
+
return output_files
|
| 314 |
+
|
| 315 |
+
def process_video_auto(video_file, progress=gr.Progress()):
|
| 316 |
+
"""Process video with automatic speaker detection and stream status updates"""
|
| 317 |
+
global GLOBAL_LOG
|
| 318 |
+
GLOBAL_LOG = LogCollector()
|
| 319 |
+
|
| 320 |
+
old_stdout = sys.stdout
|
| 321 |
+
sys.stdout = StdoutCapture(old_stdout)
|
| 322 |
+
|
| 323 |
+
status_value = "⏳ Ready to process..."
|
| 324 |
+
detected_info_output = gr.update(visible=False)
|
| 325 |
+
face_gallery_output = gr.update(visible=False)
|
| 326 |
+
output_video_output = gr.update(visible=False)
|
| 327 |
+
video_dict_value = None
|
| 328 |
+
annotated_frame_output = gr.update(visible=False)
|
| 329 |
+
|
| 330 |
+
def snapshot():
|
| 331 |
+
return (
|
| 332 |
+
status_value,
|
| 333 |
+
detected_info_output,
|
| 334 |
+
face_gallery_output,
|
| 335 |
+
output_video_output,
|
| 336 |
+
video_dict_value,
|
| 337 |
+
annotated_frame_output,
|
| 338 |
+
GLOBAL_LOG.get_text()
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
try:
|
| 342 |
+
if video_file is None:
|
| 343 |
+
status_value = "⚠️ Please upload a video file"
|
| 344 |
+
yield snapshot()
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
progress(0, desc="Starting processing...")
|
| 348 |
+
status_value = "🔄 Starting processing..."
|
| 349 |
+
GLOBAL_LOG.add("Starting video processing...")
|
| 350 |
+
yield snapshot()
|
| 351 |
+
|
| 352 |
+
temp_dir = None
|
| 353 |
+
try:
|
| 354 |
+
temp_dir = tempfile.mkdtemp(dir=TEMP_DIR)
|
| 355 |
+
print(f"Created temporary directory: {temp_dir}")
|
| 356 |
+
|
| 357 |
+
progress(0.1, desc="Detecting speakers in video...")
|
| 358 |
+
status_value = "🔍 Detecting speakers in video..."
|
| 359 |
+
print("Starting face detection in video...")
|
| 360 |
+
yield snapshot()
|
| 361 |
+
|
| 362 |
+
face_images, num_speakers, annotated_frame, error_msg = detect_and_extract_all_faces(video_file)
|
| 363 |
+
print(f"Face detection completed. Found {num_speakers} speakers.")
|
| 364 |
+
|
| 365 |
+
if error_msg:
|
| 366 |
+
print(f"Error: {error_msg}")
|
| 367 |
+
status_value = f"❌ {error_msg}"
|
| 368 |
+
if annotated_frame is not None:
|
| 369 |
+
annotated_frame_output = gr.update(value=annotated_frame, visible=True)
|
| 370 |
+
yield snapshot()
|
| 371 |
+
return
|
| 372 |
+
|
| 373 |
+
if num_speakers == 0:
|
| 374 |
+
print("No speakers detected in the video.")
|
| 375 |
+
status_value = "❌ No speakers detected in the video. Please ensure faces are visible in the first frame."
|
| 376 |
+
if annotated_frame is not None:
|
| 377 |
+
annotated_frame_output = gr.update(value=annotated_frame, visible=True)
|
| 378 |
+
yield snapshot()
|
| 379 |
+
return
|
| 380 |
+
|
| 381 |
+
face_gallery_images = [(img, f"Speaker {i+1}") for i, img in enumerate(face_images)]
|
| 382 |
+
detected_info = f"🎯 Detected {num_speakers} speaker{'s' if num_speakers > 1 else ''} in the video"
|
| 383 |
+
detected_info_output = gr.update(value=detected_info, visible=True)
|
| 384 |
+
face_gallery_output = gr.update(value=face_gallery_images, visible=True)
|
| 385 |
+
if annotated_frame is not None:
|
| 386 |
+
annotated_frame_output = gr.update(value=annotated_frame, visible=True)
|
| 387 |
+
|
| 388 |
+
progress(0.3, desc=f"Separating {num_speakers} speakers...")
|
| 389 |
+
status_value = f"🎬 Separating {num_speakers} speakers..."
|
| 390 |
+
print(f"Starting audio-visual separation for {num_speakers} speakers...")
|
| 391 |
+
yield snapshot()
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
print("Starting GPU-accelerated video processing...")
|
| 395 |
+
with GPU_PROGRESS_LOCK:
|
| 396 |
+
GPU_PROGRESS_STATE["progress"] = 0.0
|
| 397 |
+
GPU_PROGRESS_STATE["status"] = "Processing on GPU..."
|
| 398 |
+
|
| 399 |
+
progress(0.4, desc="Processing on GPU...")
|
| 400 |
+
status_value = "Processing on GPU..."
|
| 401 |
+
yield snapshot()
|
| 402 |
+
|
| 403 |
+
gpu_result = {"output_files": None, "exception": None}
|
| 404 |
+
|
| 405 |
+
def run_gpu_processing():
|
| 406 |
+
try:
|
| 407 |
+
gpu_result["output_files"] = process_video_gpu(
|
| 408 |
+
video_file=video_file,
|
| 409 |
+
temp_dir=temp_dir,
|
| 410 |
+
num_speakers=num_speakers
|
| 411 |
+
)
|
| 412 |
+
except Exception as exc:
|
| 413 |
+
gpu_result["exception"] = exc
|
| 414 |
+
|
| 415 |
+
gpu_thread = threading.Thread(target=run_gpu_processing, daemon=True)
|
| 416 |
+
gpu_thread.start()
|
| 417 |
+
|
| 418 |
+
last_reported_progress = 0.4
|
| 419 |
+
last_status_message = "Processing on GPU..."
|
| 420 |
+
|
| 421 |
+
while gpu_thread.is_alive():
|
| 422 |
+
time.sleep(0.5)
|
| 423 |
+
with GPU_PROGRESS_LOCK:
|
| 424 |
+
gpu_status = GPU_PROGRESS_STATE.get("status", "Processing on GPU...")
|
| 425 |
+
gpu_progress_value = GPU_PROGRESS_STATE.get("progress", 0.0)
|
| 426 |
+
|
| 427 |
+
mapped_progress = 0.4 + 0.5 * gpu_progress_value
|
| 428 |
+
mapped_progress = min(mapped_progress, 0.89)
|
| 429 |
+
|
| 430 |
+
if (
|
| 431 |
+
mapped_progress > last_reported_progress + 0.01
|
| 432 |
+
or gpu_status != last_status_message
|
| 433 |
+
):
|
| 434 |
+
progress(mapped_progress, desc=gpu_status)
|
| 435 |
+
last_reported_progress = mapped_progress
|
| 436 |
+
last_status_message = gpu_status
|
| 437 |
+
status_value = gpu_status
|
| 438 |
+
yield snapshot()
|
| 439 |
+
|
| 440 |
+
gpu_thread.join()
|
| 441 |
+
|
| 442 |
+
if gpu_result["exception"] is not None:
|
| 443 |
+
raise gpu_result["exception"]
|
| 444 |
+
|
| 445 |
+
output_files = gpu_result["output_files"]
|
| 446 |
+
|
| 447 |
+
progress(0.9, desc="Preparing results...")
|
| 448 |
+
status_value = "📦 Preparing results..."
|
| 449 |
+
print("Processing completed successfully!")
|
| 450 |
+
print(f"Generated {num_speakers} output videos")
|
| 451 |
+
yield snapshot()
|
| 452 |
+
|
| 453 |
+
video_dict_value = {i: output_files[i] for i in range(num_speakers)}
|
| 454 |
+
video_dict_value['temp_dir'] = temp_dir
|
| 455 |
+
output_video_output = gr.update(value=output_files[0], visible=True)
|
| 456 |
+
|
| 457 |
+
progress(1.0, desc="Complete!")
|
| 458 |
+
status_value = f"✅ Successfully separated {num_speakers} speakers! Click on any face below to view their video."
|
| 459 |
+
yield snapshot()
|
| 460 |
+
|
| 461 |
+
except Exception as e:
|
| 462 |
+
print(f"Processing failed: {str(e)}")
|
| 463 |
+
import traceback
|
| 464 |
+
traceback.print_exc()
|
| 465 |
+
status_value = f"❌ Processing failed: {str(e)}"
|
| 466 |
+
output_video_output = gr.update(visible=False)
|
| 467 |
+
video_dict_value = None
|
| 468 |
+
yield snapshot()
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
except Exception as e:
|
| 472 |
+
if temp_dir and os.path.exists(temp_dir):
|
| 473 |
+
try:
|
| 474 |
+
shutil.rmtree(temp_dir)
|
| 475 |
+
except Exception:
|
| 476 |
+
pass
|
| 477 |
+
|
| 478 |
+
print(f"Error: {str(e)}")
|
| 479 |
+
import traceback
|
| 480 |
+
traceback.print_exc()
|
| 481 |
+
status_value = f"❌ Error: {str(e)}"
|
| 482 |
+
detected_info_output = gr.update(visible=False)
|
| 483 |
+
face_gallery_output = gr.update(visible=False)
|
| 484 |
+
output_video_output = gr.update(visible=False)
|
| 485 |
+
annotated_frame_output = gr.update(visible=False)
|
| 486 |
+
video_dict_value = None
|
| 487 |
+
yield snapshot()
|
| 488 |
+
return
|
| 489 |
+
finally:
|
| 490 |
+
sys.stdout = old_stdout
|
| 491 |
+
|
| 492 |
+
def on_face_click(evt: gr.SelectData, video_dict):
|
| 493 |
+
"""Handle face gallery click events"""
|
| 494 |
+
if video_dict is None or evt.index not in video_dict:
|
| 495 |
+
return None
|
| 496 |
+
|
| 497 |
+
return video_dict[evt.index]
|
| 498 |
+
|
| 499 |
+
# Create the Gradio interface
|
| 500 |
+
custom_css = """
|
| 501 |
+
.face-gallery {
|
| 502 |
+
border-radius: 10px;
|
| 503 |
+
overflow: hidden;
|
| 504 |
+
}
|
| 505 |
+
.face-gallery img {
|
| 506 |
+
border-radius: 8px;
|
| 507 |
+
transition: transform 0.2s ease-in-out;
|
| 508 |
+
}
|
| 509 |
+
.face-gallery img:hover {
|
| 510 |
+
transform: scale(1.05);
|
| 511 |
+
cursor: pointer;
|
| 512 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.3);
|
| 513 |
+
}
|
| 514 |
+
.detected-info {
|
| 515 |
+
background-color: #f0f0f0;
|
| 516 |
+
padding: 10px;
|
| 517 |
+
border-radius: 5px;
|
| 518 |
+
margin: 10px 0;
|
| 519 |
+
}
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
with gr.Blocks(
|
| 523 |
+
title="Video Speaker Auto-Separation",
|
| 524 |
+
theme=gr.themes.Soft(),
|
| 525 |
+
css=custom_css
|
| 526 |
+
) as demo:
|
| 527 |
+
gr.Markdown(
|
| 528 |
+
"""
|
| 529 |
+
# 🎥 Dolphin: Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Hierarchical Top-Down Attention
|
| 530 |
+
<p align="left">
|
| 531 |
+
<img src="https://visitor-badge.laobi.icu/badge?page_id=JusperLee.Dolphin" alt="访客统计" /><img src="https://img.shields.io/github/stars/JusperLee/Dolphin?style=social" alt="GitHub stars" /><img alt="Static Badge" src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" />
|
| 532 |
+
</p>
|
| 533 |
+
|
| 534 |
+
### Automatically detect and separate ALL speakers in your video
|
| 535 |
+
|
| 536 |
+
Simply upload a video and the system will:
|
| 537 |
+
1. 🔍 Automatically detect all speakers in the video
|
| 538 |
+
2. 🎭 Show you each detected speaker's face
|
| 539 |
+
3. 🎬 Generate individual videos for each speaker with their isolated audio
|
| 540 |
+
"""
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
with gr.Row():
|
| 544 |
+
with gr.Column(scale=2):
|
| 545 |
+
video_input = gr.Video(
|
| 546 |
+
label="📹 Upload Your Video",
|
| 547 |
+
height=300,
|
| 548 |
+
interactive=True
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Add example video section
|
| 552 |
+
gr.Markdown("### 🎬 Try with Example Video")
|
| 553 |
+
gr.Examples(
|
| 554 |
+
examples=[["demo1/mix.mp4"]],
|
| 555 |
+
inputs=video_input,
|
| 556 |
+
label="Click to load example video",
|
| 557 |
+
cache_examples=False
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
process_btn = gr.Button(
|
| 561 |
+
"🚀 Auto-Detect and Process",
|
| 562 |
+
variant="primary",
|
| 563 |
+
size="lg"
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
status = gr.Textbox(
|
| 567 |
+
label="Status",
|
| 568 |
+
interactive=False,
|
| 569 |
+
value="⏳ Ready to process..."
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
processing_log = gr.Textbox(
|
| 573 |
+
label="📋 Processing Details",
|
| 574 |
+
lines=10,
|
| 575 |
+
max_lines=15,
|
| 576 |
+
interactive=False,
|
| 577 |
+
value=""
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
with gr.Column(scale=3):
|
| 581 |
+
annotated_frame = gr.Image(
|
| 582 |
+
label="📸 Detected Speakers in First Frame",
|
| 583 |
+
visible=False,
|
| 584 |
+
height=300
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
detected_info = gr.Markdown(
|
| 588 |
+
visible=False,
|
| 589 |
+
elem_classes="detected-info"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
gr.Markdown("### 👇 Click on any face below to view that speaker's video")
|
| 593 |
+
|
| 594 |
+
face_gallery = gr.Gallery(
|
| 595 |
+
label="Detected Speaker Faces",
|
| 596 |
+
show_label=False,
|
| 597 |
+
columns=5,
|
| 598 |
+
rows=1,
|
| 599 |
+
height=200,
|
| 600 |
+
visible=False,
|
| 601 |
+
object_fit="contain",
|
| 602 |
+
elem_classes="face-gallery"
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
output_video = gr.Video(
|
| 606 |
+
label="🎬 Selected Speaker's Video",
|
| 607 |
+
height=300,
|
| 608 |
+
visible=False,
|
| 609 |
+
autoplay=True
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# Hidden state
|
| 613 |
+
video_dict = gr.State()
|
| 614 |
+
|
| 615 |
+
gr.Markdown(
|
| 616 |
+
"""
|
| 617 |
+
---
|
| 618 |
+
### 📖 How it works:
|
| 619 |
+
|
| 620 |
+
1. **Upload** - Select any video file
|
| 621 |
+
2. **Process** - Click the button to start automatic detection
|
| 622 |
+
3. **Review** - See all detected speakers and their positions
|
| 623 |
+
4. **Select** - Click on any face to watch that speaker's separated video
|
| 624 |
+
|
| 625 |
+
### 💡 Tips for best results:
|
| 626 |
+
|
| 627 |
+
- ✅ Ensure all speakers' faces are visible in the first frame
|
| 628 |
+
- ✅ Use videos with good lighting and clear face views
|
| 629 |
+
- ✅ Works best with frontal or near-frontal face angles
|
| 630 |
+
- ⏱️ Processing time depends on video length and number of speakers
|
| 631 |
+
|
| 632 |
+
### 🚀 Powered by:
|
| 633 |
+
- RetinaFace for face detection
|
| 634 |
+
- Dolphin model for audio-visual separation
|
| 635 |
+
- GPU acceleration when available
|
| 636 |
+
<footer style="display:none;">
|
| 637 |
+
<a href='https://clustrmaps.com/site/1c828' title='Visit tracker'>
|
| 638 |
+
<img src='//clustrmaps.com/map_v2.png?cl=080808&w=300&t=tt&d=XYmTC4S_SxuX7G06iJ16lU43VCNkCBFRLXMfEM5zvmo&co=ffffff&ct=808080'/>
|
| 639 |
+
</a>
|
| 640 |
+
</footer>
|
| 641 |
+
"""
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Event handlers
|
| 645 |
+
outputs_list = [
|
| 646 |
+
status,
|
| 647 |
+
detected_info,
|
| 648 |
+
face_gallery,
|
| 649 |
+
output_video,
|
| 650 |
+
video_dict,
|
| 651 |
+
annotated_frame,
|
| 652 |
+
processing_log
|
| 653 |
+
]
|
| 654 |
+
|
| 655 |
+
process_btn.click(
|
| 656 |
+
fn=process_video_auto,
|
| 657 |
+
inputs=[video_input],
|
| 658 |
+
outputs=outputs_list,
|
| 659 |
+
show_progress=True
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
face_gallery.select(
|
| 663 |
+
fn=on_face_click,
|
| 664 |
+
inputs=[video_dict],
|
| 665 |
+
outputs=output_video
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# Launch the demo - HF Space will handle this automatically
|
| 669 |
+
if __name__ == "__main__":
|
| 670 |
+
import os
|
| 671 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
console_capture.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import io
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
class TeeOutput:
|
| 6 |
+
"""Capture stdout/stderr while still printing to console"""
|
| 7 |
+
def __init__(self, stream, callback=None):
|
| 8 |
+
self.stream = stream
|
| 9 |
+
self.callback = callback
|
| 10 |
+
self.buffer = []
|
| 11 |
+
|
| 12 |
+
def write(self, data):
|
| 13 |
+
# Write to original stream
|
| 14 |
+
self.stream.write(data)
|
| 15 |
+
self.stream.flush()
|
| 16 |
+
|
| 17 |
+
# Capture the data
|
| 18 |
+
if data.strip(): # Only capture non-empty lines
|
| 19 |
+
self.buffer.append(data.rstrip())
|
| 20 |
+
if self.callback:
|
| 21 |
+
self.callback(data.rstrip())
|
| 22 |
+
|
| 23 |
+
def flush(self):
|
| 24 |
+
self.stream.flush()
|
| 25 |
+
|
| 26 |
+
def get_captured(self):
|
| 27 |
+
return '\n'.join(self.buffer)
|
| 28 |
+
|
| 29 |
+
@contextmanager
|
| 30 |
+
def capture_console(stdout_callback=None, stderr_callback=None):
|
| 31 |
+
"""Context manager to capture console output"""
|
| 32 |
+
old_stdout = sys.stdout
|
| 33 |
+
old_stderr = sys.stderr
|
| 34 |
+
|
| 35 |
+
stdout_capture = TeeOutput(old_stdout, stdout_callback)
|
| 36 |
+
stderr_capture = TeeOutput(old_stderr, stderr_callback)
|
| 37 |
+
|
| 38 |
+
sys.stdout = stdout_capture
|
| 39 |
+
sys.stderr = stderr_capture
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
yield stdout_capture, stderr_capture
|
| 43 |
+
finally:
|
| 44 |
+
sys.stdout = old_stdout
|
| 45 |
+
sys.stderr = old_stderr
|
demo1/mix.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fbf7577afd8b8ebc4a70d88e5d6a8216dd9ca07a1e26a83a0c677074510ec39c
|
| 3 |
+
size 3387273
|
face_detection_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility helpers for RetinaFace-based face detection."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from retinaface import RetinaFace # type: ignore
|
| 11 |
+
except ImportError as import_error: # pragma: no cover - handled at runtime
|
| 12 |
+
RetinaFace = None # type: ignore
|
| 13 |
+
_RETINAFACE_IMPORT_ERROR = import_error
|
| 14 |
+
else:
|
| 15 |
+
_RETINAFACE_IMPORT_ERROR = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from PIL import Image
|
| 19 |
+
except ImportError: # pragma: no cover
|
| 20 |
+
Image = None # type: ignore
|
| 21 |
+
|
| 22 |
+
import spaces
|
| 23 |
+
|
| 24 |
+
def _ensure_retinaface_available() -> None:
|
| 25 |
+
if RetinaFace is None: # pragma: no cover - runtime safeguard
|
| 26 |
+
raise ImportError(
|
| 27 |
+
"RetinaFace package is required but not installed. "
|
| 28 |
+
"Install it with `pip install retinaface`."
|
| 29 |
+
) from _RETINAFACE_IMPORT_ERROR
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _to_rgb_array(image: np.ndarray, *, assume_bgr: bool = False) -> np.ndarray:
|
| 33 |
+
"""Convert input to an RGB numpy array."""
|
| 34 |
+
if isinstance(image, np.ndarray):
|
| 35 |
+
array = image
|
| 36 |
+
elif Image is not None and isinstance(image, Image.Image):
|
| 37 |
+
array = np.array(image.convert("RGB"))
|
| 38 |
+
else:
|
| 39 |
+
raise TypeError("Expected an ndarray or PIL.Image.Image for face detection")
|
| 40 |
+
|
| 41 |
+
if array.ndim != 3 or array.shape[2] != 3:
|
| 42 |
+
raise ValueError("Face detection expects an image with shape (H, W, 3)")
|
| 43 |
+
|
| 44 |
+
if array.dtype != np.uint8:
|
| 45 |
+
array = array.astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
if assume_bgr:
|
| 48 |
+
return cv2.cvtColor(array, cv2.COLOR_BGR2RGB)
|
| 49 |
+
return array
|
| 50 |
+
|
| 51 |
+
@spaces.GPU(duration=360)
|
| 52 |
+
def detect_faces(
|
| 53 |
+
image: np.ndarray,
|
| 54 |
+
*,
|
| 55 |
+
threshold: float = 0.9,
|
| 56 |
+
allow_upscaling: bool = False,
|
| 57 |
+
model: Optional[str] = None,
|
| 58 |
+
assume_bgr: bool = False,
|
| 59 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
| 60 |
+
"""Run RetinaFace detection on an image.
|
| 61 |
+
|
| 62 |
+
Returns bounding boxes shaped (N, 4) and confidence scores shaped (N,).
|
| 63 |
+
If no face is detected, both values are ``None``.
|
| 64 |
+
"""
|
| 65 |
+
_ensure_retinaface_available()
|
| 66 |
+
|
| 67 |
+
rgb_image = _to_rgb_array(image, assume_bgr=assume_bgr)
|
| 68 |
+
bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
|
| 69 |
+
|
| 70 |
+
detections = RetinaFace.detect_faces(
|
| 71 |
+
bgr_image,
|
| 72 |
+
threshold=threshold,
|
| 73 |
+
model=model,
|
| 74 |
+
allow_upscaling=allow_upscaling,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if not isinstance(detections, dict) or not detections:
|
| 78 |
+
return None, None
|
| 79 |
+
|
| 80 |
+
boxes, scores = [], []
|
| 81 |
+
for face_data in detections.values():
|
| 82 |
+
facial_area = face_data.get("facial_area")
|
| 83 |
+
if facial_area is None:
|
| 84 |
+
continue
|
| 85 |
+
boxes.append(facial_area)
|
| 86 |
+
scores.append(face_data.get("score", 0.0))
|
| 87 |
+
|
| 88 |
+
if not boxes:
|
| 89 |
+
return None, None
|
| 90 |
+
|
| 91 |
+
boxes_array = np.asarray(boxes, dtype=np.float32)
|
| 92 |
+
scores_array = np.asarray(scores, dtype=np.float32) if scores else None
|
| 93 |
+
|
| 94 |
+
return boxes_array, scores_array
|
| 95 |
+
|
| 96 |
+
@spaces.GPU(duration=360)
|
| 97 |
+
def extract_faces(
|
| 98 |
+
image: np.ndarray,
|
| 99 |
+
*,
|
| 100 |
+
align: bool = True,
|
| 101 |
+
threshold: float = 0.9,
|
| 102 |
+
allow_upscaling: bool = False,
|
| 103 |
+
model: Optional[str] = None,
|
| 104 |
+
assume_bgr: bool = False,
|
| 105 |
+
) -> Optional[np.ndarray]:
|
| 106 |
+
"""Extract faces using RetinaFace.extract_faces for convenience."""
|
| 107 |
+
_ensure_retinaface_available()
|
| 108 |
+
|
| 109 |
+
rgb_image = _to_rgb_array(image, assume_bgr=assume_bgr)
|
| 110 |
+
bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
|
| 111 |
+
|
| 112 |
+
faces = RetinaFace.extract_faces(
|
| 113 |
+
bgr_image,
|
| 114 |
+
align=align,
|
| 115 |
+
threshold=threshold,
|
| 116 |
+
model=model,
|
| 117 |
+
allow_upscaling=allow_upscaling,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if not faces:
|
| 121 |
+
return None
|
| 122 |
+
return np.asarray([np.asarray(face, dtype=np.uint8) for face in faces])
|
look2hear/datas/transform.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###
|
| 2 |
+
# Author: Kai Li
|
| 3 |
+
# Date: 2021-06-19 22:34:13
|
| 4 |
+
# LastEditors: Kai Li
|
| 5 |
+
# LastEditTime: 2021-08-30 20:01:43
|
| 6 |
+
###
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torchvision
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"Compose",
|
| 15 |
+
"Normalize",
|
| 16 |
+
"CenterCrop",
|
| 17 |
+
"RgbToGray",
|
| 18 |
+
"RandomCrop",
|
| 19 |
+
"HorizontalFlip",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Compose(object):
|
| 24 |
+
"""Compose several preprocess together.
|
| 25 |
+
Args:
|
| 26 |
+
preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, preprocess):
|
| 30 |
+
self.preprocess = preprocess
|
| 31 |
+
|
| 32 |
+
def __call__(self, sample):
|
| 33 |
+
for t in self.preprocess:
|
| 34 |
+
sample = t(sample)
|
| 35 |
+
return sample
|
| 36 |
+
|
| 37 |
+
def __repr__(self):
|
| 38 |
+
format_string = self.__class__.__name__ + "("
|
| 39 |
+
for t in self.preprocess:
|
| 40 |
+
format_string += "\n"
|
| 41 |
+
format_string += " {0}".format(t)
|
| 42 |
+
format_string += "\n)"
|
| 43 |
+
return format_string
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RgbToGray(object):
|
| 47 |
+
"""Convert image to grayscale.
|
| 48 |
+
Converts a numpy.ndarray (H x W x C) in the range
|
| 49 |
+
[0, 255] to a numpy.ndarray of shape (H x W x C) in the range [0.0, 1.0].
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __call__(self, frames):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
img (numpy.ndarray): Image to be converted to gray.
|
| 56 |
+
Returns:
|
| 57 |
+
numpy.ndarray: grey image
|
| 58 |
+
"""
|
| 59 |
+
frames = np.stack([cv2.cvtColor(_, cv2.COLOR_RGB2GRAY) for _ in frames], axis=0)
|
| 60 |
+
return frames
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return self.__class__.__name__ + "()"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Normalize(object):
|
| 67 |
+
"""Normalize a ndarray image with mean and standard deviation."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, mean, std):
|
| 70 |
+
self.mean = mean
|
| 71 |
+
self.std = std
|
| 72 |
+
|
| 73 |
+
def __call__(self, frames):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
| 77 |
+
Returns:
|
| 78 |
+
Tensor: Normalized Tensor image.
|
| 79 |
+
"""
|
| 80 |
+
frames = (frames - self.mean) / self.std
|
| 81 |
+
return frames
|
| 82 |
+
|
| 83 |
+
def __repr__(self):
|
| 84 |
+
return self.__class__.__name__ + "(mean={0}, std={1})".format(
|
| 85 |
+
self.mean, self.std
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CenterCrop(object):
|
| 90 |
+
"""Crop the given image at the center"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, size):
|
| 93 |
+
self.size = size
|
| 94 |
+
|
| 95 |
+
def __call__(self, frames):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
img (numpy.ndarray): Images to be cropped.
|
| 99 |
+
Returns:
|
| 100 |
+
numpy.ndarray: Cropped image.
|
| 101 |
+
"""
|
| 102 |
+
t, h, w = frames.shape
|
| 103 |
+
th, tw = self.size
|
| 104 |
+
delta_w = int(round((w - tw)) / 2.0)
|
| 105 |
+
delta_h = int(round((h - th)) / 2.0)
|
| 106 |
+
frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
|
| 107 |
+
return frames
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class RandomCrop(object):
|
| 111 |
+
"""Crop the given image at the center"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, size):
|
| 114 |
+
self.size = size
|
| 115 |
+
|
| 116 |
+
def __call__(self, frames):
|
| 117 |
+
"""
|
| 118 |
+
Args:
|
| 119 |
+
img (numpy.ndarray): Images to be cropped.
|
| 120 |
+
Returns:
|
| 121 |
+
numpy.ndarray: Cropped image.
|
| 122 |
+
"""
|
| 123 |
+
t, h, w = frames.shape
|
| 124 |
+
th, tw = self.size
|
| 125 |
+
delta_w = random.randint(0, w - tw)
|
| 126 |
+
delta_h = random.randint(0, h - th)
|
| 127 |
+
frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
|
| 128 |
+
return frames
|
| 129 |
+
|
| 130 |
+
def __repr__(self):
|
| 131 |
+
return self.__class__.__name__ + "(size={0})".format(self.size)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class HorizontalFlip(object):
|
| 135 |
+
"""Flip image horizontally."""
|
| 136 |
+
|
| 137 |
+
def __init__(self, flip_ratio):
|
| 138 |
+
self.flip_ratio = flip_ratio
|
| 139 |
+
|
| 140 |
+
def __call__(self, frames):
|
| 141 |
+
"""
|
| 142 |
+
Args:
|
| 143 |
+
img (numpy.ndarray): Images to be flipped with a probability flip_ratio
|
| 144 |
+
Returns:
|
| 145 |
+
numpy.ndarray: Cropped image.
|
| 146 |
+
"""
|
| 147 |
+
t, h, w = frames.shape
|
| 148 |
+
if random.random() < self.flip_ratio:
|
| 149 |
+
for index in range(t):
|
| 150 |
+
frames[index] = cv2.flip(frames[index], 1)
|
| 151 |
+
return frames
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_preprocessing_pipelines():
|
| 155 |
+
# -- preprocess for the video stream
|
| 156 |
+
preprocessing = {}
|
| 157 |
+
# -- LRW config
|
| 158 |
+
crop_size = (88, 88)
|
| 159 |
+
(mean, std) = (0.421, 0.165)
|
| 160 |
+
preprocessing["train"] = Compose(
|
| 161 |
+
[
|
| 162 |
+
Normalize(0.0, 255.0),
|
| 163 |
+
RandomCrop(crop_size),
|
| 164 |
+
HorizontalFlip(0.5),
|
| 165 |
+
Normalize(mean, std),
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
preprocessing["val"] = Compose(
|
| 169 |
+
[Normalize(0.0, 255.0), CenterCrop(crop_size), Normalize(mean, std)]
|
| 170 |
+
)
|
| 171 |
+
preprocessing["test"] = preprocessing["val"]
|
| 172 |
+
return preprocessing
|
| 173 |
+
|
| 174 |
+
def get_preprocessing_opt_pipelines():
|
| 175 |
+
preprocessing = {}
|
| 176 |
+
# -- LRW config
|
| 177 |
+
crop_size = (88, 88)
|
| 178 |
+
(mean, std) = (0.421, 0.165)
|
| 179 |
+
preprocessing["train"] = torchvision.transforms.Compose([
|
| 180 |
+
torchvision.transforms.Normalize(0.0, 255.0),
|
| 181 |
+
torchvision.transforms.RandomCrop(crop_size),
|
| 182 |
+
torchvision.transforms.RandomHorizontalFlip(0.5),
|
| 183 |
+
torchvision.transforms.Normalize(mean, std)
|
| 184 |
+
])
|
| 185 |
+
preprocessing["val"] = torchvision.transforms.Compose([
|
| 186 |
+
torchvision.transforms.Normalize(0.0, 255.0),
|
| 187 |
+
torchvision.transforms.CenterCrop(crop_size),
|
| 188 |
+
torchvision.transforms.Normalize(mean, std)
|
| 189 |
+
])
|
| 190 |
+
preprocessing["test"] = preprocessing["val"]
|
| 191 |
+
return preprocessing
|
look2hear/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dolphin import Dolphin
|
look2hear/models/dolphin.py
ADDED
|
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dolphin Model
|
| 3 |
+
|
| 4 |
+
This implementation is inspired by and borrows concepts from Sepformer.
|
| 5 |
+
The original Sepformer work is licensed under the Apache-2.0 License.
|
| 6 |
+
|
| 7 |
+
References:
|
| 8 |
+
- SepReformer: https://github.com/dmlguq456/SepReformer
|
| 9 |
+
- Apache-2.0 License: https://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from re import S
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import math
|
| 19 |
+
from vector_quantize_pytorch import ResidualVQ
|
| 20 |
+
from .video_compoent import *
|
| 21 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 22 |
+
|
| 23 |
+
class LayerScale(torch.nn.Module):
|
| 24 |
+
def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
|
| 25 |
+
super().__init__()
|
| 26 |
+
if dims == 1:
|
| 27 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
|
| 28 |
+
elif dims == 2:
|
| 29 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
|
| 30 |
+
elif dims == 3:
|
| 31 |
+
self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return x*self.layer_scale
|
| 35 |
+
|
| 36 |
+
class Masking(torch.nn.Module):
|
| 37 |
+
def __init__(self, input_dim):
|
| 38 |
+
super(Masking, self).__init__()
|
| 39 |
+
self.gate_act = torch.nn.ReLU()
|
| 40 |
+
|
| 41 |
+
def forward(self, x, skip):
|
| 42 |
+
return self.gate_act(x) * skip
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FFN(torch.nn.Module):
|
| 46 |
+
def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
|
| 47 |
+
super().__init__()
|
| 48 |
+
expand_factor = 3
|
| 49 |
+
self.net1 = torch.nn.Sequential(
|
| 50 |
+
torch.nn.LayerNorm(in_channels),
|
| 51 |
+
torch.nn.Linear(in_channels, in_channels * expand_factor))
|
| 52 |
+
self.depthwise = torch.nn.Conv1d(in_channels * expand_factor, in_channels * expand_factor, 3, padding=1, groups=in_channels * expand_factor)
|
| 53 |
+
self.net2 = torch.nn.Sequential(
|
| 54 |
+
torch.nn.GLU(),
|
| 55 |
+
torch.nn.Dropout(dropout_rate),
|
| 56 |
+
torch.nn.Linear(in_channels * expand_factor // 2, in_channels),
|
| 57 |
+
torch.nn.Dropout(dropout_rate))
|
| 58 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
y = self.net1(x)
|
| 62 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 63 |
+
y = self.depthwise(y)
|
| 64 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 65 |
+
y = self.net2(y)
|
| 66 |
+
return x + self.Layer_scale(y)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Multi-Head Attention layer.
|
| 72 |
+
:param int n_head: the number of head s
|
| 73 |
+
:param int n_feat: the number of features
|
| 74 |
+
:param float dropout_rate: dropout rate
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert in_channels % n_head == 0
|
| 79 |
+
self.d_k = in_channels // n_head # We assume d_v always equals d_k
|
| 80 |
+
self.h = n_head
|
| 81 |
+
self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 82 |
+
self.linear_q = torch.nn.Linear(in_channels, in_channels)
|
| 83 |
+
self.linear_k = torch.nn.Linear(in_channels, in_channels)
|
| 84 |
+
self.linear_v = torch.nn.Linear(in_channels, in_channels)
|
| 85 |
+
self.linear_out = torch.nn.Linear(in_channels, in_channels)
|
| 86 |
+
self.attn = None
|
| 87 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 88 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, pos_k, mask):
|
| 91 |
+
"""
|
| 92 |
+
Compute 'Scaled Dot Product Attention'.
|
| 93 |
+
:param torch.Tensor mask: (batch, time1, time2)
|
| 94 |
+
:param torch.nn.Dropout dropout:
|
| 95 |
+
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
|
| 96 |
+
weighted by the query dot key attention (batch, head, time1, time2)
|
| 97 |
+
"""
|
| 98 |
+
n_batch = x.size(0)
|
| 99 |
+
x = self.layer_norm(x)
|
| 100 |
+
q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 101 |
+
k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
|
| 102 |
+
v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
|
| 103 |
+
q = q.transpose(1, 2)
|
| 104 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 105 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 106 |
+
A = torch.matmul(q, k.transpose(-2, -1))
|
| 107 |
+
reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
|
| 108 |
+
if pos_k is not None:
|
| 109 |
+
B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
|
| 110 |
+
B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
|
| 111 |
+
scores = (A + B) / math.sqrt(self.d_k)
|
| 112 |
+
else:
|
| 113 |
+
scores = A / math.sqrt(self.d_k)
|
| 114 |
+
if mask is not None:
|
| 115 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
| 116 |
+
min_value = float(np.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
| 117 |
+
scores = scores.masked_fill(mask, min_value)
|
| 118 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
| 119 |
+
else:
|
| 120 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 121 |
+
p_attn = self.dropout(self.attn)
|
| 122 |
+
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
|
| 123 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
|
| 124 |
+
return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
|
| 125 |
+
|
| 126 |
+
class DU_MHSA(torch.nn.Module):
|
| 127 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.block = torch.nn.ModuleDict({
|
| 130 |
+
'self_attn': MultiHeadAttention(
|
| 131 |
+
n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 132 |
+
'linear': torch.nn.Sequential(
|
| 133 |
+
torch.nn.LayerNorm(normalized_shape=in_channels),
|
| 134 |
+
torch.nn.Linear(in_features=in_channels, out_features=in_channels),
|
| 135 |
+
torch.nn.Sigmoid())
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Compute encoded features.
|
| 141 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 142 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 143 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 144 |
+
"""
|
| 145 |
+
down_len = pos_k.shape[0]
|
| 146 |
+
x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
|
| 147 |
+
x = x.permute([0, 2, 1])
|
| 148 |
+
x_down = x_down.permute([0, 2, 1])
|
| 149 |
+
x_down = self.block['self_attn'](x_down, pos_k, None)
|
| 150 |
+
x_down = x_down.permute([0, 2, 1])
|
| 151 |
+
x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
|
| 152 |
+
x_downup = x_downup.permute([0, 2, 1])
|
| 153 |
+
x = x + self.block['linear'](x) * x_downup
|
| 154 |
+
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
class Heat1D(nn.Module):
|
| 158 |
+
"""
|
| 159 |
+
1D Heat Equation Adaptation:
|
| 160 |
+
du/dt - k d²u/dx² = 0;
|
| 161 |
+
du/dx_{x=0, x=a} = 0
|
| 162 |
+
=>
|
| 163 |
+
A_n = C(a, n==0) * sum_{0}^{a} { \phi(x) cos(n π / a x) dx }
|
| 164 |
+
core = cos(n π / a x) exp(- (n π / a)^2 k t)
|
| 165 |
+
u_{x, t} = sum_{0}^{\infinite} { core }
|
| 166 |
+
|
| 167 |
+
Assume a = T; x in [0, T]; n in [0, T]; with some slight changes
|
| 168 |
+
=>
|
| 169 |
+
(\phi(x) = linear(dwconv(input(x))))
|
| 170 |
+
A(n) = DCT1D(\phi(x))
|
| 171 |
+
u(x, t) = IDCT1D(A(n) * exp(- (n π / a)^2 kt))
|
| 172 |
+
"""
|
| 173 |
+
def __init__(self, dim=96, hidden_dim=96, **kwargs):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.dwconv = nn.Conv1d(dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
|
| 176 |
+
self.hidden_dim = hidden_dim
|
| 177 |
+
self.linear = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
|
| 178 |
+
self.out_norm = nn.LayerNorm(hidden_dim)
|
| 179 |
+
self.out_linear = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
|
| 180 |
+
self.to_k = nn.Sequential(
|
| 181 |
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim),
|
| 182 |
+
nn.GELU(),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.k = nn.Parameter(torch.ones(hidden_dim))
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def get_cos_map(N=224, device=torch.device("cpu"), dtype=torch.float):
|
| 189 |
+
# cos((x + 0.5) / N * n * π) which is also the form of DCT and IDCT
|
| 190 |
+
# DCT: F(n) = sum( (sqrt(2/N) if n > 0 else sqrt(1/N)) * cos((x + 0.5) / N * n * π) * f(x) )
|
| 191 |
+
# IDCT: f(x) = sum( (sqrt(2/N) if n > 0 else sqrt(1/N)) * cos((x + 0.5) / N * n * π) * F(n) )
|
| 192 |
+
# returns: (Res_n, Res_x)
|
| 193 |
+
weight_x = (torch.linspace(0, N - 1, N, device=device, dtype=dtype).view(1, -1) + 0.5) / N
|
| 194 |
+
weight_n = torch.linspace(0, N - 1, N, device=device, dtype=dtype).view(-1, 1)
|
| 195 |
+
weight = torch.cos(weight_n * weight_x * torch.pi) * math.sqrt(2 / N)
|
| 196 |
+
weight[0, :] = weight[0, :] / math.sqrt(2)
|
| 197 |
+
return weight
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def get_decay_map(resolution=224, device=torch.device("cpu"), dtype=torch.float):
|
| 201 |
+
# exp(- (n π / T)^2) for 1D
|
| 202 |
+
# returns: (Res_t,)
|
| 203 |
+
res_t = resolution
|
| 204 |
+
weight_n = torch.linspace(0, torch.pi, res_t + 1, device=device, dtype=dtype)[:res_t]
|
| 205 |
+
weight = torch.pow(weight_n, 2)
|
| 206 |
+
weight = torch.exp(-weight)
|
| 207 |
+
return weight
|
| 208 |
+
|
| 209 |
+
def forward(self, x: torch.Tensor, freq_embed=None):
|
| 210 |
+
B, T, C = x.shape
|
| 211 |
+
x = x.transpose(1, 2) # [B, T, C] -> [B, C, T]
|
| 212 |
+
x = self.dwconv(x) # [B, hidden_dim, T]
|
| 213 |
+
|
| 214 |
+
x = self.linear(x) # [B, 2 * hidden_dim, T]
|
| 215 |
+
x, z = x.chunk(chunks=2, dim=1) # [B, hidden_dim, T], [B, hidden_dim, T]
|
| 216 |
+
|
| 217 |
+
if (T == getattr(self, "__RES__", 0)) and (getattr(self, "__WEIGHT_COSN__", None).device == x.device):
|
| 218 |
+
weight_cosn = getattr(self, "__WEIGHT_COSN__", None)
|
| 219 |
+
weight_exp = getattr(self, "__WEIGHT_EXP__", None)
|
| 220 |
+
assert weight_cosn is not None
|
| 221 |
+
assert weight_exp is not None
|
| 222 |
+
else:
|
| 223 |
+
weight_cosn = self.get_cos_map(T, device=x.device).detach_()
|
| 224 |
+
weight_exp = self.get_decay_map(T, device=x.device).detach_()
|
| 225 |
+
setattr(self, "__RES__", T)
|
| 226 |
+
setattr(self, "__WEIGHT_COSN__", weight_cosn)
|
| 227 |
+
setattr(self, "__WEIGHT_EXP__", weight_exp)
|
| 228 |
+
|
| 229 |
+
N = weight_cosn.shape[0] # N == T
|
| 230 |
+
|
| 231 |
+
x = x.transpose(1, 2).contiguous() # [B, T, hidden_dim]
|
| 232 |
+
|
| 233 |
+
x = F.conv1d(x.contiguous().view(B, T, -1), weight_cosn.contiguous().view(N, T, 1)) # [B, N, hidden_dim]
|
| 234 |
+
|
| 235 |
+
weight_exp = torch.pow(weight_exp[:, None], self.k)
|
| 236 |
+
x = torch.einsum("bnc,nc->bnc", x, weight_exp) # exp decay
|
| 237 |
+
|
| 238 |
+
x = F.conv1d(x.contiguous().view(B, N, -1), weight_cosn.t().contiguous().view(T, N, 1)) # [B, T, hidden_dim]
|
| 239 |
+
|
| 240 |
+
x = self.out_norm(x) # [B, T, hidden_dim]
|
| 241 |
+
|
| 242 |
+
z = z.transpose(1, 2).contiguous() # [B, T, hidden_dim]
|
| 243 |
+
x = x * nn.functional.silu(z) # [B, T, hidden_dim]
|
| 244 |
+
|
| 245 |
+
x = x.transpose(1, 2).contiguous() # [B, hidden_dim, T]
|
| 246 |
+
x = self.out_linear(x) # [B, hidden_dim, T]
|
| 247 |
+
|
| 248 |
+
x = x.transpose(1, 2).contiguous() # [B, T, hidden_dim]
|
| 249 |
+
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
class CLA(torch.nn.Module):
|
| 253 |
+
def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
|
| 254 |
+
super().__init__()
|
| 255 |
+
# self.layer_norm = torch.nn.LayerNorm(in_channels)
|
| 256 |
+
self.heat1d = Heat1D(in_channels, in_channels)
|
| 257 |
+
self.GN1 = torch.nn.GroupNorm(1, in_channels)
|
| 258 |
+
self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
|
| 259 |
+
self.GN2 = torch.nn.GroupNorm(1, in_channels)
|
| 260 |
+
self.linear3 = torch.nn.Sequential(
|
| 261 |
+
torch.nn.GELU(),
|
| 262 |
+
torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels),
|
| 263 |
+
torch.nn.Dropout(dropout_rate))
|
| 264 |
+
self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
# y = self.layer_norm(x)
|
| 268 |
+
y = self.heat1d(x)
|
| 269 |
+
y = y.permute([0, 2, 1]) # B, F, T
|
| 270 |
+
y = self.GN1(y)
|
| 271 |
+
y = self.dw_conv_1d(y)
|
| 272 |
+
y = self.linear3(y)
|
| 273 |
+
y = self.GN2(y) # [B, in_channels, T]
|
| 274 |
+
y = y.permute(0, 2, 1) # B, T, in_channels
|
| 275 |
+
return x + self.Layer_scale(y)
|
| 276 |
+
|
| 277 |
+
class GlobalBlock(torch.nn.Module):
|
| 278 |
+
def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.block = torch.nn.ModuleDict({
|
| 281 |
+
'DU_MHSA': DU_MHSA(
|
| 282 |
+
num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
|
| 283 |
+
'FFN': FFN(in_channels=in_channels, dropout_rate=dropout_rate)
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 287 |
+
"""
|
| 288 |
+
Compute encoded features.
|
| 289 |
+
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
|
| 290 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
| 291 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
| 292 |
+
"""
|
| 293 |
+
x = self.block['DU_MHSA'](x, pos_k)
|
| 294 |
+
x = self.block['FFN'](x)
|
| 295 |
+
x = x.permute([0, 2, 1])
|
| 296 |
+
|
| 297 |
+
return x
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class LocalBlock(torch.nn.Module):
|
| 301 |
+
def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.block = torch.nn.ModuleDict({
|
| 304 |
+
'CLA': CLA(in_channels, kernel_size, dropout_rate),
|
| 305 |
+
'FFN': FFN(in_channels, dropout_rate)
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
def forward(self, x: torch.Tensor):
|
| 309 |
+
x = self.block['CLA'](x)
|
| 310 |
+
x = self.block['FFN'](x)
|
| 311 |
+
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
class AudioEncoder(torch.nn.Module):
|
| 315 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.conv1d = torch.nn.Conv1d(
|
| 318 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
|
| 319 |
+
self.gelu = torch.nn.GELU()
|
| 320 |
+
|
| 321 |
+
def forward(self, x: torch.Tensor):
|
| 322 |
+
x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
|
| 323 |
+
x = self.conv1d(x)
|
| 324 |
+
x = self.gelu(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
class FeatureProjector(torch.nn.Module):
|
| 328 |
+
def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
|
| 331 |
+
self.conv1d = torch.nn.Conv1d(
|
| 332 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
|
| 333 |
+
|
| 334 |
+
def forward(self, x: torch.Tensor):
|
| 335 |
+
x = self.norm(x)
|
| 336 |
+
x = self.conv1d(x)
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
class HeatConvNorm(nn.Module):
|
| 340 |
+
"""
|
| 341 |
+
This class defines the convolution layer with normalization and PReLU activation
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
def __init__(
|
| 345 |
+
self, nIn, nOut, kSize, stride=1, groups=1, bias=True, norm_type="gLN"
|
| 346 |
+
):
|
| 347 |
+
"""
|
| 348 |
+
:param nIn: number of input channels
|
| 349 |
+
:param nOut: number of output channels
|
| 350 |
+
:param kSize: kernel size
|
| 351 |
+
:param stride: stride rate for down-sampling. Default is 1
|
| 352 |
+
"""
|
| 353 |
+
super().__init__()
|
| 354 |
+
padding = int((kSize - 1) / 2)
|
| 355 |
+
self.conv = Heat1D(
|
| 356 |
+
nIn, nOut, groups=groups
|
| 357 |
+
)
|
| 358 |
+
if norm_type == "gLN":
|
| 359 |
+
self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
|
| 360 |
+
if norm_type == "BN":
|
| 361 |
+
self.norm = nn.BatchNorm1d(nOut)
|
| 362 |
+
|
| 363 |
+
def forward(self, input):
|
| 364 |
+
input = input.permute(0, 2, 1)
|
| 365 |
+
output = self.conv(input).permute(0, 2, 1)
|
| 366 |
+
return self.norm(output)
|
| 367 |
+
|
| 368 |
+
class ConvNorm(nn.Module):
|
| 369 |
+
"""
|
| 370 |
+
This class defines the convolution layer with normalization and PReLU activation
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self, nIn, nOut, kSize, stride=1, groups=1, bias=True, norm_type="gLN"
|
| 375 |
+
):
|
| 376 |
+
"""
|
| 377 |
+
:param nIn: number of input channels
|
| 378 |
+
:param nOut: number of output channels
|
| 379 |
+
:param kSize: kernel size
|
| 380 |
+
:param stride: stride rate for down-sampling. Default is 1
|
| 381 |
+
"""
|
| 382 |
+
super().__init__()
|
| 383 |
+
padding = int((kSize - 1) / 2)
|
| 384 |
+
self.conv = nn.Conv1d(
|
| 385 |
+
nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups
|
| 386 |
+
)
|
| 387 |
+
if norm_type == "gLN":
|
| 388 |
+
self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
|
| 389 |
+
if norm_type == "BN":
|
| 390 |
+
self.norm = nn.BatchNorm1d(nOut)
|
| 391 |
+
|
| 392 |
+
def forward(self, input):
|
| 393 |
+
output = self.conv(input)
|
| 394 |
+
return self.norm(output)
|
| 395 |
+
|
| 396 |
+
class AVFModule(nn.Module):
|
| 397 |
+
"""
|
| 398 |
+
1D Attention Fusion Cell,将 tensor_b 导引 tensor_a 的 key & value:
|
| 399 |
+
Input:
|
| 400 |
+
tensor_a: [B, Ca, T]
|
| 401 |
+
tensor_b: [B, Cb, Tb]
|
| 402 |
+
Output:
|
| 403 |
+
[B, Ca, T]
|
| 404 |
+
"""
|
| 405 |
+
def __init__(self,
|
| 406 |
+
in_chan_a: int,
|
| 407 |
+
in_chan_b: int,
|
| 408 |
+
kernel_size: int = 1):
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.in_chan_a = in_chan_a
|
| 411 |
+
self.in_chan_b = in_chan_b
|
| 412 |
+
self.kernel_size = kernel_size
|
| 413 |
+
# audio key embedding (depthwise 1×1)
|
| 414 |
+
self.key_embed = ConvNormAct(
|
| 415 |
+
nIn=in_chan_a, nOut=in_chan_a, kSize=1,
|
| 416 |
+
groups=in_chan_a, norm_type="gLN"
|
| 417 |
+
)
|
| 418 |
+
# audio value embedding (depthwise 1×1)
|
| 419 |
+
self.value_embed = ConvNormAct(
|
| 420 |
+
nIn=in_chan_a, nOut=in_chan_a, kSize=1,
|
| 421 |
+
groups=in_chan_a, norm_type="gLN"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
self.resize = ConvNormAct(
|
| 425 |
+
nIn=in_chan_b, nOut=in_chan_a, kSize=1,
|
| 426 |
+
norm_type="gLN"
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self.attention_embed = ConvNormAct(
|
| 430 |
+
nIn=in_chan_b,
|
| 431 |
+
nOut=in_chan_a * kernel_size,
|
| 432 |
+
kSize=1,
|
| 433 |
+
groups=in_chan_b,
|
| 434 |
+
norm_type="gLN"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def forward(self, tensor_a: torch.Tensor, tensor_b: torch.Tensor):
|
| 438 |
+
"""
|
| 439 |
+
tensor_a: [B, Ca, T]
|
| 440 |
+
tensor_b: [B, Cb, Tb]
|
| 441 |
+
"""
|
| 442 |
+
B, Ca, T = tensor_a.shape
|
| 443 |
+
# 1) Use video to guide key_embed
|
| 444 |
+
b2a = self.resize(tensor_b) # [B, Ca, Tb]
|
| 445 |
+
b2a = F.interpolate(b2a, size=T, mode="nearest") # [B, Ca, T]
|
| 446 |
+
k1 = self.key_embed(tensor_a) * b2a # [B, Ca, T]
|
| 447 |
+
# 2) audio value
|
| 448 |
+
v = self.value_embed(tensor_a) # [B, Ca, T]
|
| 449 |
+
# 3) Calculate attention scores
|
| 450 |
+
att = self.attention_embed(tensor_b) # [B, Ca*kernel, Tb]
|
| 451 |
+
# reshape → [B, Ca, kernel, Tb]
|
| 452 |
+
att = att.view(B, Ca, self.kernel_size, -1)
|
| 453 |
+
att = att.mean(dim=2)
|
| 454 |
+
att = torch.softmax(att, dim=-1) # [B, Ca, Tb]
|
| 455 |
+
att = F.interpolate(att, size=T, mode="nearest") # [B, Ca, T]
|
| 456 |
+
# 4) k2 = attention * value
|
| 457 |
+
k2 = att * v
|
| 458 |
+
|
| 459 |
+
fused = k1 + k2 # [B, Ca, T]
|
| 460 |
+
return fused
|
| 461 |
+
|
| 462 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
| 463 |
+
def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.in_channels = in_channels
|
| 466 |
+
self.num_heads = num_heads
|
| 467 |
+
self.embedding_dim = self.in_channels // self.num_heads
|
| 468 |
+
self.maxlen = maxlen
|
| 469 |
+
self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
|
| 470 |
+
self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
|
| 471 |
+
|
| 472 |
+
def forward(self, pos_seq: torch.Tensor):
|
| 473 |
+
pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
|
| 474 |
+
pos_seq += self.maxlen
|
| 475 |
+
pe_k_output = self.pe_k(pos_seq)
|
| 476 |
+
pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
|
| 477 |
+
return pe_k_output, pe_v_output
|
| 478 |
+
|
| 479 |
+
class DownConvLayer(torch.nn.Module):
|
| 480 |
+
def __init__(self, in_channels: int, samp_kernel_size: int):
|
| 481 |
+
"""Construct an EncoderLayer object."""
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.down_conv = torch.nn.Conv1d(
|
| 484 |
+
in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
|
| 485 |
+
self.GN = nn.GroupNorm(1, num_channels=in_channels)
|
| 486 |
+
self.gelu = torch.nn.GELU()
|
| 487 |
+
|
| 488 |
+
def forward(self, x: torch.Tensor):
|
| 489 |
+
x = x.permute([0, 2, 1])
|
| 490 |
+
x = self.down_conv(x)
|
| 491 |
+
x = self.GN(x)
|
| 492 |
+
x = self.gelu(x)
|
| 493 |
+
x = x.permute([0, 2, 1])
|
| 494 |
+
return x
|
| 495 |
+
|
| 496 |
+
class ConvNormAct(nn.Module):
|
| 497 |
+
"""
|
| 498 |
+
This class defines the convolution layer with normalization and a PReLU
|
| 499 |
+
activation
|
| 500 |
+
"""
|
| 501 |
+
|
| 502 |
+
def __init__(self, nIn, nOut, kSize, stride=1, groups=1, norm_type="gLN"):
|
| 503 |
+
"""
|
| 504 |
+
:param nIn: number of input channels
|
| 505 |
+
:param nOut: number of output channels
|
| 506 |
+
:param kSize: kernel size
|
| 507 |
+
:param stride: stride rate for down-sampling. Default is 1
|
| 508 |
+
"""
|
| 509 |
+
super().__init__()
|
| 510 |
+
padding = int((kSize - 1) / 2)
|
| 511 |
+
self.conv = nn.Conv1d(
|
| 512 |
+
nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups
|
| 513 |
+
)
|
| 514 |
+
if norm_type == "gLN":
|
| 515 |
+
self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
|
| 516 |
+
if norm_type == "BN":
|
| 517 |
+
self.norm = nn.BatchNorm1d(nOut)
|
| 518 |
+
self.act = nn.PReLU()
|
| 519 |
+
|
| 520 |
+
def forward(self, input):
|
| 521 |
+
output = self.conv(input)
|
| 522 |
+
output = self.norm(output)
|
| 523 |
+
return self.act(output)
|
| 524 |
+
|
| 525 |
+
class DilatedConvNorm(nn.Module):
|
| 526 |
+
"""
|
| 527 |
+
This class defines the dilated convolution with normalized output.
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1, norm_type="gLN"):
|
| 531 |
+
"""
|
| 532 |
+
:param nIn: number of input channels
|
| 533 |
+
:param nOut: number of output channels
|
| 534 |
+
:param kSize: kernel size
|
| 535 |
+
:param stride: optional stride rate for down-sampling
|
| 536 |
+
:param d: optional dilation rate
|
| 537 |
+
"""
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.conv = nn.Conv1d(
|
| 540 |
+
nIn,
|
| 541 |
+
nOut,
|
| 542 |
+
kSize,
|
| 543 |
+
stride=stride,
|
| 544 |
+
dilation=d,
|
| 545 |
+
padding=((kSize - 1) // 2) * d,
|
| 546 |
+
groups=groups,
|
| 547 |
+
)
|
| 548 |
+
# self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
|
| 549 |
+
if norm_type == "gLN":
|
| 550 |
+
self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
|
| 551 |
+
if norm_type == "BN":
|
| 552 |
+
self.norm = nn.BatchNorm1d(nOut)
|
| 553 |
+
|
| 554 |
+
def forward(self, input):
|
| 555 |
+
output = self.conv(input)
|
| 556 |
+
return self.norm(output)
|
| 557 |
+
|
| 558 |
+
class Mlp(nn.Module):
|
| 559 |
+
def __init__(self, in_features, hidden_size, drop=0.1, norm_type="gLN"):
|
| 560 |
+
super().__init__()
|
| 561 |
+
self.fc1 = ConvNorm(
|
| 562 |
+
in_features, hidden_size, 1, bias=False, norm_type=norm_type
|
| 563 |
+
)
|
| 564 |
+
self.dwconv = nn.Conv1d(
|
| 565 |
+
hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size
|
| 566 |
+
)
|
| 567 |
+
self.act = nn.ReLU()
|
| 568 |
+
self.fc2 = ConvNorm(
|
| 569 |
+
hidden_size, in_features, 1, bias=False, norm_type=norm_type
|
| 570 |
+
)
|
| 571 |
+
self.drop = nn.Dropout(drop)
|
| 572 |
+
|
| 573 |
+
def forward(self, x):
|
| 574 |
+
x = self.fc1(x)
|
| 575 |
+
x = self.dwconv(x)
|
| 576 |
+
x = self.act(x)
|
| 577 |
+
x = self.drop(x)
|
| 578 |
+
x = self.fc2(x)
|
| 579 |
+
x = self.drop(x)
|
| 580 |
+
return x
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class InjectionMultiSum(nn.Module):
|
| 584 |
+
def __init__(self, inp: int, oup: int, kernel: int = 1, norm_type="gLN") -> None:
|
| 585 |
+
super().__init__()
|
| 586 |
+
groups = 1
|
| 587 |
+
if inp == oup:
|
| 588 |
+
groups = inp
|
| 589 |
+
self.local_embedding = HeatConvNorm(
|
| 590 |
+
inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
|
| 591 |
+
)
|
| 592 |
+
self.global_embedding = HeatConvNorm(
|
| 593 |
+
inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
|
| 594 |
+
)
|
| 595 |
+
self.global_act = HeatConvNorm(
|
| 596 |
+
inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
|
| 597 |
+
)
|
| 598 |
+
self.act = nn.Sigmoid()
|
| 599 |
+
|
| 600 |
+
def forward(self, x_l, x_g):
|
| 601 |
+
"""
|
| 602 |
+
x_g: global features
|
| 603 |
+
x_l: local features
|
| 604 |
+
"""
|
| 605 |
+
B, N, T = x_l.shape
|
| 606 |
+
local_feat = self.local_embedding(x_l)
|
| 607 |
+
|
| 608 |
+
global_act = self.global_act(x_g)
|
| 609 |
+
sig_act = torch.nn.functional.interpolate(self.act(global_act), size=T, mode="nearest")
|
| 610 |
+
# sig_act = self.act(global_act)
|
| 611 |
+
|
| 612 |
+
global_feat = self.global_embedding(x_g)
|
| 613 |
+
global_feat = torch.nn.functional.interpolate(global_feat, size=T, mode="nearest")
|
| 614 |
+
|
| 615 |
+
out = local_feat * sig_act + global_feat
|
| 616 |
+
return out
|
| 617 |
+
|
| 618 |
+
class UConvBlock(nn.Module):
|
| 619 |
+
"""
|
| 620 |
+
This class defines the block which performs successive downsampling and
|
| 621 |
+
upsampling in order to be able to analyze the input features in multiple
|
| 622 |
+
resolutions.
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
def __init__(
|
| 626 |
+
self, out_channels=128, in_channels=512, upsampling_depth=4, norm_type="gLN"
|
| 627 |
+
):
|
| 628 |
+
super().__init__()
|
| 629 |
+
self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1, norm_type=norm_type)
|
| 630 |
+
self.depth = upsampling_depth
|
| 631 |
+
self.spp_dw = nn.ModuleList()
|
| 632 |
+
self.spp_dw.append(
|
| 633 |
+
DilatedConvNorm(
|
| 634 |
+
in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1, norm_type=norm_type
|
| 635 |
+
)
|
| 636 |
+
)
|
| 637 |
+
for i in range(1, upsampling_depth):
|
| 638 |
+
self.spp_dw.append(
|
| 639 |
+
DilatedConvNorm(
|
| 640 |
+
in_channels,
|
| 641 |
+
in_channels,
|
| 642 |
+
kSize=5,
|
| 643 |
+
stride=2,
|
| 644 |
+
groups=in_channels,
|
| 645 |
+
d=1,
|
| 646 |
+
norm_type=norm_type
|
| 647 |
+
)
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
self.loc_glo_fus = nn.ModuleList([])
|
| 651 |
+
for i in range(upsampling_depth):
|
| 652 |
+
self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels, norm_type=norm_type))
|
| 653 |
+
|
| 654 |
+
self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
|
| 655 |
+
|
| 656 |
+
self.globalatt = Mlp(in_channels, in_channels, drop=0.1)
|
| 657 |
+
|
| 658 |
+
self.last_layer = nn.ModuleList([])
|
| 659 |
+
for i in range(self.depth - 1):
|
| 660 |
+
self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5, norm_type=norm_type))
|
| 661 |
+
|
| 662 |
+
def forward(self, x):
|
| 663 |
+
"""
|
| 664 |
+
:param x: input feature map
|
| 665 |
+
:return: transformed feature map
|
| 666 |
+
"""
|
| 667 |
+
residual = x.clone()
|
| 668 |
+
# Reduce --> project high-dimensional feature maps to low-dimensional space
|
| 669 |
+
output1 = self.proj_1x1(x)
|
| 670 |
+
output = [self.spp_dw[0](output1)]
|
| 671 |
+
|
| 672 |
+
# Do the downsampling process from the previous level
|
| 673 |
+
for k in range(1, self.depth):
|
| 674 |
+
out_k = self.spp_dw[k](output[-1])
|
| 675 |
+
output.append(out_k)
|
| 676 |
+
|
| 677 |
+
# global features
|
| 678 |
+
global_f = torch.zeros(
|
| 679 |
+
output[-1].shape, requires_grad=True, device=output1.device
|
| 680 |
+
)
|
| 681 |
+
for fea in output:
|
| 682 |
+
global_f = global_f + torch.nn.functional.adaptive_avg_pool1d(
|
| 683 |
+
fea, output_size=output[-1].shape[-1]
|
| 684 |
+
)
|
| 685 |
+
# global_f = global_f + fea
|
| 686 |
+
global_f = self.globalatt(global_f) # [B, N, T]
|
| 687 |
+
|
| 688 |
+
x_fused = []
|
| 689 |
+
# Gather them now in reverse order
|
| 690 |
+
for idx in range(self.depth):
|
| 691 |
+
local = output[idx]
|
| 692 |
+
x_fused.append(self.loc_glo_fus[idx](local, global_f))
|
| 693 |
+
|
| 694 |
+
expanded = None
|
| 695 |
+
for i in range(self.depth - 2, -1, -1):
|
| 696 |
+
if i == self.depth - 2:
|
| 697 |
+
expanded = self.last_layer[i](x_fused[i], x_fused[i - 1])
|
| 698 |
+
else:
|
| 699 |
+
expanded = self.last_layer[i](x_fused[i], expanded)
|
| 700 |
+
# import pdb; pdb.set_trace()
|
| 701 |
+
return self.res_conv(expanded) + residual
|
| 702 |
+
|
| 703 |
+
class EncoderLayer(torch.nn.Module):
|
| 704 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
|
| 705 |
+
super().__init__()
|
| 706 |
+
|
| 707 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 708 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 709 |
+
|
| 710 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 711 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 712 |
+
|
| 713 |
+
self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
|
| 714 |
+
|
| 715 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 716 |
+
'''
|
| 717 |
+
x: [B, N, T]
|
| 718 |
+
'''
|
| 719 |
+
x = self.g_block_1(x, pos_k)
|
| 720 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 721 |
+
x = self.l_block_1(x)
|
| 722 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 723 |
+
|
| 724 |
+
x = self.g_block_2(x, pos_k)
|
| 725 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 726 |
+
x = self.l_block_2(x)
|
| 727 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 728 |
+
|
| 729 |
+
skip = x
|
| 730 |
+
if self.downconv:
|
| 731 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 732 |
+
x = self.downconv(x)
|
| 733 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 734 |
+
# [BK, S, N]
|
| 735 |
+
return x, skip
|
| 736 |
+
|
| 737 |
+
class DecoderLayer(torch.nn.Module):
|
| 738 |
+
def __init__(self, global_blocks: dict, local_blocks: dict, spk_attention: dict):
|
| 739 |
+
super().__init__()
|
| 740 |
+
|
| 741 |
+
self.g_block_1 = GlobalBlock(**global_blocks)
|
| 742 |
+
self.l_block_1 = LocalBlock(**local_blocks)
|
| 743 |
+
|
| 744 |
+
self.g_block_2 = GlobalBlock(**global_blocks)
|
| 745 |
+
self.l_block_2 = LocalBlock(**local_blocks)
|
| 746 |
+
|
| 747 |
+
self.g_block_3 = GlobalBlock(**global_blocks)
|
| 748 |
+
self.l_block_3 = LocalBlock(**local_blocks)
|
| 749 |
+
|
| 750 |
+
def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
|
| 751 |
+
'''
|
| 752 |
+
x: [B, N, T]
|
| 753 |
+
'''
|
| 754 |
+
# [BS, K, H]
|
| 755 |
+
x = self.g_block_1(x, pos_k)
|
| 756 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 757 |
+
x = self.l_block_1(x)
|
| 758 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 759 |
+
|
| 760 |
+
x = self.g_block_2(x, pos_k)
|
| 761 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 762 |
+
x = self.l_block_2(x)
|
| 763 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 764 |
+
|
| 765 |
+
x = self.g_block_3(x, pos_k)
|
| 766 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 767 |
+
x = self.l_block_3(x)
|
| 768 |
+
x = x.permute(0, 2, 1).contiguous()
|
| 769 |
+
|
| 770 |
+
skip = x
|
| 771 |
+
|
| 772 |
+
return x, skip
|
| 773 |
+
|
| 774 |
+
class Separator(torch.nn.Module):
|
| 775 |
+
def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, simple_fusion:dict, dec_stage: dict):
|
| 776 |
+
super().__init__()
|
| 777 |
+
|
| 778 |
+
self.num_stages = num_stages
|
| 779 |
+
self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
|
| 780 |
+
|
| 781 |
+
# Temporal Contracting Part
|
| 782 |
+
self.enc_stages = torch.nn.ModuleList([])
|
| 783 |
+
for _ in range(self.num_stages):
|
| 784 |
+
self.enc_stages.append(EncoderLayer(**enc_stage, down_conv=True))
|
| 785 |
+
|
| 786 |
+
self.bottleneck_G = nn.ModuleList([
|
| 787 |
+
MultiHeadAttention(
|
| 788 |
+
n_head=enc_stage['global_blocks']['num_mha_heads'],
|
| 789 |
+
in_channels=enc_stage['global_blocks']['in_channels'],
|
| 790 |
+
dropout_rate=enc_stage['global_blocks']['dropout_rate']
|
| 791 |
+
),
|
| 792 |
+
FFN(
|
| 793 |
+
in_channels=enc_stage['global_blocks']['in_channels'],
|
| 794 |
+
dropout_rate=enc_stage['global_blocks']['dropout_rate']
|
| 795 |
+
)
|
| 796 |
+
])
|
| 797 |
+
|
| 798 |
+
# top-down fusion
|
| 799 |
+
self.loc_glo_fus = nn.ModuleList([])
|
| 800 |
+
for i in range(self.num_stages):
|
| 801 |
+
self.loc_glo_fus.append(InjectionMultiSum(simple_fusion['out_channels'], simple_fusion['out_channels']))
|
| 802 |
+
|
| 803 |
+
# Temporal Expanding Part
|
| 804 |
+
self.simple_fusion = torch.nn.ModuleList([])
|
| 805 |
+
self.dec_stages = torch.nn.ModuleList([])
|
| 806 |
+
for _ in range(self.num_stages):
|
| 807 |
+
self.simple_fusion.append(InjectionMultiSum(simple_fusion['out_channels'], simple_fusion['out_channels'], kernel=5))
|
| 808 |
+
self.dec_stages.append(DecoderLayer(**dec_stage))
|
| 809 |
+
|
| 810 |
+
def forward(self, input: torch.Tensor):
|
| 811 |
+
'''input: [B, N, L]'''
|
| 812 |
+
# feature projection
|
| 813 |
+
x, _ = self.pad_signal(input)
|
| 814 |
+
len_x = x.shape[-1]
|
| 815 |
+
# Temporal Contracting Part
|
| 816 |
+
min_len = len_x//2**(self.num_stages-1)
|
| 817 |
+
pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
|
| 818 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
| 819 |
+
pos_k, _ = self.pos_emb(pos_seq)
|
| 820 |
+
|
| 821 |
+
skip = []
|
| 822 |
+
fusion_x = torch.zeros([x.shape[0], x.shape[1], min_len], requires_grad=True, device=x.device)
|
| 823 |
+
for idx in range(self.num_stages):
|
| 824 |
+
x, skip_ = self.enc_stages[idx](x, pos_k)
|
| 825 |
+
skip.append(skip_)
|
| 826 |
+
fusion_x = fusion_x + F.adaptive_avg_pool1d(x, min_len)
|
| 827 |
+
|
| 828 |
+
global_x = self.bottleneck_G[0](fusion_x.permute(0, 2, 1).contiguous(), None, None)
|
| 829 |
+
global_x = self.bottleneck_G[1](global_x).permute(0, 2, 1).contiguous()
|
| 830 |
+
|
| 831 |
+
# Global topdown attention
|
| 832 |
+
fusion_skip = []
|
| 833 |
+
for idx in range(self.num_stages):
|
| 834 |
+
fusion_skip.append(self.loc_glo_fus[idx](skip[idx], global_x))
|
| 835 |
+
|
| 836 |
+
each_stage_outputs = []
|
| 837 |
+
# Temporal Expanding Part
|
| 838 |
+
for idx in range(self.num_stages):
|
| 839 |
+
each_stage_outputs.append(x)
|
| 840 |
+
idx_en = self.num_stages - (idx + 1)
|
| 841 |
+
x = self.simple_fusion[idx](fusion_skip[idx_en], x)
|
| 842 |
+
x, _ = self.dec_stages[idx](x, pos_k)
|
| 843 |
+
|
| 844 |
+
last_stage_output = x
|
| 845 |
+
return last_stage_output, each_stage_outputs
|
| 846 |
+
|
| 847 |
+
def pad_signal(self, input: torch.Tensor):
|
| 848 |
+
# (B, T) or (B, 1, T)
|
| 849 |
+
if input.dim() == 1: input = input.unsqueeze(0)
|
| 850 |
+
elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
| 851 |
+
elif input.dim() == 2: input = input.unsqueeze(1)
|
| 852 |
+
L = 2**self.num_stages
|
| 853 |
+
batch_size = input.size(0)
|
| 854 |
+
ndim = input.size(1)
|
| 855 |
+
nframe = input.size(2)
|
| 856 |
+
padded_len = (nframe//L + 1)*L
|
| 857 |
+
rest = 0 if nframe%L == 0 else padded_len - nframe
|
| 858 |
+
if rest > 0:
|
| 859 |
+
pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
|
| 860 |
+
input = torch.cat([input, pad], dim=-1)
|
| 861 |
+
return input, rest
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
class OutputLayer(torch.nn.Module):
|
| 865 |
+
def __init__(self, in_channels: int, out_channels: int, masking: bool = False):
|
| 866 |
+
super().__init__()
|
| 867 |
+
# feature expansion back
|
| 868 |
+
self.masking = masking
|
| 869 |
+
self.spe_block = Masking(in_channels)
|
| 870 |
+
self.end_conv1x1 = torch.nn.Sequential(
|
| 871 |
+
torch.nn.Linear(out_channels, 4*out_channels),
|
| 872 |
+
torch.nn.GLU(),
|
| 873 |
+
torch.nn.Linear(2*out_channels, in_channels))
|
| 874 |
+
|
| 875 |
+
def forward(self, x: torch.Tensor, input: torch.Tensor):
|
| 876 |
+
x = x[...,:input.shape[-1]]
|
| 877 |
+
x = x.permute([0, 2, 1])
|
| 878 |
+
x = self.end_conv1x1(x)
|
| 879 |
+
x = x.permute([0, 2, 1])
|
| 880 |
+
|
| 881 |
+
if self.masking:
|
| 882 |
+
x = self.spe_block(x, input)
|
| 883 |
+
|
| 884 |
+
return x
|
| 885 |
+
|
| 886 |
+
class AudioDecoder(torch.nn.ConvTranspose1d):
|
| 887 |
+
def __init__(self, *args, **kwargs):
|
| 888 |
+
super().__init__(*args, **kwargs)
|
| 889 |
+
|
| 890 |
+
def forward(self, x):
|
| 891 |
+
# x: [B, N, L]
|
| 892 |
+
if x.dim() not in [2, 3]:
|
| 893 |
+
raise RuntimeError("{} accept 2/3D tensor as input".format(self.__class__.__name__))
|
| 894 |
+
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
| 895 |
+
x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
|
| 896 |
+
return x
|
| 897 |
+
|
| 898 |
+
class ReconstructionPath(nn.Module):
|
| 899 |
+
def __init__(
|
| 900 |
+
self,
|
| 901 |
+
layers = [
|
| 902 |
+
'residual',
|
| 903 |
+
'residual',
|
| 904 |
+
'residual'
|
| 905 |
+
],
|
| 906 |
+
image_size=88,
|
| 907 |
+
in_channel=1,
|
| 908 |
+
init_channel=16,
|
| 909 |
+
max_dim=128,
|
| 910 |
+
# conv相关
|
| 911 |
+
input_conv_kernel_size = [7, 7, 7],
|
| 912 |
+
output_conv_kernel_size = [3, 3, 3],
|
| 913 |
+
residual_conv_kernel_size=3,
|
| 914 |
+
pad_mode="constant",
|
| 915 |
+
# attn相关
|
| 916 |
+
attn_dim_head = 32,
|
| 917 |
+
attn_heads = 8,
|
| 918 |
+
attn_dropout = 0.,
|
| 919 |
+
flash_attn = True,
|
| 920 |
+
linear_attn_dim_head = 8,
|
| 921 |
+
linear_attn_heads = 16,
|
| 922 |
+
fuse_dim=32,
|
| 923 |
+
# quantizer相关
|
| 924 |
+
num_quantizers = 1,
|
| 925 |
+
codebook_size = 256,
|
| 926 |
+
codebook_dim=64,
|
| 927 |
+
commitment_cost=0.25,
|
| 928 |
+
):
|
| 929 |
+
super().__init__()
|
| 930 |
+
input_conv_kernel_size=tuple(input_conv_kernel_size)
|
| 931 |
+
|
| 932 |
+
self.conv_in = nn.Conv3d(in_channel, init_channel, input_conv_kernel_size,padding='same')
|
| 933 |
+
|
| 934 |
+
layer_fmap_size=image_size
|
| 935 |
+
self.encoder_layers = nn.ModuleList([])
|
| 936 |
+
dim=init_channel
|
| 937 |
+
dim_out=dim
|
| 938 |
+
time_downsample_factor=1
|
| 939 |
+
|
| 940 |
+
for layer_type in layers:
|
| 941 |
+
if layer_type == 'residual':
|
| 942 |
+
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
|
| 943 |
+
|
| 944 |
+
elif layer_type == 'consecutive_residual':
|
| 945 |
+
num_consecutive = 2
|
| 946 |
+
encoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)])
|
| 947 |
+
|
| 948 |
+
elif layer_type == 'compress_space':
|
| 949 |
+
dim_out = dim * 2
|
| 950 |
+
dim_out = min(dim_out, max_dim)
|
| 951 |
+
|
| 952 |
+
encoder_layer = SpatialDownsample2x(dim, dim_out)
|
| 953 |
+
|
| 954 |
+
assert layer_fmap_size > 1
|
| 955 |
+
layer_fmap_size //= 2
|
| 956 |
+
|
| 957 |
+
elif layer_type == 'compress_time':
|
| 958 |
+
dim_out = dim * 2
|
| 959 |
+
dim_out = min(dim_out, max_dim)
|
| 960 |
+
|
| 961 |
+
encoder_layer = TimeDownsample2x(dim, dim_out)
|
| 962 |
+
|
| 963 |
+
time_downsample_factor *= 2
|
| 964 |
+
|
| 965 |
+
elif layer_type == 'attend_space':
|
| 966 |
+
attn_kwargs = dict(
|
| 967 |
+
dim = dim,
|
| 968 |
+
dim_head = attn_dim_head,
|
| 969 |
+
heads = attn_heads,
|
| 970 |
+
dropout = attn_dropout,
|
| 971 |
+
flash = flash_attn
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
encoder_layer = Sequential(
|
| 975 |
+
Residual(SpaceAttention(**attn_kwargs)),
|
| 976 |
+
Residual(FeedForward(dim))
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
elif layer_type == 'linear_attend_space':
|
| 980 |
+
linear_attn_kwargs = dict(
|
| 981 |
+
dim = dim,
|
| 982 |
+
dim_head = linear_attn_dim_head,
|
| 983 |
+
heads = linear_attn_heads
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
encoder_layer = Sequential(
|
| 987 |
+
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
|
| 988 |
+
Residual(FeedForward(dim))
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
else:
|
| 992 |
+
raise ValueError(f'unknown layer type {layer_type}')
|
| 993 |
+
|
| 994 |
+
self.encoder_layers.append(encoder_layer)
|
| 995 |
+
|
| 996 |
+
dim = dim_out
|
| 997 |
+
|
| 998 |
+
self.encoder_layers.append(Sequential(
|
| 999 |
+
Rearrange('b c ... -> b ... c'),
|
| 1000 |
+
nn.LayerNorm(dim),
|
| 1001 |
+
Rearrange('b ... c -> b c ...'),
|
| 1002 |
+
))
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def forward(self, x, semantic_fea=None):
|
| 1006 |
+
x = self.conv_in(x)
|
| 1007 |
+
for layer in self.encoder_layers:
|
| 1008 |
+
x = layer(x)
|
| 1009 |
+
z_e = x
|
| 1010 |
+
|
| 1011 |
+
z_q=z_e
|
| 1012 |
+
|
| 1013 |
+
if semantic_fea!=None:
|
| 1014 |
+
B,C,T,H,W=z_q.shape
|
| 1015 |
+
z_q=z_q.contiguous().permute(0,2,1,3,4)
|
| 1016 |
+
z_q=z_q.contiguous().view(B,T,-1)
|
| 1017 |
+
z_q=z_q + semantic_fea
|
| 1018 |
+
|
| 1019 |
+
return z_q
|
| 1020 |
+
|
| 1021 |
+
class SemanticPath(nn.Module):
|
| 1022 |
+
def __init__(
|
| 1023 |
+
self,
|
| 1024 |
+
layers = [
|
| 1025 |
+
'residual',
|
| 1026 |
+
'residual',
|
| 1027 |
+
'residual'
|
| 1028 |
+
],
|
| 1029 |
+
image_size=88,
|
| 1030 |
+
in_channel=1,
|
| 1031 |
+
init_channel=4,
|
| 1032 |
+
max_dim=32,
|
| 1033 |
+
input_conv_kernel_size = [7, 7, 7],
|
| 1034 |
+
output_conv_kernel_size = [3, 3, 3],
|
| 1035 |
+
residual_conv_kernel_size=3,
|
| 1036 |
+
pad_mode="constant",
|
| 1037 |
+
attn_dim_head = 32,
|
| 1038 |
+
attn_heads = 8,
|
| 1039 |
+
attn_dropout = 0.,
|
| 1040 |
+
flash_attn = True,
|
| 1041 |
+
linear_attn_dim_head = 8,
|
| 1042 |
+
linear_attn_heads = 16,
|
| 1043 |
+
num_quantizers = 1,
|
| 1044 |
+
codebook_size = 256,
|
| 1045 |
+
codebook_dim= 64,
|
| 1046 |
+
commitment_cost=0.25,
|
| 1047 |
+
distill_dim=1024,
|
| 1048 |
+
config=None,
|
| 1049 |
+
pretrain=None
|
| 1050 |
+
):
|
| 1051 |
+
super().__init__()
|
| 1052 |
+
input_conv_kernel_size=tuple(input_conv_kernel_size)
|
| 1053 |
+
|
| 1054 |
+
self.conv_in = nn.Conv3d(in_channel, init_channel, input_conv_kernel_size,padding='same')
|
| 1055 |
+
|
| 1056 |
+
layer_fmap_size=image_size
|
| 1057 |
+
self.encoder_layers = nn.ModuleList([])
|
| 1058 |
+
dim=init_channel
|
| 1059 |
+
dim_out=dim
|
| 1060 |
+
time_downsample_factor=1
|
| 1061 |
+
|
| 1062 |
+
for layer_type in layers:
|
| 1063 |
+
if layer_type == 'residual':
|
| 1064 |
+
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
|
| 1065 |
+
|
| 1066 |
+
elif layer_type == 'consecutive_residual':
|
| 1067 |
+
num_consecutive = 2
|
| 1068 |
+
encoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)])
|
| 1069 |
+
|
| 1070 |
+
elif layer_type == 'compress_space':
|
| 1071 |
+
dim_out = dim * 2
|
| 1072 |
+
dim_out = min(dim_out, max_dim)
|
| 1073 |
+
|
| 1074 |
+
encoder_layer = SpatialDownsample2x(dim, dim_out)
|
| 1075 |
+
|
| 1076 |
+
assert layer_fmap_size > 1
|
| 1077 |
+
layer_fmap_size //= 2
|
| 1078 |
+
|
| 1079 |
+
elif layer_type == 'compress_time':
|
| 1080 |
+
dim_out = dim * 2
|
| 1081 |
+
dim_out = min(dim_out, max_dim)
|
| 1082 |
+
|
| 1083 |
+
encoder_layer = TimeDownsample2x(dim, dim_out)
|
| 1084 |
+
|
| 1085 |
+
time_downsample_factor *= 2
|
| 1086 |
+
|
| 1087 |
+
elif layer_type == 'attend_space':
|
| 1088 |
+
attn_kwargs = dict(
|
| 1089 |
+
dim = dim,
|
| 1090 |
+
dim_head = attn_dim_head,
|
| 1091 |
+
heads = attn_heads,
|
| 1092 |
+
dropout = attn_dropout,
|
| 1093 |
+
flash = flash_attn
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
encoder_layer = Sequential(
|
| 1097 |
+
Residual(SpaceAttention(**attn_kwargs)),
|
| 1098 |
+
Residual(FeedForward(dim))
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
elif layer_type == 'linear_attend_space':
|
| 1102 |
+
linear_attn_kwargs = dict(
|
| 1103 |
+
dim = dim,
|
| 1104 |
+
dim_head = linear_attn_dim_head,
|
| 1105 |
+
heads = linear_attn_heads
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
encoder_layer = Sequential(
|
| 1109 |
+
Residual(LinearSpaceAttention(**linear_attn_kwargs)),
|
| 1110 |
+
Residual(FeedForward(dim))
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
else:
|
| 1114 |
+
raise ValueError(f'unknown layer type {layer_type}')
|
| 1115 |
+
|
| 1116 |
+
self.encoder_layers.append(encoder_layer)
|
| 1117 |
+
|
| 1118 |
+
dim = dim_out
|
| 1119 |
+
|
| 1120 |
+
self.encoder_layers.append(Sequential(
|
| 1121 |
+
Rearrange('b c ... -> b ... c'),
|
| 1122 |
+
nn.LayerNorm(dim),
|
| 1123 |
+
Rearrange('b ... c -> b c ...'),
|
| 1124 |
+
))
|
| 1125 |
+
|
| 1126 |
+
# layer_fmap_size = 3
|
| 1127 |
+
self.quantizer = ResidualVQ(
|
| 1128 |
+
dim = dim*layer_fmap_size*layer_fmap_size,
|
| 1129 |
+
num_quantizers = num_quantizers,
|
| 1130 |
+
codebook_size = codebook_size,
|
| 1131 |
+
codebook_dim = codebook_dim,
|
| 1132 |
+
quantize_dropout=False,
|
| 1133 |
+
stochastic_sample_codes = True,
|
| 1134 |
+
sample_codebook_temp = 0.1,
|
| 1135 |
+
kmeans_init = True,
|
| 1136 |
+
kmeans_iters = 10
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
def forward(self, x):
|
| 1140 |
+
x = self.conv_in(x)
|
| 1141 |
+
for layer in self.encoder_layers:
|
| 1142 |
+
x = layer(x)
|
| 1143 |
+
b,c,t,h,w=x.shape
|
| 1144 |
+
x = x.contiguous().permute(0,2,1,3,4)
|
| 1145 |
+
z_e = x.contiguous().view(b,t,-1)
|
| 1146 |
+
|
| 1147 |
+
z_q,_,_=self.quantizer(z_e)
|
| 1148 |
+
|
| 1149 |
+
return z_q
|
| 1150 |
+
|
| 1151 |
+
class VideoEncoder(nn.Module):
|
| 1152 |
+
def __init__(
|
| 1153 |
+
self,
|
| 1154 |
+
layers,
|
| 1155 |
+
image_size=88,
|
| 1156 |
+
in_channel=1,
|
| 1157 |
+
init_channel=16,
|
| 1158 |
+
max_dim=128,
|
| 1159 |
+
input_conv_kernel_size = [7, 7, 7],
|
| 1160 |
+
output_conv_kernel_size = [3, 3, 3],
|
| 1161 |
+
residual_conv_kernel_size=3,
|
| 1162 |
+
pad_mode="constant",
|
| 1163 |
+
# attn相关
|
| 1164 |
+
attn_dim_head = 32,
|
| 1165 |
+
attn_heads = 8,
|
| 1166 |
+
attn_dropout = 0.,
|
| 1167 |
+
flash_attn = True,
|
| 1168 |
+
linear_attn_dim_head = 8,
|
| 1169 |
+
linear_attn_heads = 16,
|
| 1170 |
+
num_quantizers = 1,
|
| 1171 |
+
codebook_size = 256,
|
| 1172 |
+
codebook_dim=64,
|
| 1173 |
+
commitment_cost=0.25,
|
| 1174 |
+
distill_cost=1.0,
|
| 1175 |
+
):
|
| 1176 |
+
super().__init__()
|
| 1177 |
+
self.semantic_model=SemanticPath(
|
| 1178 |
+
layers=layers,
|
| 1179 |
+
image_size=image_size,
|
| 1180 |
+
in_channel=in_channel,
|
| 1181 |
+
init_channel=init_channel,
|
| 1182 |
+
max_dim=max_dim,
|
| 1183 |
+
input_conv_kernel_size=input_conv_kernel_size,
|
| 1184 |
+
output_conv_kernel_size=output_conv_kernel_size,
|
| 1185 |
+
residual_conv_kernel_size=residual_conv_kernel_size,
|
| 1186 |
+
pad_mode=pad_mode,
|
| 1187 |
+
attn_dim_head = attn_dim_head,
|
| 1188 |
+
attn_heads = attn_heads,
|
| 1189 |
+
attn_dropout = attn_dropout,
|
| 1190 |
+
flash_attn = flash_attn,
|
| 1191 |
+
linear_attn_dim_head = linear_attn_dim_head,
|
| 1192 |
+
linear_attn_heads = linear_attn_heads,
|
| 1193 |
+
num_quantizers = num_quantizers,
|
| 1194 |
+
codebook_size = codebook_size,
|
| 1195 |
+
codebook_dim = codebook_dim,
|
| 1196 |
+
commitment_cost = commitment_cost,
|
| 1197 |
+
)
|
| 1198 |
+
self.recon_model=ReconstructionPath(
|
| 1199 |
+
layers=layers,
|
| 1200 |
+
image_size=image_size,
|
| 1201 |
+
in_channel=in_channel,
|
| 1202 |
+
init_channel=init_channel,
|
| 1203 |
+
max_dim=max_dim,
|
| 1204 |
+
input_conv_kernel_size=input_conv_kernel_size,
|
| 1205 |
+
output_conv_kernel_size=output_conv_kernel_size,
|
| 1206 |
+
residual_conv_kernel_size=residual_conv_kernel_size,
|
| 1207 |
+
pad_mode=pad_mode,
|
| 1208 |
+
attn_dim_head = attn_dim_head,
|
| 1209 |
+
attn_heads = attn_heads,
|
| 1210 |
+
attn_dropout = attn_dropout,
|
| 1211 |
+
flash_attn = flash_attn,
|
| 1212 |
+
linear_attn_dim_head = linear_attn_dim_head,
|
| 1213 |
+
linear_attn_heads = linear_attn_heads,
|
| 1214 |
+
num_quantizers = num_quantizers,
|
| 1215 |
+
codebook_size = codebook_size,
|
| 1216 |
+
codebook_dim = codebook_dim,
|
| 1217 |
+
commitment_cost = commitment_cost,
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
def forward(self, x):
|
| 1221 |
+
semantic_fea = self.semantic_model(x)
|
| 1222 |
+
return self.recon_model(x,semantic_fea)
|
| 1223 |
+
|
| 1224 |
+
class Dolphin(nn.Module, PyTorchModelHubMixin):
|
| 1225 |
+
def __init__(self,
|
| 1226 |
+
num_stages: int,
|
| 1227 |
+
sample_rate: int,
|
| 1228 |
+
module_audio_enc: dict,
|
| 1229 |
+
module_feature_projector: dict,
|
| 1230 |
+
module_separator: dict,
|
| 1231 |
+
module_output_layer: dict,
|
| 1232 |
+
module_audio_dec: dict,
|
| 1233 |
+
video_encoder_params: dict,
|
| 1234 |
+
vpre_channels=512,
|
| 1235 |
+
vmid_channels=512,
|
| 1236 |
+
vin_channels=64,
|
| 1237 |
+
vout_channels=64,):
|
| 1238 |
+
super(Dolphin, self).__init__()
|
| 1239 |
+
|
| 1240 |
+
self.pre_v1 = ConvNormAct(vpre_channels, vin_channels, kSize=3, norm_type="BN")
|
| 1241 |
+
|
| 1242 |
+
self.num_stages = num_stages
|
| 1243 |
+
self.audio_encoder = AudioEncoder(**module_audio_enc)
|
| 1244 |
+
self.feature_projector = FeatureProjector(**module_feature_projector)
|
| 1245 |
+
self.separator = Separator(**module_separator)
|
| 1246 |
+
self.out_layer = OutputLayer(**module_output_layer)
|
| 1247 |
+
self.audio_decoder = AudioDecoder(**module_audio_dec)
|
| 1248 |
+
|
| 1249 |
+
self.video_blocks = UConvBlock(vin_channels, vout_channels, 3, norm_type="BN")
|
| 1250 |
+
self.modalfuse = AVFModule(module_feature_projector["out_channels"], vout_channels)
|
| 1251 |
+
|
| 1252 |
+
self.video_encoder = VideoEncoder(**video_encoder_params)
|
| 1253 |
+
|
| 1254 |
+
@classmethod
|
| 1255 |
+
def _from_pretrained(
|
| 1256 |
+
cls,
|
| 1257 |
+
*,
|
| 1258 |
+
model_id: str,
|
| 1259 |
+
revision: str,
|
| 1260 |
+
cache_dir: str,
|
| 1261 |
+
force_download: bool,
|
| 1262 |
+
proxies: dict,
|
| 1263 |
+
resume_download: bool,
|
| 1264 |
+
local_files_only: bool,
|
| 1265 |
+
token: str,
|
| 1266 |
+
map_location: str = "cpu",
|
| 1267 |
+
strict: bool = False,
|
| 1268 |
+
**model_kwargs,
|
| 1269 |
+
):
|
| 1270 |
+
"""Load model from HuggingFace Hub with proper configuration handling."""
|
| 1271 |
+
import json
|
| 1272 |
+
from huggingface_hub import hf_hub_download
|
| 1273 |
+
|
| 1274 |
+
# Download config file
|
| 1275 |
+
config_file = hf_hub_download(
|
| 1276 |
+
repo_id=model_id,
|
| 1277 |
+
filename="config.json",
|
| 1278 |
+
revision=revision,
|
| 1279 |
+
cache_dir=cache_dir,
|
| 1280 |
+
force_download=force_download,
|
| 1281 |
+
proxies=proxies,
|
| 1282 |
+
resume_download=resume_download,
|
| 1283 |
+
local_files_only=local_files_only,
|
| 1284 |
+
token=token,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
# Load configuration
|
| 1288 |
+
with open(config_file, "r") as f:
|
| 1289 |
+
config = json.load(f)
|
| 1290 |
+
|
| 1291 |
+
# Extract only the model parameters, excluding HF metadata
|
| 1292 |
+
hf_metadata_keys = {
|
| 1293 |
+
"model_type", "task", "framework", "license", "tags",
|
| 1294 |
+
"architectures", "auto_map"
|
| 1295 |
+
}
|
| 1296 |
+
model_config = {k: v for k, v in config.items() if k not in hf_metadata_keys}
|
| 1297 |
+
|
| 1298 |
+
# Create model instance with config
|
| 1299 |
+
model = cls(**model_config)
|
| 1300 |
+
|
| 1301 |
+
# Try to download different possible model file formats
|
| 1302 |
+
import torch
|
| 1303 |
+
model_files_to_try = [
|
| 1304 |
+
"model.safetensors",
|
| 1305 |
+
]
|
| 1306 |
+
|
| 1307 |
+
state_dict = None
|
| 1308 |
+
for filename in model_files_to_try:
|
| 1309 |
+
try:
|
| 1310 |
+
model_file = hf_hub_download(
|
| 1311 |
+
repo_id=model_id,
|
| 1312 |
+
filename=filename,
|
| 1313 |
+
revision=revision,
|
| 1314 |
+
cache_dir=cache_dir,
|
| 1315 |
+
force_download=force_download,
|
| 1316 |
+
proxies=proxies,
|
| 1317 |
+
resume_download=resume_download,
|
| 1318 |
+
local_files_only=local_files_only,
|
| 1319 |
+
token=token,
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
# Try to load the state dict
|
| 1323 |
+
if filename.endswith('.safetensors'):
|
| 1324 |
+
# Handle safetensors format
|
| 1325 |
+
try:
|
| 1326 |
+
from safetensors.torch import load_file
|
| 1327 |
+
state_dict = load_file(model_file, device=map_location)
|
| 1328 |
+
except ImportError:
|
| 1329 |
+
print("safetensors not available, skipping .safetensors files")
|
| 1330 |
+
continue
|
| 1331 |
+
else:
|
| 1332 |
+
# Handle PyTorch format
|
| 1333 |
+
checkpoint = torch.load(model_file, map_location=map_location, weights_only=False)
|
| 1334 |
+
|
| 1335 |
+
# Handle different checkpoint formats
|
| 1336 |
+
if isinstance(checkpoint, dict):
|
| 1337 |
+
if 'state_dict' in checkpoint:
|
| 1338 |
+
state_dict = checkpoint['state_dict']
|
| 1339 |
+
elif 'model_state_dict' in checkpoint:
|
| 1340 |
+
state_dict = checkpoint['model_state_dict']
|
| 1341 |
+
else:
|
| 1342 |
+
state_dict = checkpoint
|
| 1343 |
+
else:
|
| 1344 |
+
state_dict = checkpoint
|
| 1345 |
+
|
| 1346 |
+
# If we successfully loaded a state dict, break
|
| 1347 |
+
if state_dict is not None:
|
| 1348 |
+
break
|
| 1349 |
+
|
| 1350 |
+
except Exception as e:
|
| 1351 |
+
print(f"Failed to load {filename}: {e}")
|
| 1352 |
+
continue
|
| 1353 |
+
|
| 1354 |
+
if state_dict is None:
|
| 1355 |
+
raise RuntimeError(f"Could not load model weights from any of the tried files: {model_files_to_try}")
|
| 1356 |
+
|
| 1357 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 1358 |
+
|
| 1359 |
+
return model
|
| 1360 |
+
|
| 1361 |
+
def forward(self, input, mouth):
|
| 1362 |
+
mouth = self.video_encoder(mouth).permute(0, 2, 1).contiguous()
|
| 1363 |
+
v=self.pre_v1(mouth)
|
| 1364 |
+
v=self.video_blocks(v)
|
| 1365 |
+
|
| 1366 |
+
encoder_output = self.audio_encoder(input)
|
| 1367 |
+
projected_feature = self.feature_projector(encoder_output)
|
| 1368 |
+
|
| 1369 |
+
projected_feature = self.modalfuse(projected_feature,v)
|
| 1370 |
+
|
| 1371 |
+
last_stage_output, each_stage_outputs = self.separator(projected_feature)
|
| 1372 |
+
|
| 1373 |
+
out_layer_output = self.out_layer(last_stage_output, encoder_output)
|
| 1374 |
+
audio=self.audio_decoder(out_layer_output)
|
| 1375 |
+
|
| 1376 |
+
return audio.unsqueeze(dim=1)
|
look2hear/models/video_compoent.py
ADDED
|
@@ -0,0 +1,876 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import einsum,Tensor
|
| 6 |
+
from functools import partial
|
| 7 |
+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
| 8 |
+
|
| 9 |
+
from beartype import beartype
|
| 10 |
+
from beartype.typing import Tuple, List
|
| 11 |
+
|
| 12 |
+
from einops import rearrange, repeat, reduce, pack, unpack
|
| 13 |
+
from einops.layers.torch import Rearrange
|
| 14 |
+
|
| 15 |
+
from typing import Union
|
| 16 |
+
|
| 17 |
+
from functools import partial
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn, einsum, Tensor
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from collections import namedtuple
|
| 25 |
+
from functools import wraps
|
| 26 |
+
from packaging import version
|
| 27 |
+
|
| 28 |
+
# constants
|
| 29 |
+
|
| 30 |
+
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
| 31 |
+
|
| 32 |
+
# helpers
|
| 33 |
+
|
| 34 |
+
def exists(val):
|
| 35 |
+
return val is not None
|
| 36 |
+
|
| 37 |
+
def default(val, d):
|
| 38 |
+
return val if exists(val) else d
|
| 39 |
+
|
| 40 |
+
def compact(arr):
|
| 41 |
+
return [*filter(exists, arr)]
|
| 42 |
+
|
| 43 |
+
def once(fn):
|
| 44 |
+
called = False
|
| 45 |
+
@wraps(fn)
|
| 46 |
+
def inner(x):
|
| 47 |
+
nonlocal called
|
| 48 |
+
if called:
|
| 49 |
+
return
|
| 50 |
+
called = True
|
| 51 |
+
return fn(x)
|
| 52 |
+
return inner
|
| 53 |
+
|
| 54 |
+
print_once = once(print)
|
| 55 |
+
|
| 56 |
+
# functions for creating causal mask
|
| 57 |
+
# need a special one for onnx cpu (no support for .triu)
|
| 58 |
+
|
| 59 |
+
def create_causal_mask(i, j, device):
|
| 60 |
+
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
| 61 |
+
|
| 62 |
+
def onnx_create_causal_mask(i, j, device):
|
| 63 |
+
r = torch.arange(i, device = device)
|
| 64 |
+
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
|
| 65 |
+
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
|
| 66 |
+
return causal_mask
|
| 67 |
+
|
| 68 |
+
# main class
|
| 69 |
+
|
| 70 |
+
class Attend(nn.Module):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
*,
|
| 74 |
+
dropout = 0.,
|
| 75 |
+
causal = False,
|
| 76 |
+
heads = None,
|
| 77 |
+
scale = None,
|
| 78 |
+
flash = False,
|
| 79 |
+
onnxable = False,
|
| 80 |
+
sdp_kwargs: dict = dict(
|
| 81 |
+
enable_flash = True,
|
| 82 |
+
enable_math = True,
|
| 83 |
+
enable_mem_efficient = True
|
| 84 |
+
)
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.scale = scale
|
| 88 |
+
|
| 89 |
+
self.causal = causal
|
| 90 |
+
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
| 91 |
+
|
| 92 |
+
self.dropout = dropout
|
| 93 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 94 |
+
|
| 95 |
+
# flash attention
|
| 96 |
+
|
| 97 |
+
self.flash = flash and torch.cuda.is_available()
|
| 98 |
+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
| 99 |
+
|
| 100 |
+
self.sdp_kwargs = sdp_kwargs
|
| 101 |
+
|
| 102 |
+
def flash_attn(
|
| 103 |
+
self,
|
| 104 |
+
q, k, v,
|
| 105 |
+
mask = None,
|
| 106 |
+
attn_bias = None
|
| 107 |
+
):
|
| 108 |
+
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
| 109 |
+
|
| 110 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 111 |
+
|
| 112 |
+
# manage scale, since scale is not customizable in sdp, hack around it
|
| 113 |
+
|
| 114 |
+
if exists(self.scale):
|
| 115 |
+
q = q * self.scale / (q.shape[-1] ** -0.5)
|
| 116 |
+
|
| 117 |
+
# Check if mask exists and expand to compatible shape
|
| 118 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
| 119 |
+
|
| 120 |
+
causal = self.causal
|
| 121 |
+
|
| 122 |
+
# in the case of kv caching with one token (q_len == 1), just turn off causal masking
|
| 123 |
+
# in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
|
| 124 |
+
|
| 125 |
+
if q_len == 1 and causal:
|
| 126 |
+
causal = False
|
| 127 |
+
|
| 128 |
+
# expand key padding mask
|
| 129 |
+
|
| 130 |
+
if exists(mask):
|
| 131 |
+
assert mask.ndim == 4
|
| 132 |
+
mask = mask.expand(batch, heads, q_len, k_len)
|
| 133 |
+
|
| 134 |
+
# handle kv cache - this should be bypassable in updated flash attention 2
|
| 135 |
+
|
| 136 |
+
if k_len > q_len and causal:
|
| 137 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| 138 |
+
if not exists(mask):
|
| 139 |
+
mask = ~causal_mask
|
| 140 |
+
else:
|
| 141 |
+
mask = mask & ~causal_mask
|
| 142 |
+
causal = False
|
| 143 |
+
|
| 144 |
+
# manually handle causal mask, if another mask was given
|
| 145 |
+
|
| 146 |
+
row_is_entirely_masked = None
|
| 147 |
+
|
| 148 |
+
if exists(mask) and causal:
|
| 149 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| 150 |
+
mask = mask & ~causal_mask
|
| 151 |
+
|
| 152 |
+
# protect against an entire row being masked out
|
| 153 |
+
|
| 154 |
+
row_is_entirely_masked = ~mask.any(dim = -1)
|
| 155 |
+
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
| 156 |
+
|
| 157 |
+
causal = False
|
| 158 |
+
|
| 159 |
+
# handle alibi positional bias
|
| 160 |
+
# convert from bool to float
|
| 161 |
+
|
| 162 |
+
if exists(attn_bias):
|
| 163 |
+
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
|
| 164 |
+
|
| 165 |
+
# if mask given, the mask would already contain the causal mask from above logic
|
| 166 |
+
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
| 167 |
+
|
| 168 |
+
mask_value = -torch.finfo(q.dtype).max
|
| 169 |
+
|
| 170 |
+
if exists(mask):
|
| 171 |
+
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
|
| 172 |
+
elif causal:
|
| 173 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| 174 |
+
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
| 175 |
+
causal = False
|
| 176 |
+
|
| 177 |
+
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
|
| 178 |
+
# make it an additive bias here
|
| 179 |
+
|
| 180 |
+
mask = attn_bias
|
| 181 |
+
|
| 182 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
| 183 |
+
|
| 184 |
+
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
|
| 185 |
+
out = F.scaled_dot_product_attention(
|
| 186 |
+
q, k, v,
|
| 187 |
+
attn_mask = mask,
|
| 188 |
+
dropout_p = self.dropout if self.training else 0.,
|
| 189 |
+
is_causal = causal
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# for a row that is entirely masked out, should zero out the output of that row token
|
| 193 |
+
|
| 194 |
+
if exists(row_is_entirely_masked):
|
| 195 |
+
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
| 196 |
+
|
| 197 |
+
return out
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self,
|
| 201 |
+
q, k, v,
|
| 202 |
+
mask = None,
|
| 203 |
+
attn_bias = None,
|
| 204 |
+
prev_attn = None
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
einstein notation
|
| 208 |
+
b - batch
|
| 209 |
+
h - heads
|
| 210 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 211 |
+
d - feature dimension
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
|
| 215 |
+
|
| 216 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 217 |
+
|
| 218 |
+
causal = self.causal
|
| 219 |
+
|
| 220 |
+
# handle kv cached decoding
|
| 221 |
+
|
| 222 |
+
if n == 1 and causal:
|
| 223 |
+
causal = False
|
| 224 |
+
|
| 225 |
+
# handle zero kv, as means for allowing network to attend to nothing
|
| 226 |
+
|
| 227 |
+
if self.flash:
|
| 228 |
+
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
|
| 229 |
+
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
|
| 230 |
+
|
| 231 |
+
dots = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale
|
| 232 |
+
|
| 233 |
+
if exists(prev_attn):
|
| 234 |
+
dots = dots + prev_attn
|
| 235 |
+
|
| 236 |
+
if exists(attn_bias):
|
| 237 |
+
dots = dots + attn_bias
|
| 238 |
+
|
| 239 |
+
i, j, dtype = *dots.shape[-2:], dots.dtype
|
| 240 |
+
|
| 241 |
+
mask_value = -torch.finfo(dots.dtype).max
|
| 242 |
+
|
| 243 |
+
if exists(mask):
|
| 244 |
+
dots = dots.masked_fill(~mask, mask_value)
|
| 245 |
+
|
| 246 |
+
if causal:
|
| 247 |
+
causal_mask = self.create_causal_mask(i, j, device = device)
|
| 248 |
+
dots = dots.masked_fill(causal_mask, mask_value)
|
| 249 |
+
|
| 250 |
+
attn = dots.softmax(dim = -1)
|
| 251 |
+
|
| 252 |
+
attn = self.attn_dropout(attn)
|
| 253 |
+
|
| 254 |
+
out = einsum(f'b h i j, b h j d -> b h i d', attn, v)
|
| 255 |
+
|
| 256 |
+
return out
|
| 257 |
+
|
| 258 |
+
def exists(v):
|
| 259 |
+
return v is not None
|
| 260 |
+
|
| 261 |
+
def default(v, d):
|
| 262 |
+
return v if exists(v) else d
|
| 263 |
+
|
| 264 |
+
def safe_get_index(it, ind, default = None):
|
| 265 |
+
if ind < len(it):
|
| 266 |
+
return it[ind]
|
| 267 |
+
return default
|
| 268 |
+
|
| 269 |
+
def pair(t):
|
| 270 |
+
return t if isinstance(t, tuple) else (t, t)
|
| 271 |
+
|
| 272 |
+
def identity(t, *args, **kwargs):
|
| 273 |
+
return t
|
| 274 |
+
|
| 275 |
+
def divisible_by(num, den):
|
| 276 |
+
return (num % den) == 0
|
| 277 |
+
|
| 278 |
+
def pack_one(t, pattern):
|
| 279 |
+
return pack([t], pattern)
|
| 280 |
+
|
| 281 |
+
def unpack_one(t, ps, pattern):
|
| 282 |
+
return unpack(t, ps, pattern)[0]
|
| 283 |
+
|
| 284 |
+
def append_dims(t, ndims: int):
|
| 285 |
+
return t.reshape(*t.shape, *((1,) * ndims))
|
| 286 |
+
|
| 287 |
+
def is_odd(n):
|
| 288 |
+
return not divisible_by(n, 2)
|
| 289 |
+
|
| 290 |
+
def maybe_del_attr_(o, attr):
|
| 291 |
+
if hasattr(o, attr):
|
| 292 |
+
delattr(o, attr)
|
| 293 |
+
|
| 294 |
+
def cast_tuple(t, length = 1):
|
| 295 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
| 296 |
+
class ResBlock(nn.Module):
|
| 297 |
+
def __init__(self, in_channel, channel):
|
| 298 |
+
super().__init__()
|
| 299 |
+
|
| 300 |
+
self.conv = nn.Sequential(
|
| 301 |
+
nn.ReLU(),
|
| 302 |
+
nn.Conv2d(in_channel, channel, 3, padding=1),
|
| 303 |
+
nn.ReLU(inplace=True),
|
| 304 |
+
nn.Conv2d(channel, in_channel, 1),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def forward(self, input):
|
| 308 |
+
out = self.conv(input)
|
| 309 |
+
out += input
|
| 310 |
+
|
| 311 |
+
return out
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class EncoderAE(nn.Module):
|
| 315 |
+
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
if stride == 4:
|
| 319 |
+
blocks = [
|
| 320 |
+
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
| 321 |
+
nn.ReLU(inplace=True),
|
| 322 |
+
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
|
| 323 |
+
nn.ReLU(inplace=True),
|
| 324 |
+
nn.Conv2d(channel, channel, 3, padding=1),
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
elif stride == 2:
|
| 328 |
+
blocks = [
|
| 329 |
+
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
| 330 |
+
nn.ReLU(inplace=True),
|
| 331 |
+
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
| 332 |
+
]
|
| 333 |
+
|
| 334 |
+
for i in range(n_res_block):
|
| 335 |
+
blocks.append(ResBlock(channel, n_res_channel))
|
| 336 |
+
|
| 337 |
+
blocks.append(nn.ReLU(inplace=True))
|
| 338 |
+
|
| 339 |
+
self.blocks = nn.Sequential(*blocks)
|
| 340 |
+
|
| 341 |
+
def forward(self, input):
|
| 342 |
+
return self.blocks(input)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class DecoderAE(nn.Module):
|
| 346 |
+
def __init__(
|
| 347 |
+
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
|
| 348 |
+
):
|
| 349 |
+
super().__init__()
|
| 350 |
+
|
| 351 |
+
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
|
| 352 |
+
|
| 353 |
+
for i in range(n_res_block):
|
| 354 |
+
blocks.append(ResBlock(channel, n_res_channel))
|
| 355 |
+
|
| 356 |
+
blocks.append(nn.ReLU(inplace=True))
|
| 357 |
+
|
| 358 |
+
if stride == 4:
|
| 359 |
+
blocks.extend(
|
| 360 |
+
[
|
| 361 |
+
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
|
| 362 |
+
nn.ReLU(inplace=True),
|
| 363 |
+
nn.ConvTranspose2d(
|
| 364 |
+
channel // 2, out_channel, 4, stride=2, padding=1
|
| 365 |
+
),
|
| 366 |
+
]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
elif stride == 2:
|
| 370 |
+
blocks.append(
|
| 371 |
+
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.blocks = nn.Sequential(*blocks)
|
| 375 |
+
|
| 376 |
+
def forward(self, input):
|
| 377 |
+
return self.blocks(input)
|
| 378 |
+
|
| 379 |
+
class CausalConv3d(nn.Module):
|
| 380 |
+
# 因果三维卷积,实则和直接三维卷积区别不大
|
| 381 |
+
@beartype
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
chan_in,
|
| 385 |
+
chan_out,
|
| 386 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 387 |
+
pad_mode = 'constant',
|
| 388 |
+
**kwargs
|
| 389 |
+
):
|
| 390 |
+
super().__init__()
|
| 391 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
| 392 |
+
|
| 393 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 394 |
+
|
| 395 |
+
# 这里以及下文中的height_pad,weight_pad的设置都是为了最后的HW size不变
|
| 396 |
+
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
| 397 |
+
|
| 398 |
+
dilation = kwargs.pop('dilation', 1)
|
| 399 |
+
stride = kwargs.pop('stride', 1)
|
| 400 |
+
|
| 401 |
+
self.pad_mode = pad_mode
|
| 402 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
| 403 |
+
height_pad = height_kernel_size // 2
|
| 404 |
+
width_pad = width_kernel_size // 2
|
| 405 |
+
|
| 406 |
+
self.time_pad = time_pad
|
| 407 |
+
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
| 408 |
+
|
| 409 |
+
stride = (stride, 1, 1)
|
| 410 |
+
dilation = (dilation, 1, 1)
|
| 411 |
+
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
|
| 412 |
+
|
| 413 |
+
def forward(self, x):
|
| 414 |
+
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
|
| 415 |
+
|
| 416 |
+
x = F.pad(x, self.time_causal_padding, mode = pad_mode)
|
| 417 |
+
return self.conv(x)
|
| 418 |
+
|
| 419 |
+
class SqueezeExcite(nn.Module):
|
| 420 |
+
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
|
| 421 |
+
# 一个轻量化的 channel-wise attn
|
| 422 |
+
def __init__(
|
| 423 |
+
self,
|
| 424 |
+
dim,
|
| 425 |
+
*,
|
| 426 |
+
dim_out = None,
|
| 427 |
+
dim_hidden_min = 16,
|
| 428 |
+
init_bias = -10
|
| 429 |
+
):
|
| 430 |
+
super().__init__()
|
| 431 |
+
dim_out = default(dim_out, dim)
|
| 432 |
+
|
| 433 |
+
self.to_k = nn.Conv2d(dim, 1, 1)
|
| 434 |
+
dim_hidden = max(dim_hidden_min, dim_out // 2)
|
| 435 |
+
|
| 436 |
+
self.net = nn.Sequential(
|
| 437 |
+
nn.Conv2d(dim, dim_hidden, 1),
|
| 438 |
+
nn.LeakyReLU(0.1),
|
| 439 |
+
nn.Conv2d(dim_hidden, dim_out, 1),
|
| 440 |
+
nn.Sigmoid()
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
nn.init.zeros_(self.net[-2].weight)
|
| 444 |
+
nn.init.constant_(self.net[-2].bias, init_bias)
|
| 445 |
+
|
| 446 |
+
def forward(self, x):
|
| 447 |
+
orig_input, batch = x, x.shape[0]
|
| 448 |
+
is_video = x.ndim == 5
|
| 449 |
+
|
| 450 |
+
if is_video:
|
| 451 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 452 |
+
|
| 453 |
+
# 根据HW经过conv得到context特征图
|
| 454 |
+
context = self.to_k(x)
|
| 455 |
+
|
| 456 |
+
context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim = -1)
|
| 457 |
+
spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')
|
| 458 |
+
|
| 459 |
+
out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
|
| 460 |
+
out = rearrange(out, '... -> ... 1')
|
| 461 |
+
gates = self.net(out)
|
| 462 |
+
|
| 463 |
+
if is_video:
|
| 464 |
+
gates = rearrange(gates, '(b f) c h w -> b c f h w', b = batch)
|
| 465 |
+
|
| 466 |
+
return gates * orig_input
|
| 467 |
+
|
| 468 |
+
class Residual(nn.Module):
|
| 469 |
+
@beartype
|
| 470 |
+
def __init__(self, fn: nn.Module):
|
| 471 |
+
super().__init__()
|
| 472 |
+
self.fn = fn
|
| 473 |
+
|
| 474 |
+
def forward(self, x, **kwargs):
|
| 475 |
+
return self.fn(x, **kwargs) + x
|
| 476 |
+
|
| 477 |
+
def ResidualUnit(
|
| 478 |
+
dim,
|
| 479 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 480 |
+
pad_mode: str = 'constant'
|
| 481 |
+
):
|
| 482 |
+
net = nn.Sequential(
|
| 483 |
+
# 因果3D卷积
|
| 484 |
+
# CausalConv3d(dim, dim, kernel_size, pad_mode = pad_mode),
|
| 485 |
+
nn.Conv3d(dim, dim, kernel_size,padding='same'),
|
| 486 |
+
nn.ELU(),
|
| 487 |
+
nn.Conv3d(dim, dim, 1),
|
| 488 |
+
nn.ELU(),
|
| 489 |
+
# 一个channel wise的conv1d+softmax的global context attn
|
| 490 |
+
SqueezeExcite(dim)
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
return Residual(net)
|
| 494 |
+
|
| 495 |
+
# strided conv downsamples
|
| 496 |
+
|
| 497 |
+
class SpatialDownsample2x(nn.Module):
|
| 498 |
+
def __init__(
|
| 499 |
+
self,
|
| 500 |
+
dim,
|
| 501 |
+
dim_out = None,
|
| 502 |
+
kernel_size = 3,
|
| 503 |
+
antialias = False
|
| 504 |
+
):
|
| 505 |
+
super().__init__()
|
| 506 |
+
dim_out = default(dim_out, dim)
|
| 507 |
+
self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride = 2, padding = kernel_size // 2)
|
| 508 |
+
|
| 509 |
+
def forward(self, x):
|
| 510 |
+
|
| 511 |
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
| 512 |
+
x, ps = pack_one(x, '* c h w')
|
| 513 |
+
|
| 514 |
+
out = self.conv(x)
|
| 515 |
+
|
| 516 |
+
out = unpack_one(out, ps, '* c h w')
|
| 517 |
+
out = rearrange(out, 'b t c h w -> b c t h w')
|
| 518 |
+
return out
|
| 519 |
+
|
| 520 |
+
class TimeDownsample2x(nn.Module):
|
| 521 |
+
def __init__(
|
| 522 |
+
self,
|
| 523 |
+
dim,
|
| 524 |
+
dim_out = None,
|
| 525 |
+
kernel_size = 3,
|
| 526 |
+
antialias = False
|
| 527 |
+
):
|
| 528 |
+
super().__init__()
|
| 529 |
+
dim_out = default(dim_out, dim)
|
| 530 |
+
self.time_causal_padding = (kernel_size - 1, 0)
|
| 531 |
+
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2)
|
| 532 |
+
|
| 533 |
+
def forward(self, x):
|
| 534 |
+
x = rearrange(x, 'b c t h w -> b h w c t')
|
| 535 |
+
x, ps = pack_one(x, '* c t')
|
| 536 |
+
|
| 537 |
+
x = F.pad(x, self.time_causal_padding)
|
| 538 |
+
out = self.conv(x)
|
| 539 |
+
|
| 540 |
+
out = unpack_one(out, ps, '* c t')
|
| 541 |
+
out = rearrange(out, 'b h w c t -> b c t h w')
|
| 542 |
+
return out
|
| 543 |
+
|
| 544 |
+
# depth to space upsamples
|
| 545 |
+
|
| 546 |
+
class SpatialUpsample2x(nn.Module):
|
| 547 |
+
def __init__(
|
| 548 |
+
self,
|
| 549 |
+
dim,
|
| 550 |
+
dim_out = None
|
| 551 |
+
):
|
| 552 |
+
super().__init__()
|
| 553 |
+
dim_out = default(dim_out, dim)
|
| 554 |
+
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
| 555 |
+
|
| 556 |
+
self.net = nn.Sequential(
|
| 557 |
+
conv,
|
| 558 |
+
nn.SiLU(),
|
| 559 |
+
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
self.init_conv_(conv)
|
| 563 |
+
|
| 564 |
+
def init_conv_(self, conv):
|
| 565 |
+
o, i, h, w = conv.weight.shape
|
| 566 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
| 567 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 568 |
+
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
| 569 |
+
|
| 570 |
+
conv.weight.data.copy_(conv_weight)
|
| 571 |
+
nn.init.zeros_(conv.bias.data)
|
| 572 |
+
|
| 573 |
+
def forward(self, x):
|
| 574 |
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
| 575 |
+
x, ps = pack_one(x, '* c h w')
|
| 576 |
+
|
| 577 |
+
out = self.net(x)
|
| 578 |
+
|
| 579 |
+
out = unpack_one(out, ps, '* c h w')
|
| 580 |
+
out = rearrange(out, 'b t c h w -> b c t h w')
|
| 581 |
+
return out
|
| 582 |
+
|
| 583 |
+
class TimeUpsample2x(nn.Module):
|
| 584 |
+
def __init__(
|
| 585 |
+
self,
|
| 586 |
+
dim,
|
| 587 |
+
dim_out = None
|
| 588 |
+
):
|
| 589 |
+
super().__init__()
|
| 590 |
+
dim_out = default(dim_out, dim)
|
| 591 |
+
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
| 592 |
+
|
| 593 |
+
self.net = nn.Sequential(
|
| 594 |
+
conv,
|
| 595 |
+
nn.SiLU(),
|
| 596 |
+
Rearrange('b (c p) t -> b c (t p)', p = 2)
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
self.init_conv_(conv)
|
| 600 |
+
|
| 601 |
+
def init_conv_(self, conv):
|
| 602 |
+
o, i, t = conv.weight.shape
|
| 603 |
+
conv_weight = torch.empty(o // 2, i, t)
|
| 604 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 605 |
+
conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
|
| 606 |
+
|
| 607 |
+
conv.weight.data.copy_(conv_weight)
|
| 608 |
+
nn.init.zeros_(conv.bias.data)
|
| 609 |
+
|
| 610 |
+
def forward(self, x):
|
| 611 |
+
x = rearrange(x, 'b c t h w -> b h w c t')
|
| 612 |
+
x, ps = pack_one(x, '* c t')
|
| 613 |
+
|
| 614 |
+
out = self.net(x)
|
| 615 |
+
|
| 616 |
+
out = unpack_one(out, ps, '* c t')
|
| 617 |
+
out = rearrange(out, 'b h w c t -> b c t h w')
|
| 618 |
+
return out
|
| 619 |
+
|
| 620 |
+
class RMSNorm(nn.Module):
|
| 621 |
+
def __init__(
|
| 622 |
+
self,
|
| 623 |
+
dim,
|
| 624 |
+
channel_first = False,
|
| 625 |
+
images = False,
|
| 626 |
+
bias = False
|
| 627 |
+
):
|
| 628 |
+
super().__init__()
|
| 629 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 630 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 631 |
+
|
| 632 |
+
self.channel_first = channel_first
|
| 633 |
+
self.scale = dim ** 0.5
|
| 634 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 635 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 636 |
+
|
| 637 |
+
def forward(self, x):
|
| 638 |
+
return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 639 |
+
|
| 640 |
+
class AdaptiveRMSNorm(nn.Module):
|
| 641 |
+
def __init__(
|
| 642 |
+
self,
|
| 643 |
+
dim,
|
| 644 |
+
*,
|
| 645 |
+
dim_cond,
|
| 646 |
+
channel_first = False,
|
| 647 |
+
images = False,
|
| 648 |
+
bias = False
|
| 649 |
+
):
|
| 650 |
+
super().__init__()
|
| 651 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 652 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 653 |
+
|
| 654 |
+
self.dim_cond = dim_cond
|
| 655 |
+
self.channel_first = channel_first
|
| 656 |
+
self.scale = dim ** 0.5
|
| 657 |
+
|
| 658 |
+
self.to_gamma = nn.Linear(dim_cond, dim)
|
| 659 |
+
self.to_bias = nn.Linear(dim_cond, dim) if bias else None
|
| 660 |
+
|
| 661 |
+
nn.init.zeros_(self.to_gamma.weight)
|
| 662 |
+
nn.init.ones_(self.to_gamma.bias)
|
| 663 |
+
|
| 664 |
+
if bias:
|
| 665 |
+
nn.init.zeros_(self.to_bias.weight)
|
| 666 |
+
nn.init.zeros_(self.to_bias.bias)
|
| 667 |
+
|
| 668 |
+
@beartype
|
| 669 |
+
def forward(self, x: Tensor, *, cond: Tensor):
|
| 670 |
+
batch = x.shape[0]
|
| 671 |
+
assert cond.shape == (batch, self.dim_cond)
|
| 672 |
+
|
| 673 |
+
gamma = self.to_gamma(cond)
|
| 674 |
+
|
| 675 |
+
bias = 0.
|
| 676 |
+
if exists(self.to_bias):
|
| 677 |
+
bias = self.to_bias(cond)
|
| 678 |
+
|
| 679 |
+
if self.channel_first:
|
| 680 |
+
gamma = append_dims(gamma, x.ndim - 2)
|
| 681 |
+
|
| 682 |
+
if exists(self.to_bias):
|
| 683 |
+
bias = append_dims(bias, x.ndim - 2)
|
| 684 |
+
|
| 685 |
+
return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * gamma + bias
|
| 686 |
+
|
| 687 |
+
class Attention(nn.Module):
|
| 688 |
+
@beartype
|
| 689 |
+
def __init__(
|
| 690 |
+
self,
|
| 691 |
+
*,
|
| 692 |
+
dim,
|
| 693 |
+
dim_cond: Union[int,None] = None,
|
| 694 |
+
causal = False,
|
| 695 |
+
dim_head = 32,
|
| 696 |
+
heads = 8,
|
| 697 |
+
flash = False,
|
| 698 |
+
dropout = 0.,
|
| 699 |
+
num_memory_kv = 4
|
| 700 |
+
):
|
| 701 |
+
super().__init__()
|
| 702 |
+
dim_inner = dim_head * heads
|
| 703 |
+
|
| 704 |
+
self.need_cond = exists(dim_cond)
|
| 705 |
+
|
| 706 |
+
if self.need_cond:
|
| 707 |
+
self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
|
| 708 |
+
else:
|
| 709 |
+
self.norm = RMSNorm(dim)
|
| 710 |
+
|
| 711 |
+
self.to_qkv = nn.Sequential(
|
| 712 |
+
nn.Linear(dim, dim_inner * 3, bias = False),
|
| 713 |
+
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
assert num_memory_kv > 0
|
| 717 |
+
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
|
| 718 |
+
|
| 719 |
+
self.attend = Attend(
|
| 720 |
+
causal = causal,
|
| 721 |
+
dropout = dropout,
|
| 722 |
+
flash = flash
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
self.to_out = nn.Sequential(
|
| 726 |
+
Rearrange('b h n d -> b n (h d)'),
|
| 727 |
+
nn.Linear(dim_inner, dim, bias = False)
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
@beartype
|
| 731 |
+
def forward(
|
| 732 |
+
self,
|
| 733 |
+
x,
|
| 734 |
+
mask: Union[Tensor,None] = None,
|
| 735 |
+
cond: Union[Tensor,None] = None
|
| 736 |
+
):
|
| 737 |
+
maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
|
| 738 |
+
|
| 739 |
+
x = self.norm(x, **maybe_cond_kwargs)
|
| 740 |
+
|
| 741 |
+
q, k, v = self.to_qkv(x)
|
| 742 |
+
|
| 743 |
+
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv)
|
| 744 |
+
k = torch.cat((mk, k), dim = -2)
|
| 745 |
+
v = torch.cat((mv, v), dim = -2)
|
| 746 |
+
|
| 747 |
+
out = self.attend(q, k, v, mask = mask)
|
| 748 |
+
return self.to_out(out)
|
| 749 |
+
|
| 750 |
+
class LinearAttention(nn.Module):
|
| 751 |
+
"""
|
| 752 |
+
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
|
| 753 |
+
"""
|
| 754 |
+
|
| 755 |
+
@beartype
|
| 756 |
+
def __init__(
|
| 757 |
+
self,
|
| 758 |
+
*,
|
| 759 |
+
dim,
|
| 760 |
+
dim_cond: Union[int,None] = None,
|
| 761 |
+
dim_head = 8,
|
| 762 |
+
heads = 8,
|
| 763 |
+
dropout = 0.
|
| 764 |
+
):
|
| 765 |
+
super().__init__()
|
| 766 |
+
dim_inner = dim_head * heads
|
| 767 |
+
|
| 768 |
+
self.need_cond = exists(dim_cond)
|
| 769 |
+
|
| 770 |
+
if self.need_cond:
|
| 771 |
+
self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
|
| 772 |
+
else:
|
| 773 |
+
self.norm = RMSNorm(dim)
|
| 774 |
+
|
| 775 |
+
self.attn = TaylorSeriesLinearAttn(
|
| 776 |
+
dim = dim,
|
| 777 |
+
dim_head = dim_head,
|
| 778 |
+
heads = heads
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
def forward(
|
| 782 |
+
self,
|
| 783 |
+
x,
|
| 784 |
+
cond: Union[Tensor,None] = None
|
| 785 |
+
):
|
| 786 |
+
maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
|
| 787 |
+
|
| 788 |
+
x = self.norm(x, **maybe_cond_kwargs)
|
| 789 |
+
|
| 790 |
+
return self.attn(x)
|
| 791 |
+
|
| 792 |
+
class LinearSpaceAttention(LinearAttention):
|
| 793 |
+
def forward(self, x, *args, **kwargs):
|
| 794 |
+
x = rearrange(x, 'b c ... h w -> b ... h w c')
|
| 795 |
+
x, batch_ps = pack_one(x, '* h w c')
|
| 796 |
+
x, seq_ps = pack_one(x, 'b * c')
|
| 797 |
+
|
| 798 |
+
x = super().forward(x, *args, **kwargs)
|
| 799 |
+
|
| 800 |
+
x = unpack_one(x, seq_ps, 'b * c')
|
| 801 |
+
x = unpack_one(x, batch_ps, '* h w c')
|
| 802 |
+
return rearrange(x, 'b ... h w c -> b c ... h w')
|
| 803 |
+
|
| 804 |
+
class SpaceAttention(Attention):
|
| 805 |
+
def forward(self, x, *args, **kwargs):
|
| 806 |
+
x = rearrange(x, 'b c t h w -> b t h w c')
|
| 807 |
+
x, batch_ps = pack_one(x, '* h w c')
|
| 808 |
+
x, seq_ps = pack_one(x, 'b * c')
|
| 809 |
+
|
| 810 |
+
x = super().forward(x, *args, **kwargs)
|
| 811 |
+
|
| 812 |
+
x = unpack_one(x, seq_ps, 'b * c')
|
| 813 |
+
x = unpack_one(x, batch_ps, '* h w c')
|
| 814 |
+
return rearrange(x, 'b t h w c -> b c t h w')
|
| 815 |
+
|
| 816 |
+
class TimeAttention(Attention):
|
| 817 |
+
def forward(self, x, *args, **kwargs):
|
| 818 |
+
x = rearrange(x, 'b c t h w -> b h w t c')
|
| 819 |
+
x, batch_ps = pack_one(x, '* t c')
|
| 820 |
+
|
| 821 |
+
x = super().forward(x, *args, **kwargs)
|
| 822 |
+
|
| 823 |
+
x = unpack_one(x, batch_ps, '* t c')
|
| 824 |
+
return rearrange(x, 'b h w t c -> b c t h w')
|
| 825 |
+
|
| 826 |
+
class GEGLU(nn.Module):
|
| 827 |
+
def forward(self, x):
|
| 828 |
+
x, gate = x.chunk(2, dim = 1)
|
| 829 |
+
return F.gelu(gate) * x
|
| 830 |
+
|
| 831 |
+
class FeedForward(nn.Module):
|
| 832 |
+
@beartype
|
| 833 |
+
def __init__(
|
| 834 |
+
self,
|
| 835 |
+
dim,
|
| 836 |
+
*,
|
| 837 |
+
dim_cond: Union[int,None] = None,
|
| 838 |
+
mult = 4,
|
| 839 |
+
images = False
|
| 840 |
+
):
|
| 841 |
+
super().__init__()
|
| 842 |
+
conv_klass = nn.Conv2d if images else nn.Conv3d
|
| 843 |
+
|
| 844 |
+
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond = dim_cond)
|
| 845 |
+
|
| 846 |
+
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first = True, images = images)
|
| 847 |
+
|
| 848 |
+
dim_inner = int(dim * mult * 2 / 3)
|
| 849 |
+
|
| 850 |
+
self.norm = maybe_adaptive_norm_klass(dim)
|
| 851 |
+
|
| 852 |
+
self.net = Sequential(
|
| 853 |
+
conv_klass(dim, dim_inner * 2, 1),
|
| 854 |
+
GEGLU(),
|
| 855 |
+
conv_klass(dim_inner, dim, 1)
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
@beartype
|
| 859 |
+
def forward(
|
| 860 |
+
self,
|
| 861 |
+
x: Tensor,
|
| 862 |
+
*,
|
| 863 |
+
cond: Union[Tensor,None] = None
|
| 864 |
+
):
|
| 865 |
+
maybe_cond_kwargs = dict(cond = cond) if exists(cond) else dict()
|
| 866 |
+
|
| 867 |
+
x = self.norm(x, **maybe_cond_kwargs)
|
| 868 |
+
return self.net(x)
|
| 869 |
+
|
| 870 |
+
def Sequential(*modules):
|
| 871 |
+
modules = [*filter(exists, modules)]
|
| 872 |
+
|
| 873 |
+
if len(modules) == 0:
|
| 874 |
+
return nn.Identity()
|
| 875 |
+
|
| 876 |
+
return nn.Sequential(*modules)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
torchvision
|
| 4 |
+
moviepy
|
| 5 |
+
face_alignment
|
| 6 |
+
beartype
|
| 7 |
+
taylor_series_linear_attention
|
| 8 |
+
huggingface_hub
|
| 9 |
+
einops
|
| 10 |
+
vector_quantize_pytorch
|
| 11 |
+
spaces
|
| 12 |
+
tf-keras
|
| 13 |
+
retina-face
|
| 14 |
+
safetensors
|