JusperLee commited on
Commit
0cd6025
·
0 Parent(s):

clean repo without raw binaries

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ .DS_Store
12
+ dist
13
+ dist-ssr
14
+ coverage
15
+ *.local
16
+
17
+ /cypress/videos/
18
+ /cypress/screenshots/
19
+
20
+ # Editor directories and files
21
+ .vscode/*
22
+ !.vscode/extensions.json
23
+ .idea
24
+ *.suo
25
+ *.ntvs*
26
+ *.njsproj
27
+ *.sln
28
+ *.sw?
29
+ yarn.lock
30
+
31
+ tmp/*
32
+ .gradio
33
+ *.pyc
Inference.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import os
4
+ import argparse
5
+ import face_alignment
6
+ import torch
7
+ import torchaudio
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image, ImageDraw
11
+ from moviepy import *
12
+ from collections import deque
13
+ from skimage import transform as tf
14
+ import yaml
15
+
16
+ from look2hear.models import Dolphin
17
+ from look2hear.datas.transform import get_preprocessing_pipelines
18
+
19
+ from face_detection_utils import detect_faces
20
+
21
+ # -- Landmark interpolation:
22
+ def linear_interpolate(landmarks, start_idx, stop_idx):
23
+ start_landmarks = landmarks[start_idx]
24
+ stop_landmarks = landmarks[stop_idx]
25
+ delta = stop_landmarks - start_landmarks
26
+ for idx in range(1, stop_idx-start_idx):
27
+ landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta
28
+ return landmarks
29
+
30
+ # -- Face Transformation
31
+ def warp_img(src, dst, img, std_size):
32
+ tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix
33
+ warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image
34
+ warped = warped * 255 # note output from wrap is double image (value range [0,1])
35
+ warped = warped.astype('uint8')
36
+ return warped, tform
37
+
38
+ def apply_transform(transform, img, std_size):
39
+ warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size)
40
+ warped = warped * 255 # note output from wrap is double image (value range [0,1])
41
+ warped = warped.astype('uint8')
42
+ return warped
43
+
44
+ # -- Crop
45
+ def cut_patch(img, landmarks, height, width, threshold=5):
46
+
47
+ center_x, center_y = np.mean(landmarks, axis=0)
48
+
49
+ if center_y - height < 0:
50
+ center_y = height
51
+ if center_y - height < 0 - threshold:
52
+ raise Exception('too much bias in height')
53
+ if center_x - width < 0:
54
+ center_x = width
55
+ if center_x - width < 0 - threshold:
56
+ raise Exception('too much bias in width')
57
+
58
+ if center_y + height > img.shape[0]:
59
+ center_y = img.shape[0] - height
60
+ if center_y + height > img.shape[0] + threshold:
61
+ raise Exception('too much bias in height')
62
+ if center_x + width > img.shape[1]:
63
+ center_x = img.shape[1] - width
64
+ if center_x + width > img.shape[1] + threshold:
65
+ raise Exception('too much bias in width')
66
+
67
+ cutted_img = np.copy(img[ int(round(center_y) - round(height)): int(round(center_y) + round(height)),
68
+ int(round(center_x) - round(width)): int(round(center_x) + round(width))])
69
+ return cutted_img
70
+
71
+ # -- RGB to GRAY
72
+ def convert_bgr2gray(data):
73
+ return np.stack([cv2.cvtColor(_, cv2.COLOR_BGR2GRAY) for _ in data], axis=0)
74
+
75
+
76
+ def save2npz(filename, data=None):
77
+ assert data is not None, "data is {}".format(data)
78
+ if not os.path.exists(os.path.dirname(filename)):
79
+ os.makedirs(os.path.dirname(filename))
80
+ np.savez_compressed(filename, data=data)
81
+
82
+ def read_video(filename):
83
+ """Read video frames using MoviePy for better compatibility"""
84
+ try:
85
+ video_clip = VideoFileClip(filename)
86
+ for frame in video_clip.iter_frames():
87
+ # Convert RGB to BGR to match cv2 format
88
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
89
+ yield frame_bgr
90
+ video_clip.close()
91
+ except Exception as e:
92
+ print(f"Error reading video {filename}: {e}")
93
+ return
94
+
95
+ def face2head(boxes, scale=1.5):
96
+ new_boxes = []
97
+ for box in boxes:
98
+ width = box[2] - box[0]
99
+ height= box[3] - box[1]
100
+ width_center = (box[2] + box[0]) / 2
101
+ height_center = (box[3] + box[1]) / 2
102
+ square_width = int(max(width, height) * scale)
103
+ new_box = [width_center - square_width/2, height_center - square_width/2, width_center + square_width/2, height_center + square_width/2]
104
+ new_boxes.append(new_box)
105
+ return new_boxes
106
+
107
+ def bb_intersection_over_union(boxA, boxB):
108
+ # determine the (x, y)-coordinates of the intersection rectangle
109
+ xA = max(boxA[0], boxB[0])
110
+ yA = max(boxA[1], boxB[1])
111
+ xB = min(boxA[2], boxB[2])
112
+ yB = min(boxA[3], boxB[3])
113
+ # compute the area of intersection rectangle
114
+ interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
115
+ # compute the area of both the prediction and ground-truth
116
+ # rectangles
117
+ boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
118
+ boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
119
+ # compute the intersection over union by taking the intersection
120
+ # area and dividing it by the sum of prediction + ground-truth
121
+ # areas - the interesection area
122
+ iou = interArea / float(boxAArea + boxBArea - interArea)
123
+ # return the intersection over union value
124
+ return iou
125
+
126
+ def detectface(video_input_path, output_path, detect_every_N_frame, scalar_face_detection, number_of_speakers):
127
+ device = torch.device('cuda' if torch.cuda.get_device_name() else 'cpu')
128
+ print('Running on device: {}'.format(device))
129
+ os.makedirs(os.path.join(output_path, 'faces'), exist_ok=True)
130
+ os.makedirs(os.path.join(output_path, 'landmark'), exist_ok=True)
131
+
132
+ landmarks_dic = {}
133
+ faces_dic = {}
134
+ boxes_dic = {}
135
+
136
+ for i in range(number_of_speakers):
137
+ landmarks_dic[i] = []
138
+ faces_dic[i] = []
139
+ boxes_dic[i] = []
140
+
141
+ video_clip = VideoFileClip(video_input_path)
142
+ print("Video statistics: ", video_clip.w, video_clip.h, (video_clip.w, video_clip.h), video_clip.fps)
143
+ frames = [Image.fromarray(frame) for frame in video_clip.iter_frames()]
144
+ print('Number of frames in video: ', len(frames))
145
+ video_clip.close()
146
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
147
+
148
+ for i, frame in enumerate(frames):
149
+ print('\rTracking frame: {}'.format(i + 1), end='')
150
+
151
+ # Detect faces every N frames
152
+ if i % detect_every_N_frame == 0:
153
+ frame_array = np.array(frame)
154
+
155
+ detected_boxes, _ = detect_faces(
156
+ frame_array,
157
+ threshold=0.9,
158
+ allow_upscaling=False,
159
+ )
160
+
161
+ if detected_boxes is None or len(detected_boxes) == 0:
162
+ detected_boxes, _ = detect_faces(
163
+ frame_array,
164
+ threshold=0.7,
165
+ allow_upscaling=True,
166
+ )
167
+
168
+ if detected_boxes is not None and len(detected_boxes) > 0:
169
+ detected_boxes = detected_boxes[:number_of_speakers]
170
+ detected_boxes = face2head(detected_boxes, scalar_face_detection)
171
+ else:
172
+ detected_boxes = []
173
+
174
+ # Process the detection results
175
+ if i == 0:
176
+ # First frame - initialize tracking
177
+ if len(detected_boxes) < number_of_speakers:
178
+ raise ValueError(f"First frame must detect at least {number_of_speakers} faces, but only found {len(detected_boxes)}")
179
+
180
+ # Assign first detections to speakers in order
181
+ for j in range(number_of_speakers):
182
+ box = detected_boxes[j]
183
+ face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
184
+ preds = fa.get_landmarks(np.array(face))
185
+
186
+ if preds is None:
187
+ raise ValueError(f"Face landmarks not detected in initial frame for speaker {j}")
188
+
189
+ faces_dic[j].append(face)
190
+ landmarks_dic[j].append(preds)
191
+ boxes_dic[j].append(box)
192
+ else:
193
+ # For subsequent frames, match detected boxes to speakers
194
+ matched_speakers = set()
195
+ speaker_boxes = [None] * number_of_speakers
196
+
197
+ # Match each detected box to the most likely speaker
198
+ for box in detected_boxes:
199
+ iou_scores = []
200
+ for speaker_id in range(number_of_speakers):
201
+ if speaker_id in matched_speakers:
202
+ iou_scores.append(-1) # Already matched
203
+ else:
204
+ last_box = boxes_dic[speaker_id][-1]
205
+ iou_score = bb_intersection_over_union(box, last_box)
206
+ iou_scores.append(iou_score)
207
+
208
+ if max(iou_scores) > 0: # Valid match found
209
+ best_speaker = iou_scores.index(max(iou_scores))
210
+ speaker_boxes[best_speaker] = box
211
+ matched_speakers.add(best_speaker)
212
+
213
+ # Process each speaker
214
+ for speaker_id in range(number_of_speakers):
215
+ if speaker_boxes[speaker_id] is not None:
216
+ # Use detected box
217
+ box = speaker_boxes[speaker_id]
218
+ else:
219
+ # Use previous box for this speaker
220
+ box = boxes_dic[speaker_id][-1]
221
+
222
+ # Extract face and landmarks
223
+ face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
224
+ preds = fa.get_landmarks(np.array(face))
225
+
226
+ if preds is None:
227
+ # Use previous landmarks if detection fails
228
+ preds = landmarks_dic[speaker_id][-1]
229
+
230
+ faces_dic[speaker_id].append(face)
231
+ landmarks_dic[speaker_id].append(preds)
232
+ boxes_dic[speaker_id].append(box)
233
+
234
+ # Verify all speakers have same number of frames
235
+ frame_counts = [len(boxes_dic[s]) for s in range(number_of_speakers)]
236
+ print(f"\nFrame counts per speaker: {frame_counts}")
237
+ assert all(count == len(frames) for count in frame_counts), f"Inconsistent frame counts: {frame_counts}"
238
+
239
+ # Continue with saving videos and landmarks...
240
+ for s in range(number_of_speakers):
241
+ frames_tracked = []
242
+ for i, frame in enumerate(frames):
243
+ frame_draw = frame.copy()
244
+ draw = ImageDraw.Draw(frame_draw)
245
+ draw.rectangle(boxes_dic[s][i], outline=(255, 0, 0), width=6)
246
+ frames_tracked.append(frame_draw)
247
+
248
+ # Save tracked video
249
+ tracked_frames = [np.array(frame) for frame in frames_tracked]
250
+ if tracked_frames:
251
+ tracked_clip = ImageSequenceClip(tracked_frames, fps=25.0)
252
+ tracked_video_path = os.path.join(output_path, 'video_tracked' + str(s+1) + '.mp4')
253
+ tracked_clip.write_videofile(tracked_video_path, codec='libx264', audio=False, logger=None)
254
+ tracked_clip.close()
255
+
256
+ # Save landmarks
257
+ for i in range(number_of_speakers):
258
+ save2npz(os.path.join(output_path, 'landmark', 'speaker' + str(i+1)+'.npz'), data=landmarks_dic[i])
259
+
260
+ # Save face video
261
+ face_frames = [np.array(frame) for frame in faces_dic[i]]
262
+ if face_frames:
263
+ face_clip = ImageSequenceClip(face_frames, fps=25.0)
264
+ face_video_path = os.path.join(output_path, 'faces', 'speaker' + str(i+1) + '.mp4')
265
+ face_clip.write_videofile(face_video_path, codec='libx264', audio=False, logger=None)
266
+ face_clip.close()
267
+
268
+ # Output video path
269
+ parts = video_input_path.split('/')
270
+ video_name = parts[-1][:-4]
271
+ if not os.path.exists(os.path.join(output_path, 'filename_input')):
272
+ os.mkdir(os.path.join(output_path, 'filename_input'))
273
+ csvfile = open(os.path.join(output_path, 'filename_input', str(video_name) + '.csv'), 'w')
274
+ for i in range(number_of_speakers):
275
+ csvfile.write('speaker' + str(i+1)+ ',0\n')
276
+ csvfile.close()
277
+ return os.path.join(output_path, 'filename_input', str(video_name) + '.csv')
278
+
279
+
280
+ def crop_patch(mean_face_landmarks, video_pathname, landmarks, window_margin, start_idx, stop_idx, crop_height, crop_width, STD_SIZE=(256, 256)):
281
+
282
+ """Crop mouth patch
283
+ :param str video_pathname: pathname for the video_dieo
284
+ :param list landmarks: interpolated landmarks
285
+ """
286
+
287
+ stablePntsIDs = [33, 36, 39, 42, 45]
288
+
289
+ frame_idx = 0
290
+ frame_gen = read_video(video_pathname)
291
+ while True:
292
+ try:
293
+ frame = frame_gen.__next__() ## -- BGR
294
+ except StopIteration:
295
+ break
296
+ if frame_idx == 0:
297
+ q_frame, q_landmarks = deque(), deque()
298
+ sequence = []
299
+
300
+ q_landmarks.append(landmarks[frame_idx])
301
+ q_frame.append(frame)
302
+ if len(q_frame) == window_margin:
303
+ smoothed_landmarks = np.mean(q_landmarks, axis=0)
304
+ cur_landmarks = q_landmarks.popleft()
305
+ cur_frame = q_frame.popleft()
306
+ # -- affine transformation
307
+ trans_frame, trans = warp_img( smoothed_landmarks[stablePntsIDs, :],
308
+ mean_face_landmarks[stablePntsIDs, :],
309
+ cur_frame,
310
+ STD_SIZE)
311
+ trans_landmarks = trans(cur_landmarks)
312
+ # -- crop mouth patch
313
+ sequence.append( cut_patch( trans_frame,
314
+ trans_landmarks[start_idx:stop_idx],
315
+ crop_height//2,
316
+ crop_width//2,))
317
+ if frame_idx == len(landmarks)-1:
318
+ #deal with corner case with video too short
319
+ if len(landmarks) < window_margin:
320
+ smoothed_landmarks = np.mean(q_landmarks, axis=0)
321
+ cur_landmarks = q_landmarks.popleft()
322
+ cur_frame = q_frame.popleft()
323
+
324
+ # -- affine transformation
325
+ trans_frame, trans = warp_img(smoothed_landmarks[stablePntsIDs, :],
326
+ mean_face_landmarks[stablePntsIDs, :],
327
+ cur_frame,
328
+ STD_SIZE)
329
+ trans_landmarks = trans(cur_landmarks)
330
+ # -- crop mouth patch
331
+ sequence.append(cut_patch( trans_frame,
332
+ trans_landmarks[start_idx:stop_idx],
333
+ crop_height//2,
334
+ crop_width//2,))
335
+
336
+ while q_frame:
337
+ cur_frame = q_frame.popleft()
338
+ # -- transform frame
339
+ trans_frame = apply_transform( trans, cur_frame, STD_SIZE)
340
+ # -- transform landmarks
341
+ trans_landmarks = trans(q_landmarks.popleft())
342
+ # -- crop mouth patch
343
+ sequence.append( cut_patch( trans_frame,
344
+ trans_landmarks[start_idx:stop_idx],
345
+ crop_height//2,
346
+ crop_width//2,))
347
+ return np.array(sequence)
348
+ frame_idx += 1
349
+ return None
350
+
351
+ def landmarks_interpolate(landmarks):
352
+
353
+ """Interpolate landmarks
354
+ param list landmarks: landmarks detected in raw videos
355
+ """
356
+
357
+ valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
358
+ if not valid_frames_idx:
359
+ return None
360
+ for idx in range(1, len(valid_frames_idx)):
361
+ if valid_frames_idx[idx] - valid_frames_idx[idx-1] == 1:
362
+ continue
363
+ else:
364
+ landmarks = linear_interpolate(landmarks, valid_frames_idx[idx-1], valid_frames_idx[idx])
365
+ valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
366
+ # -- Corner case: keep frames at the beginning or at the end failed to be detected.
367
+ if valid_frames_idx:
368
+ landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
369
+ landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1])
370
+ valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
371
+ assert len(valid_frames_idx) == len(landmarks), "not every frame has landmark"
372
+ return landmarks
373
+
374
+ def crop_mouth(video_direc, landmark_direc, filename_path, save_direc, convert_gray=False, testset_only=False):
375
+ lines = open(filename_path).read().splitlines()
376
+ lines = list(filter(lambda x: 'test' in x, lines)) if testset_only else lines
377
+
378
+ for filename_idx, line in enumerate(lines):
379
+
380
+ filename, person_id = line.split(',')
381
+ print('idx: {} \tProcessing.\t{}'.format(filename_idx, filename))
382
+
383
+ video_pathname = os.path.join(video_direc, filename+'.mp4')
384
+ landmarks_pathname = os.path.join(landmark_direc, filename+'.npz')
385
+ dst_pathname = os.path.join( save_direc, filename+'.npz')
386
+
387
+ # if os.path.exists(dst_pathname):
388
+ # continue
389
+
390
+ multi_sub_landmarks = np.load(landmarks_pathname, allow_pickle=True)['data']
391
+ landmarks = [None] * len(multi_sub_landmarks)
392
+ for frame_idx in range(len(landmarks)):
393
+ try:
394
+ #landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)]['facial_landmarks'] #original for LRW
395
+ landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)] #VOXCELEB2
396
+ except (IndexError, TypeError):
397
+ continue
398
+
399
+ # -- pre-process landmarks: interpolate frames not being detected.
400
+ preprocessed_landmarks = landmarks_interpolate(landmarks)
401
+ if not preprocessed_landmarks:
402
+ continue
403
+
404
+ # -- crop
405
+ mean_face_landmarks = np.load('assets/20words_mean_face.npy')
406
+ sequence = crop_patch(mean_face_landmarks, video_pathname, preprocessed_landmarks, 12, 48, 68, 96, 96)
407
+ assert sequence is not None, "cannot crop from {}.".format(filename)
408
+
409
+ # -- save
410
+ data = convert_bgr2gray(sequence) if convert_gray else sequence[...,::-1]
411
+ save2npz(dst_pathname, data=data)
412
+
413
+ def convert_video_fps(input_file, output_file, target_fps=25):
414
+ """Convert video to target FPS using moviepy"""
415
+ video = VideoFileClip(input_file)
416
+ video_fps = video.fps
417
+
418
+ if video_fps != target_fps:
419
+ video.write_videofile(
420
+ output_file,
421
+ fps=target_fps,
422
+ codec='libx264',
423
+ audio_codec='aac',
424
+ temp_audiofile='temp-audio.m4a',
425
+ remove_temp=True,
426
+ )
427
+ else:
428
+ # If already at target fps, just copy
429
+ import shutil
430
+ shutil.copy2(input_file, output_file)
431
+
432
+ video.close()
433
+ print(f'Video has been converted to {target_fps} fps and saved to {output_file}')
434
+
435
+ def extract_audio(video_file, audio_output_file, sample_rate=16000):
436
+ """Extract audio from video using moviepy"""
437
+ video = VideoFileClip(video_file)
438
+ audio = video.audio
439
+
440
+ # Save audio with specified sample rate
441
+ audio.write_audiofile(audio_output_file, fps=sample_rate, nbytes=2, codec='pcm_s16le')
442
+
443
+ video.close()
444
+ audio.close()
445
+
446
+ def merge_video_audio(video_file, audio_file, output_file):
447
+ """Merge video and audio using moviepy"""
448
+ video = VideoFileClip(video_file)
449
+ audio = AudioFileClip(audio_file)
450
+
451
+ # Attach audio (MoviePy v2 renamed set_audio to with_audio)
452
+ set_audio_fn = getattr(video, "set_audio", None)
453
+ if callable(set_audio_fn):
454
+ final_video = set_audio_fn(audio)
455
+ else:
456
+ with_audio_fn = getattr(video, "with_audio", None)
457
+ if not callable(with_audio_fn):
458
+ video.close()
459
+ audio.close()
460
+ raise AttributeError("VideoFileClip object lacks both set_audio and with_audio methods")
461
+ final_video = with_audio_fn(audio)
462
+
463
+ # Write the result
464
+ final_video.write_videofile(output_file, codec='libx264', audio_codec='aac', temp_audiofile='temp-audio.m4a', remove_temp=True)
465
+
466
+ # Clean up
467
+ video.close()
468
+ audio.close()
469
+ final_video.close()
470
+
471
+ def process_video(input_file, output_path, number_of_speakers=2,
472
+ detect_every_N_frame=8, scalar_face_detection=1.5,
473
+ config_path="checkpoints/vox2/conf.yml",
474
+ cuda_device=None):
475
+ """Main processing function for video speaker separation"""
476
+
477
+ # Set CUDA device if specified
478
+ if cuda_device is not None:
479
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)
480
+
481
+ # Create output directory
482
+ os.makedirs(output_path, exist_ok=True)
483
+
484
+ # Convert video to 25fps
485
+ temp_25fps_file = os.path.join(output_path, 'temp_25fps.mp4')
486
+ convert_video_fps(input_file, temp_25fps_file, target_fps=25)
487
+
488
+ # Detect faces
489
+ filename_path = detectface(video_input_path=temp_25fps_file,
490
+ output_path=output_path,
491
+ detect_every_N_frame=detect_every_N_frame,
492
+ scalar_face_detection=scalar_face_detection,
493
+ number_of_speakers=number_of_speakers)
494
+
495
+ # Extract audio
496
+ audio_output = os.path.join(output_path, 'audio.wav')
497
+ extract_audio(temp_25fps_file, audio_output, sample_rate=16000)
498
+
499
+ # Crop mouth
500
+ crop_mouth(video_direc=os.path.join(output_path, "faces"),
501
+ landmark_direc=os.path.join(output_path, "landmark"),
502
+ filename_path=filename_path,
503
+ save_direc=os.path.join(output_path, "mouthroi"),
504
+ convert_gray=True,
505
+ testset_only=False)
506
+
507
+ # Load model
508
+ audiomodel = Dolphin.from_pretrained("JusperLee/Dolphin")
509
+
510
+ audiomodel.cuda()
511
+ audiomodel.eval()
512
+
513
+ # Process each speaker
514
+ with torch.no_grad():
515
+ for i in range(number_of_speakers):
516
+ mouth_roi = np.load(os.path.join(output_path, "mouthroi", f"speaker{i+1}.npz"))["data"]
517
+ mouth_roi = get_preprocessing_pipelines()["val"](mouth_roi)
518
+
519
+ mix, sr = torchaudio.load(audio_output)
520
+ mix = mix.cuda().mean(dim=0)
521
+
522
+ window_size = 4 * sr
523
+ hop_size = 4 * sr
524
+
525
+ all_estimates = []
526
+
527
+ # 滑动窗口处理
528
+ start_idx = 0
529
+ while start_idx < len(mix):
530
+ end_idx = min(start_idx + window_size, len(mix))
531
+ window_mix = mix[start_idx:end_idx]
532
+
533
+ start_frame = int(start_idx / sr * 25)
534
+ end_frame = int(end_idx / sr * 25)
535
+ end_frame = min(end_frame, len(mouth_roi))
536
+ window_mouth_roi = mouth_roi[start_frame:end_frame]
537
+
538
+ est_sources = audiomodel(window_mix[None],
539
+ torch.from_numpy(window_mouth_roi[None, None]).float().cuda())
540
+
541
+ all_estimates.append({
542
+ 'start': start_idx,
543
+ 'end': end_idx,
544
+ 'estimate': est_sources[0].cpu()
545
+ })
546
+
547
+ start_idx += hop_size
548
+
549
+ if start_idx >= len(mix):
550
+ break
551
+
552
+ output_length = len(mix)
553
+ merged_output = torch.zeros(1, output_length)
554
+ weights = torch.zeros(output_length)
555
+
556
+ for est in all_estimates:
557
+ window_len = est['end'] - est['start']
558
+ hann_window = torch.hann_window(window_len)
559
+
560
+ merged_output[0, est['start']:est['end']] += est['estimate'][0, :window_len] * hann_window
561
+ weights[est['start']:est['end']] += hann_window
562
+
563
+ merged_output[:, weights > 0] /= weights[weights > 0]
564
+
565
+ torchaudio.save(os.path.join(output_path, f"speaker{i+1}_est.wav"), merged_output, sr)
566
+
567
+
568
+ # Merge video with separated audio for each speaker
569
+ output_files = []
570
+ for i in range(number_of_speakers):
571
+ video_input = os.path.join(output_path, f"video_tracked{i+1}.mp4")
572
+ audio_input = os.path.join(output_path, f"speaker{i+1}_est.wav")
573
+ video_output = os.path.join(output_path, f"s{i+1}.mp4")
574
+
575
+ merge_video_audio(video_input, audio_input, video_output)
576
+ output_files.append(video_output)
577
+
578
+ # Clean up temporary file
579
+ if os.path.exists(temp_25fps_file):
580
+ os.remove(temp_25fps_file)
581
+
582
+ return output_files
583
+
584
+ if __name__ == '__main__':
585
+ parser = argparse.ArgumentParser(description='Video Speaker Separation using Dolphin model')
586
+ parser.add_argument('--input', '-i', type=str, required=True,
587
+ help='Path to input video file')
588
+ parser.add_argument('--output', '-o', type=str, default=None,
589
+ help='Output directory path (default: creates directory based on input filename)')
590
+ parser.add_argument('--speakers', '-s', type=int, default=2,
591
+ help='Number of speakers to separate (default: 2)')
592
+ parser.add_argument('--detect-every-n', type=int, default=8,
593
+ help='Detect faces every N frames (default: 8)')
594
+ parser.add_argument('--face-scale', type=float, default=1.5,
595
+ help='Face detection bounding box scale factor (default: 1.5)')
596
+ parser.add_argument('--cuda-device', type=int, default=0,
597
+ help='CUDA device ID to use (default: 0, set to -1 for CPU)')
598
+ parser.add_argument('--config', type=str, default="checkpoints/vox2/conf.yml",
599
+ help='Path to model configuration file')
600
+
601
+ args = parser.parse_args()
602
+
603
+ # 验证输入文件是否存在
604
+ if not os.path.exists(args.input):
605
+ print(f"Error: Input file '{args.input}' does not exist")
606
+ exit(1)
607
+
608
+ # 如果没有指定输出路径,基于输入文件名创建输出目录
609
+ if args.output is None:
610
+ input_basename = os.path.splitext(os.path.basename(args.input))[0]
611
+ args.output = os.path.join(os.path.dirname(args.input), input_basename + "_output")
612
+
613
+ # 设置CUDA设备
614
+ cuda_device = args.cuda_device if args.cuda_device >= 0 else None
615
+
616
+ print(f"Processing video: {args.input}")
617
+ print(f"Output directory: {args.output}")
618
+ print(f"Number of speakers: {args.speakers}")
619
+ print(f"CUDA device: {cuda_device if cuda_device is not None else 'CPU'}")
620
+
621
+ # 处理视频
622
+ output_files = process_video(
623
+ input_file=args.input,
624
+ output_path=args.output,
625
+ number_of_speakers=args.speakers,
626
+ detect_every_N_frame=args.detect_every_n,
627
+ scalar_face_detection=args.face_scale,
628
+ config_path=args.config,
629
+ cuda_device=cuda_device
630
+ )
631
+
632
+ print("\nProcessing completed!")
633
+ print("Output files:")
634
+ for i, output_file in enumerate(output_files):
635
+ print(f" Speaker {i+1}: {output_file}")
Inference_with_status.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ from moviepy import *
8
+ from PIL import Image, ImageDraw
9
+ import face_alignment
10
+ import cv2
11
+
12
+ from look2hear.models import Dolphin
13
+ from look2hear.datas.transform import get_preprocessing_pipelines
14
+
15
+ from face_detection_utils import detect_faces
16
+
17
+ # Import functions from original Inference.py
18
+ from Inference import (
19
+ linear_interpolate, warp_img, apply_transform, cut_patch, convert_bgr2gray,
20
+ save2npz, read_video, face2head, bb_intersection_over_union,
21
+ landmarks_interpolate, crop_patch, convert_video_fps, extract_audio, merge_video_audio
22
+ )
23
+
24
+ def detectface_with_status(video_input_path, output_path, detect_every_N_frame, scalar_face_detection, number_of_speakers, status_callback=None):
25
+ """Face detection with status updates"""
26
+ device = torch.device('cuda' if torch.cuda.get_device_name() else 'cpu')
27
+ if status_callback:
28
+ status_callback({'status': f'Running on device: {device}', 'progress': 0.0})
29
+
30
+ os.makedirs(os.path.join(output_path, 'faces'), exist_ok=True)
31
+ os.makedirs(os.path.join(output_path, 'landmark'), exist_ok=True)
32
+
33
+ landmarks_dic = {}
34
+ faces_dic = {}
35
+ boxes_dic = {}
36
+
37
+ for i in range(number_of_speakers):
38
+ landmarks_dic[i] = []
39
+ faces_dic[i] = []
40
+ boxes_dic[i] = []
41
+
42
+ video_clip = VideoFileClip(video_input_path)
43
+ if status_callback:
44
+ status_callback({'status': f"Video: {video_clip.w}x{video_clip.h}, {video_clip.fps}fps", 'progress': 0.05})
45
+
46
+ frames = [Image.fromarray(frame) for frame in video_clip.iter_frames()]
47
+ total_frames = len(frames)
48
+ if status_callback:
49
+ status_callback({'status': f'Processing {total_frames} frames', 'progress': 0.1})
50
+
51
+ video_clip.close()
52
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
53
+
54
+ for i, frame in enumerate(frames):
55
+ if status_callback and i % 10 == 0:
56
+ status_callback({'status': f'Tracking frame: {i+1}/{total_frames}', 'progress': 0.1 + 0.3 * (i / total_frames)})
57
+
58
+ # Detect faces every N frames
59
+ if i % detect_every_N_frame == 0:
60
+ frame_array = np.array(frame)
61
+
62
+ detected_boxes, _ = detect_faces(
63
+ frame_array,
64
+ threshold=0.9,
65
+ allow_upscaling=False,
66
+ )
67
+
68
+ if detected_boxes is None or len(detected_boxes) == 0:
69
+ detected_boxes, _ = detect_faces(
70
+ frame_array,
71
+ threshold=0.7,
72
+ allow_upscaling=True,
73
+ )
74
+
75
+ if detected_boxes is not None and len(detected_boxes) > 0:
76
+ detected_boxes = np.asarray(detected_boxes, dtype=np.float32)
77
+ areas = (detected_boxes[:, 2] - detected_boxes[:, 0]) * (detected_boxes[:, 3] - detected_boxes[:, 1])
78
+ sort_idx = np.argsort(areas)[::-1]
79
+ detected_boxes = detected_boxes[sort_idx][:number_of_speakers]
80
+ detected_boxes = face2head(detected_boxes, scalar_face_detection)
81
+ detected_boxes = [box for box in detected_boxes]
82
+ else:
83
+ detected_boxes = []
84
+
85
+ # Process the detection results (same as original)
86
+ if i == 0:
87
+ # First frame - initialize tracking
88
+ if len(detected_boxes) < number_of_speakers:
89
+ raise ValueError(f"First frame must detect at least {number_of_speakers} faces, but only found {len(detected_boxes)}")
90
+
91
+ # Assign first detections to speakers in order
92
+ for j in range(number_of_speakers):
93
+ box = detected_boxes[j]
94
+ face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
95
+ preds = fa.get_landmarks(np.array(face))
96
+
97
+ if preds is None:
98
+ raise ValueError(f"Face landmarks not detected in initial frame for speaker {j}")
99
+
100
+ faces_dic[j].append(face)
101
+ landmarks_dic[j].append(preds)
102
+ boxes_dic[j].append(box)
103
+ else:
104
+ # For subsequent frames, match detected boxes to speakers
105
+ matched_speakers = set()
106
+ speaker_boxes = [None] * number_of_speakers
107
+
108
+ # Match each detected box to the most likely speaker
109
+ for box in detected_boxes:
110
+ iou_scores = []
111
+ for speaker_id in range(number_of_speakers):
112
+ if speaker_id in matched_speakers:
113
+ iou_scores.append(-1) # Already matched
114
+ else:
115
+ last_box = boxes_dic[speaker_id][-1]
116
+ iou_score = bb_intersection_over_union(box, last_box)
117
+ iou_scores.append(iou_score)
118
+
119
+ if max(iou_scores) > 0: # Valid match found
120
+ best_speaker = iou_scores.index(max(iou_scores))
121
+ speaker_boxes[best_speaker] = box
122
+ matched_speakers.add(best_speaker)
123
+
124
+ # Process each speaker
125
+ for speaker_id in range(number_of_speakers):
126
+ if speaker_boxes[speaker_id] is not None:
127
+ # Use detected box
128
+ box = speaker_boxes[speaker_id]
129
+ else:
130
+ # Use previous box for this speaker
131
+ box = boxes_dic[speaker_id][-1]
132
+
133
+ # Extract face and landmarks
134
+ face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224))
135
+ preds = fa.get_landmarks(np.array(face))
136
+
137
+ if preds is None:
138
+ # Use previous landmarks if detection fails
139
+ preds = landmarks_dic[speaker_id][-1]
140
+
141
+ faces_dic[speaker_id].append(face)
142
+ landmarks_dic[speaker_id].append(preds)
143
+ boxes_dic[speaker_id].append(box)
144
+
145
+ # Verify all speakers have same number of frames
146
+ frame_counts = [len(boxes_dic[s]) for s in range(number_of_speakers)]
147
+ if status_callback:
148
+ status_callback({'status': f"Frame counts per speaker: {frame_counts}", 'progress': 0.4})
149
+
150
+ assert all(count == len(frames) for count in frame_counts), f"Inconsistent frame counts: {frame_counts}"
151
+
152
+ # Continue with saving videos and landmarks...
153
+ for s in range(number_of_speakers):
154
+ if status_callback:
155
+ status_callback({'status': f'Saving tracked video for speaker {s+1}', 'progress': 0.4 + 0.1 * (s / number_of_speakers)})
156
+
157
+ frames_tracked = []
158
+ for i, frame in enumerate(frames):
159
+ frame_draw = frame.copy()
160
+ draw = ImageDraw.Draw(frame_draw)
161
+ draw.rectangle(boxes_dic[s][i], outline=(255, 0, 0), width=6)
162
+ frames_tracked.append(frame_draw)
163
+
164
+ # Save tracked video
165
+ tracked_frames = [np.array(frame) for frame in frames_tracked]
166
+ if tracked_frames:
167
+ tracked_clip = ImageSequenceClip(tracked_frames, fps=25.0)
168
+ tracked_video_path = os.path.join(output_path, 'video_tracked' + str(s+1) + '.mp4')
169
+ tracked_clip.write_videofile(tracked_video_path, codec='libx264', audio=False, logger=None)
170
+ tracked_clip.close()
171
+
172
+ # Save landmarks
173
+ for i in range(number_of_speakers):
174
+ # Create landmark directory if it doesn't exist
175
+ landmark_dir = os.path.join(output_path, 'landmark')
176
+ os.makedirs(landmark_dir, exist_ok=True)
177
+ save2npz(os.path.join(landmark_dir, 'speaker' + str(i+1)+'.npz'), data=landmarks_dic[i])
178
+
179
+ # Save face video
180
+ face_frames = [np.array(frame) for frame in faces_dic[i]]
181
+ if face_frames:
182
+ face_clip = ImageSequenceClip(face_frames, fps=25.0)
183
+ face_video_path = os.path.join(output_path, 'faces', 'speaker' + str(i+1) + '.mp4')
184
+ face_clip.write_videofile(face_video_path, codec='libx264', audio=False, logger=None)
185
+ face_clip.close()
186
+
187
+ # Output video path
188
+ parts = video_input_path.split('/')
189
+ video_name = parts[-1][:-4]
190
+ filename_dir = os.path.join(output_path, 'filename_input')
191
+ os.makedirs(filename_dir, exist_ok=True)
192
+ csvfile = open(os.path.join(filename_dir, str(video_name) + '.csv'), 'w')
193
+ for i in range(number_of_speakers):
194
+ csvfile.write('speaker' + str(i+1)+ ',0\n')
195
+ csvfile.close()
196
+ return os.path.join(filename_dir, str(video_name) + '.csv')
197
+
198
+
199
+ def crop_mouth_with_status(video_direc, landmark_direc, filename_path, save_direc, status_callback=None, convert_gray=False, testset_only=False):
200
+ """Crop mouth with status updates"""
201
+ lines = open(filename_path).read().splitlines()
202
+ lines = list(filter(lambda x: 'test' in x, lines)) if testset_only else lines
203
+
204
+ for filename_idx, line in enumerate(lines):
205
+ filename, person_id = line.split(',')
206
+
207
+ if status_callback:
208
+ status_callback({'status': f'Processing speaker{int(person_id)+1}', 'progress': 0.5 + 0.1 * filename_idx / len(lines)})
209
+
210
+ video_pathname = os.path.join(video_direc, filename+'.mp4')
211
+ landmarks_pathname = os.path.join(landmark_direc, filename+'.npz')
212
+
213
+ # Create mouthroi directory if it doesn't exist
214
+ os.makedirs(save_direc, exist_ok=True)
215
+ dst_pathname = os.path.join(save_direc, filename+'.npz')
216
+
217
+ multi_sub_landmarks = np.load(landmarks_pathname, allow_pickle=True)['data']
218
+ if len(multi_sub_landmarks) == 0:
219
+ print(f"No landmarks found for {filename}, skipping crop.")
220
+ continue
221
+
222
+ landmark_frame_count = len(multi_sub_landmarks)
223
+ cap = cv2.VideoCapture(video_pathname)
224
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
225
+ cap.release()
226
+
227
+ if frame_count > 0 and frame_count != landmark_frame_count:
228
+ print(
229
+ f"Frame count mismatch for {filename}: video has {frame_count} frames, "
230
+ f"landmarks have {landmark_frame_count} entries. Adjusting to match."
231
+ )
232
+ if frame_count < landmark_frame_count:
233
+ multi_sub_landmarks = multi_sub_landmarks[:frame_count]
234
+ else:
235
+ pad_count = frame_count - landmark_frame_count
236
+ pad = np.repeat(multi_sub_landmarks[-1:], pad_count, axis=0)
237
+ multi_sub_landmarks = np.concatenate((multi_sub_landmarks, pad), axis=0)
238
+
239
+ landmarks = [None] * len(multi_sub_landmarks)
240
+ for frame_idx in range(len(landmarks)):
241
+ try:
242
+ landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)]
243
+ except (IndexError, TypeError):
244
+ continue
245
+
246
+ # Pre-process landmarks: interpolate frames not being detected
247
+ preprocessed_landmarks = landmarks_interpolate(landmarks)
248
+ if not preprocessed_landmarks:
249
+ continue
250
+
251
+ # Crop
252
+ mean_face_landmarks = np.load('assets/20words_mean_face.npy')
253
+ sequence = crop_patch(mean_face_landmarks, video_pathname, preprocessed_landmarks, 12, 48, 68, 96, 96)
254
+ assert sequence is not None, "cannot crop from {}.".format(filename)
255
+
256
+ # Save
257
+ data = convert_bgr2gray(sequence) if convert_gray else sequence[...,::-1]
258
+ save2npz(dst_pathname, data=data)
259
+
260
+
261
+ def process_video_with_status(input_file, output_path, number_of_speakers=2,
262
+ detect_every_N_frame=8, scalar_face_detection=1.5,
263
+ config_path="checkpoints/vox2/conf.yml",
264
+ cuda_device=None, status_callback=None):
265
+ """Main processing function with status updates"""
266
+
267
+ # Set CUDA device if specified
268
+ if cuda_device is not None:
269
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)
270
+
271
+ # Create output directory
272
+ os.makedirs(output_path, exist_ok=True)
273
+
274
+ # Convert video to 25fps
275
+ if status_callback:
276
+ status_callback({'status': 'Converting video to 25fps', 'progress': 0.0})
277
+
278
+ temp_25fps_file = os.path.join(output_path, 'temp_25fps.mp4')
279
+ convert_video_fps(input_file, temp_25fps_file, target_fps=25)
280
+
281
+ # Detect faces
282
+ if status_callback:
283
+ status_callback({'status': 'Detecting faces and tracking speakers', 'progress': 0.1})
284
+
285
+ filename_path = detectface_with_status(
286
+ video_input_path=temp_25fps_file,
287
+ output_path=output_path,
288
+ detect_every_N_frame=detect_every_N_frame,
289
+ scalar_face_detection=scalar_face_detection,
290
+ number_of_speakers=number_of_speakers,
291
+ status_callback=status_callback
292
+ )
293
+ torch.cuda.empty_cache()
294
+ # Extract audio
295
+ if status_callback:
296
+ status_callback({'status': 'Extracting audio from video', 'progress': 0.5})
297
+
298
+ audio_output = os.path.join(output_path, 'audio.wav')
299
+ extract_audio(temp_25fps_file, audio_output, sample_rate=16000)
300
+
301
+ # Crop mouth
302
+ if status_callback:
303
+ status_callback({'status': 'Cropping mouth regions', 'progress': 0.55})
304
+
305
+ crop_mouth_with_status(
306
+ video_direc=os.path.join(output_path, "faces"),
307
+ landmark_direc=os.path.join(output_path, "landmark"),
308
+ filename_path=filename_path,
309
+ save_direc=os.path.join(output_path, "mouthroi"),
310
+ convert_gray=True,
311
+ testset_only=False,
312
+ status_callback=status_callback
313
+ )
314
+
315
+ # Load model
316
+ if status_callback:
317
+ status_callback({'status': 'Loading Dolphin model', 'progress': 0.6})
318
+ torch.cuda.empty_cache()
319
+ audiomodel = Dolphin.from_pretrained("JusperLee/Dolphin")
320
+ # audiomodel.cuda()
321
+ audiomodel.eval()
322
+
323
+ # Process each speaker
324
+ with torch.no_grad():
325
+ for i in range(number_of_speakers):
326
+ if status_callback:
327
+ status_callback({'status': f'Processing audio for speaker {i+1}', 'progress': 0.65 + 0.25 * (i / number_of_speakers)})
328
+
329
+ mouth_roi_path = os.path.join(output_path, "mouthroi", f"speaker{i+1}.npz")
330
+ mouth_roi = np.load(mouth_roi_path)["data"]
331
+ mouth_roi = get_preprocessing_pipelines()["val"](mouth_roi)
332
+
333
+ mix, sr = torchaudio.load(audio_output)
334
+ mix = mix.mean(dim=0)
335
+
336
+ window_size = 4 * sr
337
+ hop_size = int(4 * sr)
338
+
339
+ all_estimates = []
340
+
341
+ # Sliding window processing
342
+ start_idx = 0
343
+ window_count = 0
344
+ while start_idx < len(mix):
345
+ end_idx = min(start_idx + window_size, len(mix))
346
+ window_mix = mix[start_idx:end_idx]
347
+
348
+ start_frame = int(start_idx / sr * 25)
349
+ end_frame = int(end_idx / sr * 25)
350
+ end_frame = min(end_frame, len(mouth_roi))
351
+ window_mouth_roi = mouth_roi[start_frame:end_frame]
352
+
353
+ est_sources = audiomodel(window_mix[None],
354
+ torch.from_numpy(window_mouth_roi[None, None]).float())
355
+
356
+ all_estimates.append({
357
+ 'start': start_idx,
358
+ 'end': end_idx,
359
+ 'estimate': est_sources[0].cpu()
360
+ })
361
+
362
+ window_count += 1
363
+ if status_callback:
364
+ progress = 0.65 + 0.25 * (i / number_of_speakers) + 0.25 / number_of_speakers * (window_count * hop_size / len(mix))
365
+ status_callback({'status': f'Processing audio window {window_count} for speaker {i+1}', 'progress': min(progress, 0.9)})
366
+
367
+ start_idx += hop_size
368
+
369
+ if start_idx >= len(mix):
370
+ break
371
+ torch.cuda.empty_cache()
372
+
373
+ output_length = len(mix)
374
+ merged_output = torch.zeros(1, output_length)
375
+ weights = torch.zeros(output_length)
376
+
377
+ for est in all_estimates:
378
+ window_len = est['end'] - est['start']
379
+ hann_window = torch.hann_window(window_len)
380
+
381
+ merged_output[0, est['start']:est['end']] += est['estimate'][0, :window_len] * hann_window
382
+ weights[est['start']:est['end']] += hann_window
383
+
384
+ merged_output[:, weights > 0] /= weights[weights > 0]
385
+
386
+ audio_save_path = os.path.join(output_path, f"speaker{i+1}_est.wav")
387
+ torchaudio.save(audio_save_path, merged_output, sr)
388
+
389
+ # Merge video with separated audio for each speaker
390
+ torch.cuda.empty_cache()
391
+ if status_callback:
392
+ status_callback({'status': 'Merging videos with separated audio', 'progress': 0.9})
393
+
394
+ output_files = []
395
+ for i in range(number_of_speakers):
396
+ video_input = os.path.join(output_path, f"video_tracked{i+1}.mp4")
397
+ audio_input = os.path.join(output_path, f"speaker{i+1}_est.wav")
398
+ video_output = os.path.join(output_path, f"s{i+1}.mp4")
399
+
400
+ merge_video_audio(video_input, audio_input, video_output)
401
+ output_files.append(video_output)
402
+
403
+ # Clean up temporary file
404
+ if os.path.exists(temp_25fps_file):
405
+ os.remove(temp_25fps_file)
406
+
407
+ if status_callback:
408
+ status_callback({'status': 'Processing completed!', 'progress': 1.0})
409
+
410
+ return output_files
README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/icon.png" alt="Dolphin Logo" width="150"/>
3
+ </p>
4
+ <h3 align="center">Dolphin: Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Multi-Scale Global-Local Attention</h3>
5
+ <p align="center">
6
+ <strong>Kai Li*, Kejun Gao*, Xiaolin Hu </strong><br>
7
+ <strong>Tsinghua University</strong>
8
+ </p>
9
+
10
+ <p align="center">
11
+ <img src="https://visitor-badge.laobi.icu/badge?page_id=JusperLee.Dolphin" alt="访客统计" />
12
+ <img src="https://img.shields.io/github/stars/JusperLee/Dolphin?style=social" alt="GitHub stars" />
13
+ <img alt="Static Badge" src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" />
14
+ <a href="https://arxiv.org/abs/2509.23610" target="_blank" rel="noreferrer noopener">
15
+ <img alt="arXiv Paper" src="https://img.shields.io/badge/arXiv-2509.23610-b31b1b.svg?logo=arxiv&logoColor=white" />
16
+ </a>
17
+ <a href="https://huggingface.co/JusperLee/Dolphin" target="_blank" rel="noreferrer noopener">
18
+ <img alt="Hugging Face Models" src="https://img.shields.io/badge/Hugging%20Face-Models-ff9d2c?logo=huggingface&logoColor=white" />
19
+ </a>
20
+ <a href="https://dolphin.cslikai.cn/" target="_blank" rel="noreferrer noopener">
21
+ <img alt="Gradio Live Demo" src="https://img.shields.io/badge/Gradio-Live%20Demo-00a67e?logo=gradio&logoColor=white" />
22
+ </a>
23
+ </p>
24
+
25
+ <p align="center">
26
+
27
+ > Dolphin is an efficient audio-visual speech separation framework that leverages discrete lip semantics and global–local attention to achieve state-of-the-art performance with significantly reduced computational complexity.
28
+
29
+ ## 🎯 Highlights
30
+
31
+ - **Balanced Quality & Efficiency**: Single-pass separator achieves state-of-the-art AVSS performance without iterative refinement.
32
+ - **DP-LipCoder**: Dual-path, vector-quantized video encoder produces discrete audio-aligned semantic tokens while staying lightweight.
33
+ - **Global–Local Attention**: TDANet-based separator augments each layer with coarse global self-attention and heat diffusion local attention.
34
+ - **Edge-Friendly Deployment**: Delivers >50% parameter reduction, >2.4× lower MACs, and >6× faster GPU inference versus IIANet.
35
+
36
+ ## 💥 News
37
+
38
+ - **[2025-09-28]** Code and pre-trained models are released! 📦
39
+
40
+ ## 📜 Abstract
41
+
42
+ Audio-visual speech separation (AVSS) methods leverage visual cues to extract target speech in noisy acoustic environments, but most existing systems remain computationally heavy. Dolphin tackles this tension by combining a lightweight, dual-path video encoder with a single-pass global–local collaborative separator. The video pathway, DP-LipCoder, maps lip movements into discrete semantic tokens that remain tightly aligned with audio through vector quantization and distillation from AV-HuBERT. The audio separator builds upon TDANet and injects global–local attention (GLA) blocks—coarse-grained self-attention for long-range context and heat diffusion attention for denoising fine details. Across three public AVSS benchmarks, Dolphin not only outperforms the state-of-the-art IIANet on separation metrics but also delivers over 50% fewer parameters, more than 2.4× lower MACs, and over 6× faster GPU inference, making it practical for edge deployment.
43
+
44
+ ## 🌍 Motivation
45
+
46
+ In real-world environments, target speech is often masked by background noise and interfering speakers. This phenomenon reflects the classic “cocktail party effect,” where listeners selectively attend to a single speaker within a noisy scene (Cherry, 1953). These challenges have spurred extensive research on speech separation.
47
+
48
+ Audio-only approaches tend to struggle in complex acoustic conditions, while the integration of synchronous visual cues offers greater robustness. Recent deep learning-based AVSS systems achieve strong performance, yet many rely on computationally intensive separators or heavy iterative refinement, limiting their practicality.
49
+
50
+ Beyond the separator itself, AVSS models frequently inherit high computational cost from their video encoders. Large-scale lip-reading backbones provide rich semantic alignment but bring prohibitive parameter counts. Compressing them often erodes lip semantics, whereas designing new lightweight encoders from scratch risks losing semantic fidelity and degrading separation quality. Building a video encoder that balances compactness with semantic alignment therefore remains a central challenge for AVSS.
51
+
52
+ ## 🧠 Method Overview
53
+
54
+ To address these limitations, Dolphin introduces a novel AVSS pipeline centered on two components:
55
+
56
+ - **DP-LipCoder**: A dual-path, vector-quantized video encoder that separates compressed visual structure from audio-aligned semantics. By combining vector quantization with knowledge distillation from AV-HuBERT, it converts continuous lip motion into discrete semantic tokens without sacrificing representational capacity.
57
+ - **Single-Pass GLA Separator**: A lightweight TDANet-based audio separator that removes the need for iterative refinement. Each layer hosts a global–local attention block: coarse-grained self-attention captures long-range dependencies at low resolution, while heat diffusion attention smooths features across channels to suppress noise and retain detail.
58
+
59
+ Together, these components strike a balance between separation quality and computational efficiency, enabling deployment in resource-constrained scenarios.
60
+
61
+ ## 🧪 Experimental Highlights
62
+
63
+ We evaluate Dolphin on LRS2, LRS3, and VoxCeleb2. Compared with the state-of-the-art IIANet, Dolphin achieves higher scores across all separation metrics while dramatically reducing resource consumption:
64
+
65
+ - **Parameters**: >50% reduction
66
+ - **Computation**: >2.4× decrease in MACs
67
+ - **Inference**: >6× speedup on GPU
68
+
69
+ These results demonstrate that Dolphin provides competitive AVSS quality on edge hardware without heavy iterative processing.
70
+
71
+ ## 🏗️ Architecture
72
+
73
+ ![Dolphin Architecture](assets/overall-pipeline.png)
74
+
75
+ > The overall architecture of Dolphin.
76
+
77
+ ### Video Encoder
78
+
79
+ ![Dolphin Architecture](assets/video-ae.png)
80
+
81
+ > The video encoder of Dolphin.
82
+
83
+ ### Dolphin Model Overview
84
+
85
+ ![Dolphin Architecture](assets/separator.png)
86
+
87
+ > The overall architecture of Dolphin's separator.
88
+
89
+ ### Key Components
90
+
91
+ ![Dolphin Architecture](assets/ga-msa.png)
92
+
93
+ 1. **Global Attention (GA) Block**
94
+ - Applies coarse-grained self-attention to capture long-range structure
95
+ - Operates at low spatial resolution for efficiency
96
+ - Enhances robustness to complex acoustic mixtures
97
+
98
+ 2. **Local Attention (LA) Block**
99
+ - Uses heat diffusion attention to smooth features across channels
100
+ - Suppresses background noise while preserving details
101
+ - Complements GA to balance global context and local fidelity
102
+
103
+ ## 📊 Results
104
+
105
+ ### Performance Comparison
106
+
107
+ Performance metrics on three public AVSS benchmark datasets. Bold indicates best performance.
108
+
109
+ ![Results Table](assets/results.png)
110
+
111
+ ### Efficiency Analysis
112
+
113
+ ![Efficiency Comparison](assets/efficiency_comparison.png)
114
+
115
+ Dolphin achieves:
116
+ - ✅ **>50%** parameter reduction
117
+ - ✅ **2.4×** lower computational cost (MACs)
118
+ - ✅ **6×** faster GPU inference speed
119
+ - ✅ Superior separation quality across all metrics
120
+
121
+ ## 📦 Installation
122
+
123
+ ```bash
124
+ git clone https://github.com/JusperLee/Dolphin.git
125
+ cd Dolphin
126
+ pip install torch torchvision
127
+ pip install -r requirements.txt
128
+ ```
129
+
130
+ ### Requirements
131
+
132
+ - Python >= 3.10
133
+ - PyTorch >= 2.5.0
134
+ - CUDA >= 12.4
135
+ - Other dependencies in requirements.txt
136
+
137
+ ## 🚀 Quick Start
138
+
139
+ ### Inference with Pre-trained Model
140
+
141
+ ```python
142
+ # Single audio-visual separation
143
+ python inference.py \
144
+ --input /path/to/video.mp4 \
145
+ --output /path/to/output/directory \
146
+ --speakers 2 \
147
+ --detect-every-n 8 \
148
+ --face-scale 1.5 \
149
+ --cuda-device 0 \
150
+ --config checkpoints/vox2/conf.yml
151
+ ```
152
+
153
+ ## 📁 Model Zoo
154
+
155
+ | Model | Training Data | SI-SNRi | PESQ | Download |
156
+ |-------|--------------|---------|------|----------|
157
+ | Dolphin | VoxCeleb2 | 16.1 dB | 3.45 | [Link](https://huggingface.co/JusperLee/Dolphin) |
158
+
159
+ ## 📖 Citation
160
+
161
+ If you find Dolphin useful in your research, please cite:
162
+
163
+ ```bibtex
164
+ @misc{li2025efficientaudiovisualspeechseparation,
165
+ title={Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Multi-Scale Global-Local Attention},
166
+ author={Kai Li and Kejun Gao and Xiaolin Hu},
167
+ year={2025},
168
+ eprint={2509.23610},
169
+ archivePrefix={arXiv},
170
+ primaryClass={cs.SD},
171
+ url={https://arxiv.org/abs/2509.23610},
172
+ }
173
+ ```
174
+
175
+ ## 🤝 Acknowledgments
176
+
177
+ We thank the authors of [IIANet](https://github.com/JusperLee/IIANet) and [SepReformer](https://github.com/dmlguq456/SepReformer) for providing parts of the code used in this project.
178
+
179
+ ## 📧 Contact
180
+
181
+ For questions and feedback, please open an issue on GitHub or contact us at: [tsinghua.kaili@gmail.com](tsinghua.kaili@gmail.com)
182
+
183
+ ## 📄 License
184
+
185
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
186
+
187
+ <p align="center">
188
+ Made with stars ⭐️ for efficient audio-visual speech separation
189
+ </p>
app.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Audio-visual Speech Separation Gradio App - Hugging Face Space Version
4
+ Automatically detects and separates all speakers in videos
5
+ """
6
+
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import os
10
+ import gradio as gr
11
+ import numpy as np
12
+ import shutil
13
+ import tempfile
14
+ import time
15
+ import sys
16
+ import threading
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ from moviepy import *
19
+ import spaces
20
+
21
+ from face_detection_utils import detect_faces
22
+
23
+ # Use HF Space's temp directory
24
+ TEMP_DIR = os.environ.get('TMPDIR', '/tmp')
25
+
26
+ # Shared state for relaying GPU-side status back to the UI thread.
27
+ GPU_PROGRESS_STATE = {"progress": 0.0, "status": "Processing on GPU..."}
28
+ GPU_PROGRESS_LOCK = threading.Lock()
29
+
30
+ class LogCollector:
31
+ """Collect logs in a list"""
32
+ def __init__(self):
33
+ self.logs = []
34
+
35
+ def add(self, message):
36
+ if message and message.strip():
37
+ timestamp = time.strftime("%H:%M:%S")
38
+ self.logs.append(f"[{timestamp}] {message.strip()}")
39
+
40
+ def get_text(self, last_n=None):
41
+ if last_n:
42
+ return "\n".join(self.logs[-last_n:])
43
+ return "\n".join(self.logs)
44
+
45
+ # Global log collector for capturing print statements
46
+ GLOBAL_LOG = LogCollector()
47
+
48
+ class StdoutCapture:
49
+ """Capture stdout and add to log"""
50
+ def __init__(self, original):
51
+ self.original = original
52
+
53
+ def write(self, text):
54
+ self.original.write(text)
55
+ if text.strip():
56
+ GLOBAL_LOG.add(text.strip())
57
+
58
+ def flush(self):
59
+ self.original.flush()
60
+
61
+ def remove_duplicate_faces(boxes, probs, iou_threshold=0.5):
62
+ """Remove duplicate face detections using IoU (Intersection over Union)"""
63
+ if len(boxes) <= 1:
64
+ return boxes, probs
65
+
66
+ # Calculate IoU between all pairs of boxes
67
+ def calculate_iou(box1, box2):
68
+ x1 = max(box1[0], box2[0])
69
+ y1 = max(box1[1], box2[1])
70
+ x2 = min(box1[2], box2[2])
71
+ y2 = min(box1[3], box2[3])
72
+
73
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
74
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
75
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
76
+ union = area1 + area2 - intersection
77
+
78
+ return intersection / union if union > 0 else 0
79
+
80
+ # Keep track of which boxes to keep
81
+ keep = []
82
+ used = set()
83
+
84
+ # Sort by confidence (if available) or by area
85
+ if probs is not None:
86
+ sorted_indices = np.argsort(probs)[::-1]
87
+ else:
88
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
89
+ sorted_indices = np.argsort(areas)[::-1]
90
+
91
+ for i in sorted_indices:
92
+ if i in used:
93
+ continue
94
+
95
+ keep.append(i)
96
+ used.add(i)
97
+
98
+ # Mark overlapping boxes as used
99
+ for j in range(len(boxes)):
100
+ if j != i and j not in used:
101
+ iou = calculate_iou(boxes[i], boxes[j])
102
+ if iou > iou_threshold:
103
+ used.add(j)
104
+
105
+ # Return filtered boxes and probs
106
+ keep = sorted(keep) # Maintain original order
107
+ filtered_boxes = boxes[keep]
108
+ filtered_probs = probs[keep] if probs is not None else None
109
+
110
+ return filtered_boxes, filtered_probs
111
+
112
+ def process_detected_faces(boxes, probs, frame_rgb, frame_pil):
113
+ """Process detected faces and return face images"""
114
+ face_images = []
115
+ full_frame_annotated = frame_rgb.copy()
116
+
117
+ if boxes is None or len(boxes) == 0:
118
+ return [], 0, full_frame_annotated, "No faces detected"
119
+
120
+ boxes = np.asarray(boxes, dtype=np.float32)
121
+
122
+ # Filter by confidence if available
123
+ if probs is not None:
124
+ # Keep faces with confidence > 0.9
125
+ confident_indices = probs > 0.9
126
+ boxes = boxes[confident_indices]
127
+ probs = probs[confident_indices]
128
+ print(f"After filtering by confidence: {len(boxes)} faces")
129
+
130
+ if len(boxes) == 0:
131
+ return [], 0, full_frame_annotated, "No faces passed the confidence filter"
132
+
133
+ # Remove duplicate detections
134
+ boxes, probs = remove_duplicate_faces(boxes, probs, iou_threshold=0.3)
135
+ print(f"After removing duplicates: {len(boxes)} faces")
136
+
137
+ if len(boxes) == 0:
138
+ return [], 0, full_frame_annotated, "No faces remained after duplicate removal"
139
+
140
+ # Sort boxes by area (larger faces first)
141
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
142
+ sorted_indices = np.argsort(areas)[::-1]
143
+ boxes = boxes[sorted_indices]
144
+
145
+ # Annotate full frame
146
+ full_frame_pil = Image.fromarray(full_frame_annotated)
147
+ draw = ImageDraw.Draw(full_frame_pil)
148
+
149
+ # Try to use a better font
150
+ try:
151
+ font = ImageFont.load_default()
152
+ except:
153
+ font = None
154
+
155
+ # Extract face images and annotate
156
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
157
+
158
+ for i, box in enumerate(boxes):
159
+ color = colors[i % len(colors)]
160
+
161
+ # Draw bounding box
162
+ draw.rectangle(box.tolist(), outline=color, width=4)
163
+ label = f"Speaker {i+1}"
164
+
165
+ # Draw label
166
+ if font:
167
+ draw.text((box[0] + 5, box[1] - 20), label, fill=color, font=font)
168
+
169
+ # Extract face with margin
170
+ margin = 50
171
+ x1 = max(0, int(box[0] - margin))
172
+ y1 = max(0, int(box[1] - margin))
173
+ x2 = min(frame_rgb.shape[1], int(box[2] + margin))
174
+ y2 = min(frame_rgb.shape[0], int(box[3] + margin))
175
+
176
+ face_crop = frame_rgb[y1:y2, x1:x2]
177
+ # Resize maintaining aspect ratio
178
+ face_crop = Image.fromarray(face_crop)
179
+ face_crop.thumbnail((250, 250), Image.Resampling.LANCZOS)
180
+ face_crop = np.array(face_crop)
181
+
182
+ face_images.append(face_crop)
183
+
184
+ full_frame_annotated = np.array(full_frame_pil)
185
+ return face_images, len(boxes), full_frame_annotated, None
186
+
187
+ @spaces.GPU(duration=60, enable_queue=True)
188
+ def detect_faces_gpu(frame_pil):
189
+ """GPU-accelerated face detection"""
190
+ print("Detecting faces with RetinaFace")
191
+
192
+ frame_array = np.array(frame_pil)
193
+
194
+ boxes, probs = detect_faces(
195
+ frame_array,
196
+ threshold=0.9,
197
+ allow_upscaling=False,
198
+ )
199
+
200
+ if boxes is None or len(boxes) == 0:
201
+ print("No faces detected at high threshold, relaxing criteria...")
202
+ boxes, probs = detect_faces(
203
+ frame_array,
204
+ threshold=0.7,
205
+ allow_upscaling=True,
206
+ )
207
+
208
+ return boxes, probs
209
+
210
+ def detect_and_extract_all_faces(video_path):
211
+ """Detect all faces in the first frame and extract them"""
212
+ print("Starting face detection...")
213
+
214
+ # Check if video file exists
215
+ if not os.path.exists(video_path):
216
+ print(f"Error: Video file does not exist at path: {video_path}")
217
+ return [], 0, None, f"Video file not found: {video_path}"
218
+
219
+ print(f"Video path: {video_path}")
220
+ print(f"File size: {os.path.getsize(video_path) / 1024 / 1024:.2f} MB")
221
+
222
+ # Use moviepy to read video
223
+ print("Opening video with moviepy...")
224
+ try:
225
+ clip = VideoFileClip(video_path)
226
+
227
+ # Get video properties
228
+ fps = clip.fps
229
+ duration = clip.duration
230
+ total_frames = int(fps * duration)
231
+
232
+ print(f"Video info: FPS: {fps}, Duration: {duration}s, Total frames: {total_frames}")
233
+
234
+ # Get first frame
235
+ frame = clip.get_frame(0) # MoviePy returns RGB
236
+ frame_rgb = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
237
+
238
+ print(f"Successfully read frame with moviepy: {frame_rgb.shape}")
239
+
240
+ # Close the clip to free resources
241
+ clip.close()
242
+
243
+ # Convert to PIL for downstream processing
244
+ frame_pil = Image.fromarray(frame_rgb)
245
+
246
+ # Detect faces using RetinaFace
247
+ print("Detecting faces with RetinaFace...")
248
+ boxes, probs = detect_faces(
249
+ frame_rgb,
250
+ threshold=0.9,
251
+ allow_upscaling=False,
252
+ )
253
+
254
+ if boxes is None or len(boxes) == 0:
255
+ print("No faces detected at high threshold, trying relaxed settings...")
256
+ boxes, probs = detect_faces(
257
+ frame_rgb,
258
+ threshold=0.7,
259
+ allow_upscaling=True,
260
+ )
261
+
262
+ if boxes is not None and len(boxes) > 0:
263
+ print(f"Detected {len(boxes)} faces")
264
+ return process_detected_faces(boxes, probs, frame_rgb, frame_pil)
265
+ else:
266
+ return [], 0, frame_rgb, "No faces detected in the first frame"
267
+
268
+ except Exception as e:
269
+ print(f"MoviePy failed: {e}")
270
+ import traceback
271
+ traceback.print_exc()
272
+ return [], 0, None, f"Failed to open video file. Error: {str(e)}"
273
+
274
+ @spaces.GPU(duration=300, enable_queue=True)
275
+ def process_video_gpu(video_file, temp_dir, num_speakers):
276
+ """GPU-accelerated video processing"""
277
+ try:
278
+ from Inference_with_status import process_video_with_status
279
+
280
+ # Define status callback inside GPU function
281
+ def gpu_status_callback(message):
282
+ status_text = message.get('status', 'Processing...')
283
+ print(f"GPU Processing: {status_text}")
284
+ progress_value = message.get('progress')
285
+ with GPU_PROGRESS_LOCK:
286
+ GPU_PROGRESS_STATE["status"] = status_text
287
+ if progress_value is not None:
288
+ try:
289
+ numeric_progress = float(progress_value)
290
+ GPU_PROGRESS_STATE["progress"] = min(max(numeric_progress, 0.0), 1.0)
291
+ except (TypeError, ValueError):
292
+ pass
293
+
294
+ output_files = process_video_with_status(
295
+ input_file=video_file,
296
+ output_path=temp_dir,
297
+ number_of_speakers=num_speakers,
298
+ detect_every_N_frame=8,
299
+ scalar_face_detection=1.5,
300
+ status_callback=gpu_status_callback
301
+ )
302
+ return output_files
303
+ except ImportError:
304
+ from Inference import process_video
305
+ print("Using standard process_video (status callbacks not available)")
306
+ output_files = process_video(
307
+ input_file=video_file,
308
+ output_path=temp_dir,
309
+ number_of_speakers=num_speakers,
310
+ detect_every_N_frame=8,
311
+ scalar_face_detection=1.5
312
+ )
313
+ return output_files
314
+
315
+ def process_video_auto(video_file, progress=gr.Progress()):
316
+ """Process video with automatic speaker detection and stream status updates"""
317
+ global GLOBAL_LOG
318
+ GLOBAL_LOG = LogCollector()
319
+
320
+ old_stdout = sys.stdout
321
+ sys.stdout = StdoutCapture(old_stdout)
322
+
323
+ status_value = "⏳ Ready to process..."
324
+ detected_info_output = gr.update(visible=False)
325
+ face_gallery_output = gr.update(visible=False)
326
+ output_video_output = gr.update(visible=False)
327
+ video_dict_value = None
328
+ annotated_frame_output = gr.update(visible=False)
329
+
330
+ def snapshot():
331
+ return (
332
+ status_value,
333
+ detected_info_output,
334
+ face_gallery_output,
335
+ output_video_output,
336
+ video_dict_value,
337
+ annotated_frame_output,
338
+ GLOBAL_LOG.get_text()
339
+ )
340
+
341
+ try:
342
+ if video_file is None:
343
+ status_value = "⚠️ Please upload a video file"
344
+ yield snapshot()
345
+ return
346
+
347
+ progress(0, desc="Starting processing...")
348
+ status_value = "🔄 Starting processing..."
349
+ GLOBAL_LOG.add("Starting video processing...")
350
+ yield snapshot()
351
+
352
+ temp_dir = None
353
+ try:
354
+ temp_dir = tempfile.mkdtemp(dir=TEMP_DIR)
355
+ print(f"Created temporary directory: {temp_dir}")
356
+
357
+ progress(0.1, desc="Detecting speakers in video...")
358
+ status_value = "🔍 Detecting speakers in video..."
359
+ print("Starting face detection in video...")
360
+ yield snapshot()
361
+
362
+ face_images, num_speakers, annotated_frame, error_msg = detect_and_extract_all_faces(video_file)
363
+ print(f"Face detection completed. Found {num_speakers} speakers.")
364
+
365
+ if error_msg:
366
+ print(f"Error: {error_msg}")
367
+ status_value = f"❌ {error_msg}"
368
+ if annotated_frame is not None:
369
+ annotated_frame_output = gr.update(value=annotated_frame, visible=True)
370
+ yield snapshot()
371
+ return
372
+
373
+ if num_speakers == 0:
374
+ print("No speakers detected in the video.")
375
+ status_value = "❌ No speakers detected in the video. Please ensure faces are visible in the first frame."
376
+ if annotated_frame is not None:
377
+ annotated_frame_output = gr.update(value=annotated_frame, visible=True)
378
+ yield snapshot()
379
+ return
380
+
381
+ face_gallery_images = [(img, f"Speaker {i+1}") for i, img in enumerate(face_images)]
382
+ detected_info = f"🎯 Detected {num_speakers} speaker{'s' if num_speakers > 1 else ''} in the video"
383
+ detected_info_output = gr.update(value=detected_info, visible=True)
384
+ face_gallery_output = gr.update(value=face_gallery_images, visible=True)
385
+ if annotated_frame is not None:
386
+ annotated_frame_output = gr.update(value=annotated_frame, visible=True)
387
+
388
+ progress(0.3, desc=f"Separating {num_speakers} speakers...")
389
+ status_value = f"🎬 Separating {num_speakers} speakers..."
390
+ print(f"Starting audio-visual separation for {num_speakers} speakers...")
391
+ yield snapshot()
392
+
393
+ try:
394
+ print("Starting GPU-accelerated video processing...")
395
+ with GPU_PROGRESS_LOCK:
396
+ GPU_PROGRESS_STATE["progress"] = 0.0
397
+ GPU_PROGRESS_STATE["status"] = "Processing on GPU..."
398
+
399
+ progress(0.4, desc="Processing on GPU...")
400
+ status_value = "Processing on GPU..."
401
+ yield snapshot()
402
+
403
+ gpu_result = {"output_files": None, "exception": None}
404
+
405
+ def run_gpu_processing():
406
+ try:
407
+ gpu_result["output_files"] = process_video_gpu(
408
+ video_file=video_file,
409
+ temp_dir=temp_dir,
410
+ num_speakers=num_speakers
411
+ )
412
+ except Exception as exc:
413
+ gpu_result["exception"] = exc
414
+
415
+ gpu_thread = threading.Thread(target=run_gpu_processing, daemon=True)
416
+ gpu_thread.start()
417
+
418
+ last_reported_progress = 0.4
419
+ last_status_message = "Processing on GPU..."
420
+
421
+ while gpu_thread.is_alive():
422
+ time.sleep(0.5)
423
+ with GPU_PROGRESS_LOCK:
424
+ gpu_status = GPU_PROGRESS_STATE.get("status", "Processing on GPU...")
425
+ gpu_progress_value = GPU_PROGRESS_STATE.get("progress", 0.0)
426
+
427
+ mapped_progress = 0.4 + 0.5 * gpu_progress_value
428
+ mapped_progress = min(mapped_progress, 0.89)
429
+
430
+ if (
431
+ mapped_progress > last_reported_progress + 0.01
432
+ or gpu_status != last_status_message
433
+ ):
434
+ progress(mapped_progress, desc=gpu_status)
435
+ last_reported_progress = mapped_progress
436
+ last_status_message = gpu_status
437
+ status_value = gpu_status
438
+ yield snapshot()
439
+
440
+ gpu_thread.join()
441
+
442
+ if gpu_result["exception"] is not None:
443
+ raise gpu_result["exception"]
444
+
445
+ output_files = gpu_result["output_files"]
446
+
447
+ progress(0.9, desc="Preparing results...")
448
+ status_value = "📦 Preparing results..."
449
+ print("Processing completed successfully!")
450
+ print(f"Generated {num_speakers} output videos")
451
+ yield snapshot()
452
+
453
+ video_dict_value = {i: output_files[i] for i in range(num_speakers)}
454
+ video_dict_value['temp_dir'] = temp_dir
455
+ output_video_output = gr.update(value=output_files[0], visible=True)
456
+
457
+ progress(1.0, desc="Complete!")
458
+ status_value = f"✅ Successfully separated {num_speakers} speakers! Click on any face below to view their video."
459
+ yield snapshot()
460
+
461
+ except Exception as e:
462
+ print(f"Processing failed: {str(e)}")
463
+ import traceback
464
+ traceback.print_exc()
465
+ status_value = f"❌ Processing failed: {str(e)}"
466
+ output_video_output = gr.update(visible=False)
467
+ video_dict_value = None
468
+ yield snapshot()
469
+ return
470
+
471
+ except Exception as e:
472
+ if temp_dir and os.path.exists(temp_dir):
473
+ try:
474
+ shutil.rmtree(temp_dir)
475
+ except Exception:
476
+ pass
477
+
478
+ print(f"Error: {str(e)}")
479
+ import traceback
480
+ traceback.print_exc()
481
+ status_value = f"❌ Error: {str(e)}"
482
+ detected_info_output = gr.update(visible=False)
483
+ face_gallery_output = gr.update(visible=False)
484
+ output_video_output = gr.update(visible=False)
485
+ annotated_frame_output = gr.update(visible=False)
486
+ video_dict_value = None
487
+ yield snapshot()
488
+ return
489
+ finally:
490
+ sys.stdout = old_stdout
491
+
492
+ def on_face_click(evt: gr.SelectData, video_dict):
493
+ """Handle face gallery click events"""
494
+ if video_dict is None or evt.index not in video_dict:
495
+ return None
496
+
497
+ return video_dict[evt.index]
498
+
499
+ # Create the Gradio interface
500
+ custom_css = """
501
+ .face-gallery {
502
+ border-radius: 10px;
503
+ overflow: hidden;
504
+ }
505
+ .face-gallery img {
506
+ border-radius: 8px;
507
+ transition: transform 0.2s ease-in-out;
508
+ }
509
+ .face-gallery img:hover {
510
+ transform: scale(1.05);
511
+ cursor: pointer;
512
+ box-shadow: 0 4px 8px rgba(0,0,0,0.3);
513
+ }
514
+ .detected-info {
515
+ background-color: #f0f0f0;
516
+ padding: 10px;
517
+ border-radius: 5px;
518
+ margin: 10px 0;
519
+ }
520
+ """
521
+
522
+ with gr.Blocks(
523
+ title="Video Speaker Auto-Separation",
524
+ theme=gr.themes.Soft(),
525
+ css=custom_css
526
+ ) as demo:
527
+ gr.Markdown(
528
+ """
529
+ # 🎥 Dolphin: Efficient Audio-Visual Speech Separation with Discrete Lip Semantics and Hierarchical Top-Down Attention
530
+ <p align="left">
531
+ <img src="https://visitor-badge.laobi.icu/badge?page_id=JusperLee.Dolphin" alt="访客统计" /><img src="https://img.shields.io/github/stars/JusperLee/Dolphin?style=social" alt="GitHub stars" /><img alt="Static Badge" src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" />
532
+ </p>
533
+
534
+ ### Automatically detect and separate ALL speakers in your video
535
+
536
+ Simply upload a video and the system will:
537
+ 1. 🔍 Automatically detect all speakers in the video
538
+ 2. 🎭 Show you each detected speaker's face
539
+ 3. 🎬 Generate individual videos for each speaker with their isolated audio
540
+ """
541
+ )
542
+
543
+ with gr.Row():
544
+ with gr.Column(scale=2):
545
+ video_input = gr.Video(
546
+ label="📹 Upload Your Video",
547
+ height=300,
548
+ interactive=True
549
+ )
550
+
551
+ # Add example video section
552
+ gr.Markdown("### 🎬 Try with Example Video")
553
+ gr.Examples(
554
+ examples=[["demo1/mix.mp4"]],
555
+ inputs=video_input,
556
+ label="Click to load example video",
557
+ cache_examples=False
558
+ )
559
+
560
+ process_btn = gr.Button(
561
+ "🚀 Auto-Detect and Process",
562
+ variant="primary",
563
+ size="lg"
564
+ )
565
+
566
+ status = gr.Textbox(
567
+ label="Status",
568
+ interactive=False,
569
+ value="⏳ Ready to process..."
570
+ )
571
+
572
+ processing_log = gr.Textbox(
573
+ label="📋 Processing Details",
574
+ lines=10,
575
+ max_lines=15,
576
+ interactive=False,
577
+ value=""
578
+ )
579
+
580
+ with gr.Column(scale=3):
581
+ annotated_frame = gr.Image(
582
+ label="📸 Detected Speakers in First Frame",
583
+ visible=False,
584
+ height=300
585
+ )
586
+
587
+ detected_info = gr.Markdown(
588
+ visible=False,
589
+ elem_classes="detected-info"
590
+ )
591
+
592
+ gr.Markdown("### 👇 Click on any face below to view that speaker's video")
593
+
594
+ face_gallery = gr.Gallery(
595
+ label="Detected Speaker Faces",
596
+ show_label=False,
597
+ columns=5,
598
+ rows=1,
599
+ height=200,
600
+ visible=False,
601
+ object_fit="contain",
602
+ elem_classes="face-gallery"
603
+ )
604
+
605
+ output_video = gr.Video(
606
+ label="🎬 Selected Speaker's Video",
607
+ height=300,
608
+ visible=False,
609
+ autoplay=True
610
+ )
611
+
612
+ # Hidden state
613
+ video_dict = gr.State()
614
+
615
+ gr.Markdown(
616
+ """
617
+ ---
618
+ ### 📖 How it works:
619
+
620
+ 1. **Upload** - Select any video file
621
+ 2. **Process** - Click the button to start automatic detection
622
+ 3. **Review** - See all detected speakers and their positions
623
+ 4. **Select** - Click on any face to watch that speaker's separated video
624
+
625
+ ### 💡 Tips for best results:
626
+
627
+ - ✅ Ensure all speakers' faces are visible in the first frame
628
+ - ✅ Use videos with good lighting and clear face views
629
+ - ✅ Works best with frontal or near-frontal face angles
630
+ - ⏱️ Processing time depends on video length and number of speakers
631
+
632
+ ### 🚀 Powered by:
633
+ - RetinaFace for face detection
634
+ - Dolphin model for audio-visual separation
635
+ - GPU acceleration when available
636
+ <footer style="display:none;">
637
+ <a href='https://clustrmaps.com/site/1c828' title='Visit tracker'>
638
+ <img src='//clustrmaps.com/map_v2.png?cl=080808&w=300&t=tt&d=XYmTC4S_SxuX7G06iJ16lU43VCNkCBFRLXMfEM5zvmo&co=ffffff&ct=808080'/>
639
+ </a>
640
+ </footer>
641
+ """
642
+ )
643
+
644
+ # Event handlers
645
+ outputs_list = [
646
+ status,
647
+ detected_info,
648
+ face_gallery,
649
+ output_video,
650
+ video_dict,
651
+ annotated_frame,
652
+ processing_log
653
+ ]
654
+
655
+ process_btn.click(
656
+ fn=process_video_auto,
657
+ inputs=[video_input],
658
+ outputs=outputs_list,
659
+ show_progress=True
660
+ )
661
+
662
+ face_gallery.select(
663
+ fn=on_face_click,
664
+ inputs=[video_dict],
665
+ outputs=output_video
666
+ )
667
+
668
+ # Launch the demo - HF Space will handle this automatically
669
+ if __name__ == "__main__":
670
+ import os
671
+ demo.launch(server_name="0.0.0.0", server_port=7860)
console_capture.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import io
3
+ from contextlib import contextmanager
4
+
5
+ class TeeOutput:
6
+ """Capture stdout/stderr while still printing to console"""
7
+ def __init__(self, stream, callback=None):
8
+ self.stream = stream
9
+ self.callback = callback
10
+ self.buffer = []
11
+
12
+ def write(self, data):
13
+ # Write to original stream
14
+ self.stream.write(data)
15
+ self.stream.flush()
16
+
17
+ # Capture the data
18
+ if data.strip(): # Only capture non-empty lines
19
+ self.buffer.append(data.rstrip())
20
+ if self.callback:
21
+ self.callback(data.rstrip())
22
+
23
+ def flush(self):
24
+ self.stream.flush()
25
+
26
+ def get_captured(self):
27
+ return '\n'.join(self.buffer)
28
+
29
+ @contextmanager
30
+ def capture_console(stdout_callback=None, stderr_callback=None):
31
+ """Context manager to capture console output"""
32
+ old_stdout = sys.stdout
33
+ old_stderr = sys.stderr
34
+
35
+ stdout_capture = TeeOutput(old_stdout, stdout_callback)
36
+ stderr_capture = TeeOutput(old_stderr, stderr_callback)
37
+
38
+ sys.stdout = stdout_capture
39
+ sys.stderr = stderr_capture
40
+
41
+ try:
42
+ yield stdout_capture, stderr_capture
43
+ finally:
44
+ sys.stdout = old_stdout
45
+ sys.stderr = old_stderr
demo1/mix.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbf7577afd8b8ebc4a70d88e5d6a8216dd9ca07a1e26a83a0c677074510ec39c
3
+ size 3387273
face_detection_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility helpers for RetinaFace-based face detection."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import cv2
8
+
9
+ try:
10
+ from retinaface import RetinaFace # type: ignore
11
+ except ImportError as import_error: # pragma: no cover - handled at runtime
12
+ RetinaFace = None # type: ignore
13
+ _RETINAFACE_IMPORT_ERROR = import_error
14
+ else:
15
+ _RETINAFACE_IMPORT_ERROR = None
16
+
17
+ try:
18
+ from PIL import Image
19
+ except ImportError: # pragma: no cover
20
+ Image = None # type: ignore
21
+
22
+ import spaces
23
+
24
+ def _ensure_retinaface_available() -> None:
25
+ if RetinaFace is None: # pragma: no cover - runtime safeguard
26
+ raise ImportError(
27
+ "RetinaFace package is required but not installed. "
28
+ "Install it with `pip install retinaface`."
29
+ ) from _RETINAFACE_IMPORT_ERROR
30
+
31
+
32
+ def _to_rgb_array(image: np.ndarray, *, assume_bgr: bool = False) -> np.ndarray:
33
+ """Convert input to an RGB numpy array."""
34
+ if isinstance(image, np.ndarray):
35
+ array = image
36
+ elif Image is not None and isinstance(image, Image.Image):
37
+ array = np.array(image.convert("RGB"))
38
+ else:
39
+ raise TypeError("Expected an ndarray or PIL.Image.Image for face detection")
40
+
41
+ if array.ndim != 3 or array.shape[2] != 3:
42
+ raise ValueError("Face detection expects an image with shape (H, W, 3)")
43
+
44
+ if array.dtype != np.uint8:
45
+ array = array.astype(np.uint8)
46
+
47
+ if assume_bgr:
48
+ return cv2.cvtColor(array, cv2.COLOR_BGR2RGB)
49
+ return array
50
+
51
+ @spaces.GPU(duration=360)
52
+ def detect_faces(
53
+ image: np.ndarray,
54
+ *,
55
+ threshold: float = 0.9,
56
+ allow_upscaling: bool = False,
57
+ model: Optional[str] = None,
58
+ assume_bgr: bool = False,
59
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
60
+ """Run RetinaFace detection on an image.
61
+
62
+ Returns bounding boxes shaped (N, 4) and confidence scores shaped (N,).
63
+ If no face is detected, both values are ``None``.
64
+ """
65
+ _ensure_retinaface_available()
66
+
67
+ rgb_image = _to_rgb_array(image, assume_bgr=assume_bgr)
68
+ bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
69
+
70
+ detections = RetinaFace.detect_faces(
71
+ bgr_image,
72
+ threshold=threshold,
73
+ model=model,
74
+ allow_upscaling=allow_upscaling,
75
+ )
76
+
77
+ if not isinstance(detections, dict) or not detections:
78
+ return None, None
79
+
80
+ boxes, scores = [], []
81
+ for face_data in detections.values():
82
+ facial_area = face_data.get("facial_area")
83
+ if facial_area is None:
84
+ continue
85
+ boxes.append(facial_area)
86
+ scores.append(face_data.get("score", 0.0))
87
+
88
+ if not boxes:
89
+ return None, None
90
+
91
+ boxes_array = np.asarray(boxes, dtype=np.float32)
92
+ scores_array = np.asarray(scores, dtype=np.float32) if scores else None
93
+
94
+ return boxes_array, scores_array
95
+
96
+ @spaces.GPU(duration=360)
97
+ def extract_faces(
98
+ image: np.ndarray,
99
+ *,
100
+ align: bool = True,
101
+ threshold: float = 0.9,
102
+ allow_upscaling: bool = False,
103
+ model: Optional[str] = None,
104
+ assume_bgr: bool = False,
105
+ ) -> Optional[np.ndarray]:
106
+ """Extract faces using RetinaFace.extract_faces for convenience."""
107
+ _ensure_retinaface_available()
108
+
109
+ rgb_image = _to_rgb_array(image, assume_bgr=assume_bgr)
110
+ bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
111
+
112
+ faces = RetinaFace.extract_faces(
113
+ bgr_image,
114
+ align=align,
115
+ threshold=threshold,
116
+ model=model,
117
+ allow_upscaling=allow_upscaling,
118
+ )
119
+
120
+ if not faces:
121
+ return None
122
+ return np.asarray([np.asarray(face, dtype=np.uint8) for face in faces])
look2hear/datas/transform.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # Author: Kai Li
3
+ # Date: 2021-06-19 22:34:13
4
+ # LastEditors: Kai Li
5
+ # LastEditTime: 2021-08-30 20:01:43
6
+ ###
7
+
8
+ import cv2
9
+ import random
10
+ import numpy as np
11
+ import torchvision
12
+
13
+ __all__ = [
14
+ "Compose",
15
+ "Normalize",
16
+ "CenterCrop",
17
+ "RgbToGray",
18
+ "RandomCrop",
19
+ "HorizontalFlip",
20
+ ]
21
+
22
+
23
+ class Compose(object):
24
+ """Compose several preprocess together.
25
+ Args:
26
+ preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
27
+ """
28
+
29
+ def __init__(self, preprocess):
30
+ self.preprocess = preprocess
31
+
32
+ def __call__(self, sample):
33
+ for t in self.preprocess:
34
+ sample = t(sample)
35
+ return sample
36
+
37
+ def __repr__(self):
38
+ format_string = self.__class__.__name__ + "("
39
+ for t in self.preprocess:
40
+ format_string += "\n"
41
+ format_string += " {0}".format(t)
42
+ format_string += "\n)"
43
+ return format_string
44
+
45
+
46
+ class RgbToGray(object):
47
+ """Convert image to grayscale.
48
+ Converts a numpy.ndarray (H x W x C) in the range
49
+ [0, 255] to a numpy.ndarray of shape (H x W x C) in the range [0.0, 1.0].
50
+ """
51
+
52
+ def __call__(self, frames):
53
+ """
54
+ Args:
55
+ img (numpy.ndarray): Image to be converted to gray.
56
+ Returns:
57
+ numpy.ndarray: grey image
58
+ """
59
+ frames = np.stack([cv2.cvtColor(_, cv2.COLOR_RGB2GRAY) for _ in frames], axis=0)
60
+ return frames
61
+
62
+ def __repr__(self):
63
+ return self.__class__.__name__ + "()"
64
+
65
+
66
+ class Normalize(object):
67
+ """Normalize a ndarray image with mean and standard deviation."""
68
+
69
+ def __init__(self, mean, std):
70
+ self.mean = mean
71
+ self.std = std
72
+
73
+ def __call__(self, frames):
74
+ """
75
+ Args:
76
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
77
+ Returns:
78
+ Tensor: Normalized Tensor image.
79
+ """
80
+ frames = (frames - self.mean) / self.std
81
+ return frames
82
+
83
+ def __repr__(self):
84
+ return self.__class__.__name__ + "(mean={0}, std={1})".format(
85
+ self.mean, self.std
86
+ )
87
+
88
+
89
+ class CenterCrop(object):
90
+ """Crop the given image at the center"""
91
+
92
+ def __init__(self, size):
93
+ self.size = size
94
+
95
+ def __call__(self, frames):
96
+ """
97
+ Args:
98
+ img (numpy.ndarray): Images to be cropped.
99
+ Returns:
100
+ numpy.ndarray: Cropped image.
101
+ """
102
+ t, h, w = frames.shape
103
+ th, tw = self.size
104
+ delta_w = int(round((w - tw)) / 2.0)
105
+ delta_h = int(round((h - th)) / 2.0)
106
+ frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
107
+ return frames
108
+
109
+
110
+ class RandomCrop(object):
111
+ """Crop the given image at the center"""
112
+
113
+ def __init__(self, size):
114
+ self.size = size
115
+
116
+ def __call__(self, frames):
117
+ """
118
+ Args:
119
+ img (numpy.ndarray): Images to be cropped.
120
+ Returns:
121
+ numpy.ndarray: Cropped image.
122
+ """
123
+ t, h, w = frames.shape
124
+ th, tw = self.size
125
+ delta_w = random.randint(0, w - tw)
126
+ delta_h = random.randint(0, h - th)
127
+ frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
128
+ return frames
129
+
130
+ def __repr__(self):
131
+ return self.__class__.__name__ + "(size={0})".format(self.size)
132
+
133
+
134
+ class HorizontalFlip(object):
135
+ """Flip image horizontally."""
136
+
137
+ def __init__(self, flip_ratio):
138
+ self.flip_ratio = flip_ratio
139
+
140
+ def __call__(self, frames):
141
+ """
142
+ Args:
143
+ img (numpy.ndarray): Images to be flipped with a probability flip_ratio
144
+ Returns:
145
+ numpy.ndarray: Cropped image.
146
+ """
147
+ t, h, w = frames.shape
148
+ if random.random() < self.flip_ratio:
149
+ for index in range(t):
150
+ frames[index] = cv2.flip(frames[index], 1)
151
+ return frames
152
+
153
+
154
+ def get_preprocessing_pipelines():
155
+ # -- preprocess for the video stream
156
+ preprocessing = {}
157
+ # -- LRW config
158
+ crop_size = (88, 88)
159
+ (mean, std) = (0.421, 0.165)
160
+ preprocessing["train"] = Compose(
161
+ [
162
+ Normalize(0.0, 255.0),
163
+ RandomCrop(crop_size),
164
+ HorizontalFlip(0.5),
165
+ Normalize(mean, std),
166
+ ]
167
+ )
168
+ preprocessing["val"] = Compose(
169
+ [Normalize(0.0, 255.0), CenterCrop(crop_size), Normalize(mean, std)]
170
+ )
171
+ preprocessing["test"] = preprocessing["val"]
172
+ return preprocessing
173
+
174
+ def get_preprocessing_opt_pipelines():
175
+ preprocessing = {}
176
+ # -- LRW config
177
+ crop_size = (88, 88)
178
+ (mean, std) = (0.421, 0.165)
179
+ preprocessing["train"] = torchvision.transforms.Compose([
180
+ torchvision.transforms.Normalize(0.0, 255.0),
181
+ torchvision.transforms.RandomCrop(crop_size),
182
+ torchvision.transforms.RandomHorizontalFlip(0.5),
183
+ torchvision.transforms.Normalize(mean, std)
184
+ ])
185
+ preprocessing["val"] = torchvision.transforms.Compose([
186
+ torchvision.transforms.Normalize(0.0, 255.0),
187
+ torchvision.transforms.CenterCrop(crop_size),
188
+ torchvision.transforms.Normalize(mean, std)
189
+ ])
190
+ preprocessing["test"] = preprocessing["val"]
191
+ return preprocessing
look2hear/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dolphin import Dolphin
look2hear/models/dolphin.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dolphin Model
3
+
4
+ This implementation is inspired by and borrows concepts from Sepformer.
5
+ The original Sepformer work is licensed under the Apache-2.0 License.
6
+
7
+ References:
8
+ - SepReformer: https://github.com/dmlguq456/SepReformer
9
+ - Apache-2.0 License: https://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ """
12
+
13
+ from re import S
14
+ import torch
15
+ import numpy as np
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import math
19
+ from vector_quantize_pytorch import ResidualVQ
20
+ from .video_compoent import *
21
+ from huggingface_hub import PyTorchModelHubMixin
22
+
23
+ class LayerScale(torch.nn.Module):
24
+ def __init__(self, dims, input_size, Layer_scale_init=1.0e-5):
25
+ super().__init__()
26
+ if dims == 1:
27
+ self.layer_scale = torch.nn.Parameter(torch.ones(input_size)*Layer_scale_init, requires_grad=True)
28
+ elif dims == 2:
29
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,input_size)*Layer_scale_init, requires_grad=True)
30
+ elif dims == 3:
31
+ self.layer_scale = torch.nn.Parameter(torch.ones(1,1,input_size)*Layer_scale_init, requires_grad=True)
32
+
33
+ def forward(self, x):
34
+ return x*self.layer_scale
35
+
36
+ class Masking(torch.nn.Module):
37
+ def __init__(self, input_dim):
38
+ super(Masking, self).__init__()
39
+ self.gate_act = torch.nn.ReLU()
40
+
41
+ def forward(self, x, skip):
42
+ return self.gate_act(x) * skip
43
+
44
+
45
+ class FFN(torch.nn.Module):
46
+ def __init__(self, in_channels, dropout_rate, Layer_scale_init=1.0e-5):
47
+ super().__init__()
48
+ expand_factor = 3
49
+ self.net1 = torch.nn.Sequential(
50
+ torch.nn.LayerNorm(in_channels),
51
+ torch.nn.Linear(in_channels, in_channels * expand_factor))
52
+ self.depthwise = torch.nn.Conv1d(in_channels * expand_factor, in_channels * expand_factor, 3, padding=1, groups=in_channels * expand_factor)
53
+ self.net2 = torch.nn.Sequential(
54
+ torch.nn.GLU(),
55
+ torch.nn.Dropout(dropout_rate),
56
+ torch.nn.Linear(in_channels * expand_factor // 2, in_channels),
57
+ torch.nn.Dropout(dropout_rate))
58
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
59
+
60
+ def forward(self, x):
61
+ y = self.net1(x)
62
+ y = y.permute(0, 2, 1).contiguous()
63
+ y = self.depthwise(y)
64
+ y = y.permute(0, 2, 1).contiguous()
65
+ y = self.net2(y)
66
+ return x + self.Layer_scale(y)
67
+
68
+
69
+ class MultiHeadAttention(torch.nn.Module):
70
+ """
71
+ Multi-Head Attention layer.
72
+ :param int n_head: the number of head s
73
+ :param int n_feat: the number of features
74
+ :param float dropout_rate: dropout rate
75
+ """
76
+ def __init__(self, n_head: int, in_channels: int, dropout_rate: float, Layer_scale_init=1.0e-5):
77
+ super().__init__()
78
+ assert in_channels % n_head == 0
79
+ self.d_k = in_channels // n_head # We assume d_v always equals d_k
80
+ self.h = n_head
81
+ self.layer_norm = torch.nn.LayerNorm(in_channels)
82
+ self.linear_q = torch.nn.Linear(in_channels, in_channels)
83
+ self.linear_k = torch.nn.Linear(in_channels, in_channels)
84
+ self.linear_v = torch.nn.Linear(in_channels, in_channels)
85
+ self.linear_out = torch.nn.Linear(in_channels, in_channels)
86
+ self.attn = None
87
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
88
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
89
+
90
+ def forward(self, x, pos_k, mask):
91
+ """
92
+ Compute 'Scaled Dot Product Attention'.
93
+ :param torch.Tensor mask: (batch, time1, time2)
94
+ :param torch.nn.Dropout dropout:
95
+ :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
96
+ weighted by the query dot key attention (batch, head, time1, time2)
97
+ """
98
+ n_batch = x.size(0)
99
+ x = self.layer_norm(x)
100
+ q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
101
+ k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d)
102
+ v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k)
103
+ q = q.transpose(1, 2)
104
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
105
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
106
+ A = torch.matmul(q, k.transpose(-2, -1))
107
+ reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1)
108
+ if pos_k is not None:
109
+ B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))
110
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
111
+ scores = (A + B) / math.sqrt(self.d_k)
112
+ else:
113
+ scores = A / math.sqrt(self.d_k)
114
+ if mask is not None:
115
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
116
+ min_value = float(np.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
117
+ scores = scores.masked_fill(mask, min_value)
118
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
119
+ else:
120
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
121
+ p_attn = self.dropout(self.attn)
122
+ x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
123
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
124
+ return self.Layer_scale(self.dropout(self.linear_out(x))) # (batch, time1, d_model)
125
+
126
+ class DU_MHSA(torch.nn.Module):
127
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
128
+ super().__init__()
129
+ self.block = torch.nn.ModuleDict({
130
+ 'self_attn': MultiHeadAttention(
131
+ n_head=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
132
+ 'linear': torch.nn.Sequential(
133
+ torch.nn.LayerNorm(normalized_shape=in_channels),
134
+ torch.nn.Linear(in_features=in_channels, out_features=in_channels),
135
+ torch.nn.Sigmoid())
136
+ })
137
+
138
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
139
+ """
140
+ Compute encoded features.
141
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
142
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
143
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
144
+ """
145
+ down_len = pos_k.shape[0]
146
+ x_down = torch.nn.functional.adaptive_avg_pool1d(input=x, output_size=down_len)
147
+ x = x.permute([0, 2, 1])
148
+ x_down = x_down.permute([0, 2, 1])
149
+ x_down = self.block['self_attn'](x_down, pos_k, None)
150
+ x_down = x_down.permute([0, 2, 1])
151
+ x_downup = torch.nn.functional.upsample(input=x_down, size=x.shape[1])
152
+ x_downup = x_downup.permute([0, 2, 1])
153
+ x = x + self.block['linear'](x) * x_downup
154
+
155
+ return x
156
+
157
+ class Heat1D(nn.Module):
158
+ """
159
+ 1D Heat Equation Adaptation:
160
+ du/dt - k d²u/dx² = 0;
161
+ du/dx_{x=0, x=a} = 0
162
+ =>
163
+ A_n = C(a, n==0) * sum_{0}^{a} { \phi(x) cos(n π / a x) dx }
164
+ core = cos(n π / a x) exp(- (n π / a)^2 k t)
165
+ u_{x, t} = sum_{0}^{\infinite} { core }
166
+
167
+ Assume a = T; x in [0, T]; n in [0, T]; with some slight changes
168
+ =>
169
+ (\phi(x) = linear(dwconv(input(x))))
170
+ A(n) = DCT1D(\phi(x))
171
+ u(x, t) = IDCT1D(A(n) * exp(- (n π / a)^2 kt))
172
+ """
173
+ def __init__(self, dim=96, hidden_dim=96, **kwargs):
174
+ super().__init__()
175
+ self.dwconv = nn.Conv1d(dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
176
+ self.hidden_dim = hidden_dim
177
+ self.linear = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
178
+ self.out_norm = nn.LayerNorm(hidden_dim)
179
+ self.out_linear = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim)
180
+ self.to_k = nn.Sequential(
181
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim),
182
+ nn.GELU(),
183
+ )
184
+
185
+ self.k = nn.Parameter(torch.ones(hidden_dim))
186
+
187
+ @staticmethod
188
+ def get_cos_map(N=224, device=torch.device("cpu"), dtype=torch.float):
189
+ # cos((x + 0.5) / N * n * π) which is also the form of DCT and IDCT
190
+ # DCT: F(n) = sum( (sqrt(2/N) if n > 0 else sqrt(1/N)) * cos((x + 0.5) / N * n * π) * f(x) )
191
+ # IDCT: f(x) = sum( (sqrt(2/N) if n > 0 else sqrt(1/N)) * cos((x + 0.5) / N * n * π) * F(n) )
192
+ # returns: (Res_n, Res_x)
193
+ weight_x = (torch.linspace(0, N - 1, N, device=device, dtype=dtype).view(1, -1) + 0.5) / N
194
+ weight_n = torch.linspace(0, N - 1, N, device=device, dtype=dtype).view(-1, 1)
195
+ weight = torch.cos(weight_n * weight_x * torch.pi) * math.sqrt(2 / N)
196
+ weight[0, :] = weight[0, :] / math.sqrt(2)
197
+ return weight
198
+
199
+ @staticmethod
200
+ def get_decay_map(resolution=224, device=torch.device("cpu"), dtype=torch.float):
201
+ # exp(- (n π / T)^2) for 1D
202
+ # returns: (Res_t,)
203
+ res_t = resolution
204
+ weight_n = torch.linspace(0, torch.pi, res_t + 1, device=device, dtype=dtype)[:res_t]
205
+ weight = torch.pow(weight_n, 2)
206
+ weight = torch.exp(-weight)
207
+ return weight
208
+
209
+ def forward(self, x: torch.Tensor, freq_embed=None):
210
+ B, T, C = x.shape
211
+ x = x.transpose(1, 2) # [B, T, C] -> [B, C, T]
212
+ x = self.dwconv(x) # [B, hidden_dim, T]
213
+
214
+ x = self.linear(x) # [B, 2 * hidden_dim, T]
215
+ x, z = x.chunk(chunks=2, dim=1) # [B, hidden_dim, T], [B, hidden_dim, T]
216
+
217
+ if (T == getattr(self, "__RES__", 0)) and (getattr(self, "__WEIGHT_COSN__", None).device == x.device):
218
+ weight_cosn = getattr(self, "__WEIGHT_COSN__", None)
219
+ weight_exp = getattr(self, "__WEIGHT_EXP__", None)
220
+ assert weight_cosn is not None
221
+ assert weight_exp is not None
222
+ else:
223
+ weight_cosn = self.get_cos_map(T, device=x.device).detach_()
224
+ weight_exp = self.get_decay_map(T, device=x.device).detach_()
225
+ setattr(self, "__RES__", T)
226
+ setattr(self, "__WEIGHT_COSN__", weight_cosn)
227
+ setattr(self, "__WEIGHT_EXP__", weight_exp)
228
+
229
+ N = weight_cosn.shape[0] # N == T
230
+
231
+ x = x.transpose(1, 2).contiguous() # [B, T, hidden_dim]
232
+
233
+ x = F.conv1d(x.contiguous().view(B, T, -1), weight_cosn.contiguous().view(N, T, 1)) # [B, N, hidden_dim]
234
+
235
+ weight_exp = torch.pow(weight_exp[:, None], self.k)
236
+ x = torch.einsum("bnc,nc->bnc", x, weight_exp) # exp decay
237
+
238
+ x = F.conv1d(x.contiguous().view(B, N, -1), weight_cosn.t().contiguous().view(T, N, 1)) # [B, T, hidden_dim]
239
+
240
+ x = self.out_norm(x) # [B, T, hidden_dim]
241
+
242
+ z = z.transpose(1, 2).contiguous() # [B, T, hidden_dim]
243
+ x = x * nn.functional.silu(z) # [B, T, hidden_dim]
244
+
245
+ x = x.transpose(1, 2).contiguous() # [B, hidden_dim, T]
246
+ x = self.out_linear(x) # [B, hidden_dim, T]
247
+
248
+ x = x.transpose(1, 2).contiguous() # [B, T, hidden_dim]
249
+
250
+ return x
251
+
252
+ class CLA(torch.nn.Module):
253
+ def __init__(self, in_channels, kernel_size, dropout_rate, Layer_scale_init=1.0e-5):
254
+ super().__init__()
255
+ # self.layer_norm = torch.nn.LayerNorm(in_channels)
256
+ self.heat1d = Heat1D(in_channels, in_channels)
257
+ self.GN1 = torch.nn.GroupNorm(1, in_channels)
258
+ self.dw_conv_1d = torch.nn.Conv1d(in_channels, in_channels, kernel_size, padding='same', groups=in_channels)
259
+ self.GN2 = torch.nn.GroupNorm(1, in_channels)
260
+ self.linear3 = torch.nn.Sequential(
261
+ torch.nn.GELU(),
262
+ torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels),
263
+ torch.nn.Dropout(dropout_rate))
264
+ self.Layer_scale = LayerScale(dims=3, input_size=in_channels, Layer_scale_init=Layer_scale_init)
265
+
266
+ def forward(self, x):
267
+ # y = self.layer_norm(x)
268
+ y = self.heat1d(x)
269
+ y = y.permute([0, 2, 1]) # B, F, T
270
+ y = self.GN1(y)
271
+ y = self.dw_conv_1d(y)
272
+ y = self.linear3(y)
273
+ y = self.GN2(y) # [B, in_channels, T]
274
+ y = y.permute(0, 2, 1) # B, T, in_channels
275
+ return x + self.Layer_scale(y)
276
+
277
+ class GlobalBlock(torch.nn.Module):
278
+ def __init__(self, in_channels: int, num_mha_heads: int, dropout_rate: float):
279
+ super().__init__()
280
+ self.block = torch.nn.ModuleDict({
281
+ 'DU_MHSA': DU_MHSA(
282
+ num_mha_heads=num_mha_heads, in_channels=in_channels, dropout_rate=dropout_rate),
283
+ 'FFN': FFN(in_channels=in_channels, dropout_rate=dropout_rate)
284
+ })
285
+
286
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
287
+ """
288
+ Compute encoded features.
289
+ :param torch.Tensor x: encoded source features (batch, max_time_in, size)
290
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
291
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
292
+ """
293
+ x = self.block['DU_MHSA'](x, pos_k)
294
+ x = self.block['FFN'](x)
295
+ x = x.permute([0, 2, 1])
296
+
297
+ return x
298
+
299
+
300
+ class LocalBlock(torch.nn.Module):
301
+ def __init__(self, in_channels: int, kernel_size: int, dropout_rate: float):
302
+ super().__init__()
303
+ self.block = torch.nn.ModuleDict({
304
+ 'CLA': CLA(in_channels, kernel_size, dropout_rate),
305
+ 'FFN': FFN(in_channels, dropout_rate)
306
+ })
307
+
308
+ def forward(self, x: torch.Tensor):
309
+ x = self.block['CLA'](x)
310
+ x = self.block['FFN'](x)
311
+
312
+ return x
313
+
314
+ class AudioEncoder(torch.nn.Module):
315
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, groups: int, bias: bool):
316
+ super().__init__()
317
+ self.conv1d = torch.nn.Conv1d(
318
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, bias=bias)
319
+ self.gelu = torch.nn.GELU()
320
+
321
+ def forward(self, x: torch.Tensor):
322
+ x = torch.unsqueeze(x, dim=0) if len(x.shape) == 1 else torch.unsqueeze(x, dim=1) # [T] - >[1, T] OR [B, T] -> [B, 1, T]
323
+ x = self.conv1d(x)
324
+ x = self.gelu(x)
325
+ return x
326
+
327
+ class FeatureProjector(torch.nn.Module):
328
+ def __init__(self, num_channels: int, in_channels: int, out_channels: int, kernel_size: int, bias: bool):
329
+ super().__init__()
330
+ self.norm = torch.nn.GroupNorm(num_groups=1, num_channels=num_channels, eps=1e-8)
331
+ self.conv1d = torch.nn.Conv1d(
332
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
333
+
334
+ def forward(self, x: torch.Tensor):
335
+ x = self.norm(x)
336
+ x = self.conv1d(x)
337
+ return x
338
+
339
+ class HeatConvNorm(nn.Module):
340
+ """
341
+ This class defines the convolution layer with normalization and PReLU activation
342
+ """
343
+
344
+ def __init__(
345
+ self, nIn, nOut, kSize, stride=1, groups=1, bias=True, norm_type="gLN"
346
+ ):
347
+ """
348
+ :param nIn: number of input channels
349
+ :param nOut: number of output channels
350
+ :param kSize: kernel size
351
+ :param stride: stride rate for down-sampling. Default is 1
352
+ """
353
+ super().__init__()
354
+ padding = int((kSize - 1) / 2)
355
+ self.conv = Heat1D(
356
+ nIn, nOut, groups=groups
357
+ )
358
+ if norm_type == "gLN":
359
+ self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
360
+ if norm_type == "BN":
361
+ self.norm = nn.BatchNorm1d(nOut)
362
+
363
+ def forward(self, input):
364
+ input = input.permute(0, 2, 1)
365
+ output = self.conv(input).permute(0, 2, 1)
366
+ return self.norm(output)
367
+
368
+ class ConvNorm(nn.Module):
369
+ """
370
+ This class defines the convolution layer with normalization and PReLU activation
371
+ """
372
+
373
+ def __init__(
374
+ self, nIn, nOut, kSize, stride=1, groups=1, bias=True, norm_type="gLN"
375
+ ):
376
+ """
377
+ :param nIn: number of input channels
378
+ :param nOut: number of output channels
379
+ :param kSize: kernel size
380
+ :param stride: stride rate for down-sampling. Default is 1
381
+ """
382
+ super().__init__()
383
+ padding = int((kSize - 1) / 2)
384
+ self.conv = nn.Conv1d(
385
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups
386
+ )
387
+ if norm_type == "gLN":
388
+ self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
389
+ if norm_type == "BN":
390
+ self.norm = nn.BatchNorm1d(nOut)
391
+
392
+ def forward(self, input):
393
+ output = self.conv(input)
394
+ return self.norm(output)
395
+
396
+ class AVFModule(nn.Module):
397
+ """
398
+ 1D Attention Fusion Cell,将 tensor_b 导引 tensor_a 的 key & value:
399
+ Input:
400
+ tensor_a: [B, Ca, T]
401
+ tensor_b: [B, Cb, Tb]
402
+ Output:
403
+ [B, Ca, T]
404
+ """
405
+ def __init__(self,
406
+ in_chan_a: int,
407
+ in_chan_b: int,
408
+ kernel_size: int = 1):
409
+ super().__init__()
410
+ self.in_chan_a = in_chan_a
411
+ self.in_chan_b = in_chan_b
412
+ self.kernel_size = kernel_size
413
+ # audio key embedding (depthwise 1×1)
414
+ self.key_embed = ConvNormAct(
415
+ nIn=in_chan_a, nOut=in_chan_a, kSize=1,
416
+ groups=in_chan_a, norm_type="gLN"
417
+ )
418
+ # audio value embedding (depthwise 1×1)
419
+ self.value_embed = ConvNormAct(
420
+ nIn=in_chan_a, nOut=in_chan_a, kSize=1,
421
+ groups=in_chan_a, norm_type="gLN"
422
+ )
423
+
424
+ self.resize = ConvNormAct(
425
+ nIn=in_chan_b, nOut=in_chan_a, kSize=1,
426
+ norm_type="gLN"
427
+ )
428
+
429
+ self.attention_embed = ConvNormAct(
430
+ nIn=in_chan_b,
431
+ nOut=in_chan_a * kernel_size,
432
+ kSize=1,
433
+ groups=in_chan_b,
434
+ norm_type="gLN"
435
+ )
436
+
437
+ def forward(self, tensor_a: torch.Tensor, tensor_b: torch.Tensor):
438
+ """
439
+ tensor_a: [B, Ca, T]
440
+ tensor_b: [B, Cb, Tb]
441
+ """
442
+ B, Ca, T = tensor_a.shape
443
+ # 1) Use video to guide key_embed
444
+ b2a = self.resize(tensor_b) # [B, Ca, Tb]
445
+ b2a = F.interpolate(b2a, size=T, mode="nearest") # [B, Ca, T]
446
+ k1 = self.key_embed(tensor_a) * b2a # [B, Ca, T]
447
+ # 2) audio value
448
+ v = self.value_embed(tensor_a) # [B, Ca, T]
449
+ # 3) Calculate attention scores
450
+ att = self.attention_embed(tensor_b) # [B, Ca*kernel, Tb]
451
+ # reshape → [B, Ca, kernel, Tb]
452
+ att = att.view(B, Ca, self.kernel_size, -1)
453
+ att = att.mean(dim=2)
454
+ att = torch.softmax(att, dim=-1) # [B, Ca, Tb]
455
+ att = F.interpolate(att, size=T, mode="nearest") # [B, Ca, T]
456
+ # 4) k2 = attention * value
457
+ k2 = att * v
458
+
459
+ fused = k1 + k2 # [B, Ca, T]
460
+ return fused
461
+
462
+ class RelativePositionalEncoding(torch.nn.Module):
463
+ def __init__(self, in_channels: int, num_heads: int, maxlen: int, embed_v=False):
464
+ super().__init__()
465
+ self.in_channels = in_channels
466
+ self.num_heads = num_heads
467
+ self.embedding_dim = self.in_channels // self.num_heads
468
+ self.maxlen = maxlen
469
+ self.pe_k = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim)
470
+ self.pe_v = torch.nn.Embedding(num_embeddings=2*maxlen, embedding_dim=self.embedding_dim) if embed_v else None
471
+
472
+ def forward(self, pos_seq: torch.Tensor):
473
+ pos_seq.clamp_(-self.maxlen, self.maxlen - 1)
474
+ pos_seq += self.maxlen
475
+ pe_k_output = self.pe_k(pos_seq)
476
+ pe_v_output = self.pe_v(pos_seq) if self.pe_v is not None else None
477
+ return pe_k_output, pe_v_output
478
+
479
+ class DownConvLayer(torch.nn.Module):
480
+ def __init__(self, in_channels: int, samp_kernel_size: int):
481
+ """Construct an EncoderLayer object."""
482
+ super().__init__()
483
+ self.down_conv = torch.nn.Conv1d(
484
+ in_channels=in_channels, out_channels=in_channels, kernel_size=samp_kernel_size, stride=2, padding=(samp_kernel_size-1)//2, groups=in_channels)
485
+ self.GN = nn.GroupNorm(1, num_channels=in_channels)
486
+ self.gelu = torch.nn.GELU()
487
+
488
+ def forward(self, x: torch.Tensor):
489
+ x = x.permute([0, 2, 1])
490
+ x = self.down_conv(x)
491
+ x = self.GN(x)
492
+ x = self.gelu(x)
493
+ x = x.permute([0, 2, 1])
494
+ return x
495
+
496
+ class ConvNormAct(nn.Module):
497
+ """
498
+ This class defines the convolution layer with normalization and a PReLU
499
+ activation
500
+ """
501
+
502
+ def __init__(self, nIn, nOut, kSize, stride=1, groups=1, norm_type="gLN"):
503
+ """
504
+ :param nIn: number of input channels
505
+ :param nOut: number of output channels
506
+ :param kSize: kernel size
507
+ :param stride: stride rate for down-sampling. Default is 1
508
+ """
509
+ super().__init__()
510
+ padding = int((kSize - 1) / 2)
511
+ self.conv = nn.Conv1d(
512
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups
513
+ )
514
+ if norm_type == "gLN":
515
+ self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
516
+ if norm_type == "BN":
517
+ self.norm = nn.BatchNorm1d(nOut)
518
+ self.act = nn.PReLU()
519
+
520
+ def forward(self, input):
521
+ output = self.conv(input)
522
+ output = self.norm(output)
523
+ return self.act(output)
524
+
525
+ class DilatedConvNorm(nn.Module):
526
+ """
527
+ This class defines the dilated convolution with normalized output.
528
+ """
529
+
530
+ def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1, norm_type="gLN"):
531
+ """
532
+ :param nIn: number of input channels
533
+ :param nOut: number of output channels
534
+ :param kSize: kernel size
535
+ :param stride: optional stride rate for down-sampling
536
+ :param d: optional dilation rate
537
+ """
538
+ super().__init__()
539
+ self.conv = nn.Conv1d(
540
+ nIn,
541
+ nOut,
542
+ kSize,
543
+ stride=stride,
544
+ dilation=d,
545
+ padding=((kSize - 1) // 2) * d,
546
+ groups=groups,
547
+ )
548
+ # self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
549
+ if norm_type == "gLN":
550
+ self.norm = nn.GroupNorm(1, nOut, eps=1e-8)
551
+ if norm_type == "BN":
552
+ self.norm = nn.BatchNorm1d(nOut)
553
+
554
+ def forward(self, input):
555
+ output = self.conv(input)
556
+ return self.norm(output)
557
+
558
+ class Mlp(nn.Module):
559
+ def __init__(self, in_features, hidden_size, drop=0.1, norm_type="gLN"):
560
+ super().__init__()
561
+ self.fc1 = ConvNorm(
562
+ in_features, hidden_size, 1, bias=False, norm_type=norm_type
563
+ )
564
+ self.dwconv = nn.Conv1d(
565
+ hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size
566
+ )
567
+ self.act = nn.ReLU()
568
+ self.fc2 = ConvNorm(
569
+ hidden_size, in_features, 1, bias=False, norm_type=norm_type
570
+ )
571
+ self.drop = nn.Dropout(drop)
572
+
573
+ def forward(self, x):
574
+ x = self.fc1(x)
575
+ x = self.dwconv(x)
576
+ x = self.act(x)
577
+ x = self.drop(x)
578
+ x = self.fc2(x)
579
+ x = self.drop(x)
580
+ return x
581
+
582
+
583
+ class InjectionMultiSum(nn.Module):
584
+ def __init__(self, inp: int, oup: int, kernel: int = 1, norm_type="gLN") -> None:
585
+ super().__init__()
586
+ groups = 1
587
+ if inp == oup:
588
+ groups = inp
589
+ self.local_embedding = HeatConvNorm(
590
+ inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
591
+ )
592
+ self.global_embedding = HeatConvNorm(
593
+ inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
594
+ )
595
+ self.global_act = HeatConvNorm(
596
+ inp, oup, kernel, groups=groups, bias=False, norm_type=norm_type
597
+ )
598
+ self.act = nn.Sigmoid()
599
+
600
+ def forward(self, x_l, x_g):
601
+ """
602
+ x_g: global features
603
+ x_l: local features
604
+ """
605
+ B, N, T = x_l.shape
606
+ local_feat = self.local_embedding(x_l)
607
+
608
+ global_act = self.global_act(x_g)
609
+ sig_act = torch.nn.functional.interpolate(self.act(global_act), size=T, mode="nearest")
610
+ # sig_act = self.act(global_act)
611
+
612
+ global_feat = self.global_embedding(x_g)
613
+ global_feat = torch.nn.functional.interpolate(global_feat, size=T, mode="nearest")
614
+
615
+ out = local_feat * sig_act + global_feat
616
+ return out
617
+
618
+ class UConvBlock(nn.Module):
619
+ """
620
+ This class defines the block which performs successive downsampling and
621
+ upsampling in order to be able to analyze the input features in multiple
622
+ resolutions.
623
+ """
624
+
625
+ def __init__(
626
+ self, out_channels=128, in_channels=512, upsampling_depth=4, norm_type="gLN"
627
+ ):
628
+ super().__init__()
629
+ self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1, norm_type=norm_type)
630
+ self.depth = upsampling_depth
631
+ self.spp_dw = nn.ModuleList()
632
+ self.spp_dw.append(
633
+ DilatedConvNorm(
634
+ in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1, norm_type=norm_type
635
+ )
636
+ )
637
+ for i in range(1, upsampling_depth):
638
+ self.spp_dw.append(
639
+ DilatedConvNorm(
640
+ in_channels,
641
+ in_channels,
642
+ kSize=5,
643
+ stride=2,
644
+ groups=in_channels,
645
+ d=1,
646
+ norm_type=norm_type
647
+ )
648
+ )
649
+
650
+ self.loc_glo_fus = nn.ModuleList([])
651
+ for i in range(upsampling_depth):
652
+ self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels, norm_type=norm_type))
653
+
654
+ self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
655
+
656
+ self.globalatt = Mlp(in_channels, in_channels, drop=0.1)
657
+
658
+ self.last_layer = nn.ModuleList([])
659
+ for i in range(self.depth - 1):
660
+ self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5, norm_type=norm_type))
661
+
662
+ def forward(self, x):
663
+ """
664
+ :param x: input feature map
665
+ :return: transformed feature map
666
+ """
667
+ residual = x.clone()
668
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
669
+ output1 = self.proj_1x1(x)
670
+ output = [self.spp_dw[0](output1)]
671
+
672
+ # Do the downsampling process from the previous level
673
+ for k in range(1, self.depth):
674
+ out_k = self.spp_dw[k](output[-1])
675
+ output.append(out_k)
676
+
677
+ # global features
678
+ global_f = torch.zeros(
679
+ output[-1].shape, requires_grad=True, device=output1.device
680
+ )
681
+ for fea in output:
682
+ global_f = global_f + torch.nn.functional.adaptive_avg_pool1d(
683
+ fea, output_size=output[-1].shape[-1]
684
+ )
685
+ # global_f = global_f + fea
686
+ global_f = self.globalatt(global_f) # [B, N, T]
687
+
688
+ x_fused = []
689
+ # Gather them now in reverse order
690
+ for idx in range(self.depth):
691
+ local = output[idx]
692
+ x_fused.append(self.loc_glo_fus[idx](local, global_f))
693
+
694
+ expanded = None
695
+ for i in range(self.depth - 2, -1, -1):
696
+ if i == self.depth - 2:
697
+ expanded = self.last_layer[i](x_fused[i], x_fused[i - 1])
698
+ else:
699
+ expanded = self.last_layer[i](x_fused[i], expanded)
700
+ # import pdb; pdb.set_trace()
701
+ return self.res_conv(expanded) + residual
702
+
703
+ class EncoderLayer(torch.nn.Module):
704
+ def __init__(self, global_blocks: dict, local_blocks: dict, down_conv_layer: dict, down_conv=True):
705
+ super().__init__()
706
+
707
+ self.g_block_1 = GlobalBlock(**global_blocks)
708
+ self.l_block_1 = LocalBlock(**local_blocks)
709
+
710
+ self.g_block_2 = GlobalBlock(**global_blocks)
711
+ self.l_block_2 = LocalBlock(**local_blocks)
712
+
713
+ self.downconv = DownConvLayer(**down_conv_layer) if down_conv == True else None
714
+
715
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
716
+ '''
717
+ x: [B, N, T]
718
+ '''
719
+ x = self.g_block_1(x, pos_k)
720
+ x = x.permute(0, 2, 1).contiguous()
721
+ x = self.l_block_1(x)
722
+ x = x.permute(0, 2, 1).contiguous()
723
+
724
+ x = self.g_block_2(x, pos_k)
725
+ x = x.permute(0, 2, 1).contiguous()
726
+ x = self.l_block_2(x)
727
+ x = x.permute(0, 2, 1).contiguous()
728
+
729
+ skip = x
730
+ if self.downconv:
731
+ x = x.permute(0, 2, 1).contiguous()
732
+ x = self.downconv(x)
733
+ x = x.permute(0, 2, 1).contiguous()
734
+ # [BK, S, N]
735
+ return x, skip
736
+
737
+ class DecoderLayer(torch.nn.Module):
738
+ def __init__(self, global_blocks: dict, local_blocks: dict, spk_attention: dict):
739
+ super().__init__()
740
+
741
+ self.g_block_1 = GlobalBlock(**global_blocks)
742
+ self.l_block_1 = LocalBlock(**local_blocks)
743
+
744
+ self.g_block_2 = GlobalBlock(**global_blocks)
745
+ self.l_block_2 = LocalBlock(**local_blocks)
746
+
747
+ self.g_block_3 = GlobalBlock(**global_blocks)
748
+ self.l_block_3 = LocalBlock(**local_blocks)
749
+
750
+ def forward(self, x: torch.Tensor, pos_k: torch.Tensor):
751
+ '''
752
+ x: [B, N, T]
753
+ '''
754
+ # [BS, K, H]
755
+ x = self.g_block_1(x, pos_k)
756
+ x = x.permute(0, 2, 1).contiguous()
757
+ x = self.l_block_1(x)
758
+ x = x.permute(0, 2, 1).contiguous()
759
+
760
+ x = self.g_block_2(x, pos_k)
761
+ x = x.permute(0, 2, 1).contiguous()
762
+ x = self.l_block_2(x)
763
+ x = x.permute(0, 2, 1).contiguous()
764
+
765
+ x = self.g_block_3(x, pos_k)
766
+ x = x.permute(0, 2, 1).contiguous()
767
+ x = self.l_block_3(x)
768
+ x = x.permute(0, 2, 1).contiguous()
769
+
770
+ skip = x
771
+
772
+ return x, skip
773
+
774
+ class Separator(torch.nn.Module):
775
+ def __init__(self, num_stages: int, relative_positional_encoding: dict, enc_stage: dict, simple_fusion:dict, dec_stage: dict):
776
+ super().__init__()
777
+
778
+ self.num_stages = num_stages
779
+ self.pos_emb = RelativePositionalEncoding(**relative_positional_encoding)
780
+
781
+ # Temporal Contracting Part
782
+ self.enc_stages = torch.nn.ModuleList([])
783
+ for _ in range(self.num_stages):
784
+ self.enc_stages.append(EncoderLayer(**enc_stage, down_conv=True))
785
+
786
+ self.bottleneck_G = nn.ModuleList([
787
+ MultiHeadAttention(
788
+ n_head=enc_stage['global_blocks']['num_mha_heads'],
789
+ in_channels=enc_stage['global_blocks']['in_channels'],
790
+ dropout_rate=enc_stage['global_blocks']['dropout_rate']
791
+ ),
792
+ FFN(
793
+ in_channels=enc_stage['global_blocks']['in_channels'],
794
+ dropout_rate=enc_stage['global_blocks']['dropout_rate']
795
+ )
796
+ ])
797
+
798
+ # top-down fusion
799
+ self.loc_glo_fus = nn.ModuleList([])
800
+ for i in range(self.num_stages):
801
+ self.loc_glo_fus.append(InjectionMultiSum(simple_fusion['out_channels'], simple_fusion['out_channels']))
802
+
803
+ # Temporal Expanding Part
804
+ self.simple_fusion = torch.nn.ModuleList([])
805
+ self.dec_stages = torch.nn.ModuleList([])
806
+ for _ in range(self.num_stages):
807
+ self.simple_fusion.append(InjectionMultiSum(simple_fusion['out_channels'], simple_fusion['out_channels'], kernel=5))
808
+ self.dec_stages.append(DecoderLayer(**dec_stage))
809
+
810
+ def forward(self, input: torch.Tensor):
811
+ '''input: [B, N, L]'''
812
+ # feature projection
813
+ x, _ = self.pad_signal(input)
814
+ len_x = x.shape[-1]
815
+ # Temporal Contracting Part
816
+ min_len = len_x//2**(self.num_stages-1)
817
+ pos_seq = torch.arange(0, len_x//2**self.num_stages).long().to(x.device)
818
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
819
+ pos_k, _ = self.pos_emb(pos_seq)
820
+
821
+ skip = []
822
+ fusion_x = torch.zeros([x.shape[0], x.shape[1], min_len], requires_grad=True, device=x.device)
823
+ for idx in range(self.num_stages):
824
+ x, skip_ = self.enc_stages[idx](x, pos_k)
825
+ skip.append(skip_)
826
+ fusion_x = fusion_x + F.adaptive_avg_pool1d(x, min_len)
827
+
828
+ global_x = self.bottleneck_G[0](fusion_x.permute(0, 2, 1).contiguous(), None, None)
829
+ global_x = self.bottleneck_G[1](global_x).permute(0, 2, 1).contiguous()
830
+
831
+ # Global topdown attention
832
+ fusion_skip = []
833
+ for idx in range(self.num_stages):
834
+ fusion_skip.append(self.loc_glo_fus[idx](skip[idx], global_x))
835
+
836
+ each_stage_outputs = []
837
+ # Temporal Expanding Part
838
+ for idx in range(self.num_stages):
839
+ each_stage_outputs.append(x)
840
+ idx_en = self.num_stages - (idx + 1)
841
+ x = self.simple_fusion[idx](fusion_skip[idx_en], x)
842
+ x, _ = self.dec_stages[idx](x, pos_k)
843
+
844
+ last_stage_output = x
845
+ return last_stage_output, each_stage_outputs
846
+
847
+ def pad_signal(self, input: torch.Tensor):
848
+ # (B, T) or (B, 1, T)
849
+ if input.dim() == 1: input = input.unsqueeze(0)
850
+ elif input.dim() not in [2, 3]: raise RuntimeError("Input can only be 2 or 3 dimensional.")
851
+ elif input.dim() == 2: input = input.unsqueeze(1)
852
+ L = 2**self.num_stages
853
+ batch_size = input.size(0)
854
+ ndim = input.size(1)
855
+ nframe = input.size(2)
856
+ padded_len = (nframe//L + 1)*L
857
+ rest = 0 if nframe%L == 0 else padded_len - nframe
858
+ if rest > 0:
859
+ pad = torch.autograd.Variable(torch.zeros(batch_size, ndim, rest)).type(input.type()).to(input.device)
860
+ input = torch.cat([input, pad], dim=-1)
861
+ return input, rest
862
+
863
+
864
+ class OutputLayer(torch.nn.Module):
865
+ def __init__(self, in_channels: int, out_channels: int, masking: bool = False):
866
+ super().__init__()
867
+ # feature expansion back
868
+ self.masking = masking
869
+ self.spe_block = Masking(in_channels)
870
+ self.end_conv1x1 = torch.nn.Sequential(
871
+ torch.nn.Linear(out_channels, 4*out_channels),
872
+ torch.nn.GLU(),
873
+ torch.nn.Linear(2*out_channels, in_channels))
874
+
875
+ def forward(self, x: torch.Tensor, input: torch.Tensor):
876
+ x = x[...,:input.shape[-1]]
877
+ x = x.permute([0, 2, 1])
878
+ x = self.end_conv1x1(x)
879
+ x = x.permute([0, 2, 1])
880
+
881
+ if self.masking:
882
+ x = self.spe_block(x, input)
883
+
884
+ return x
885
+
886
+ class AudioDecoder(torch.nn.ConvTranspose1d):
887
+ def __init__(self, *args, **kwargs):
888
+ super().__init__(*args, **kwargs)
889
+
890
+ def forward(self, x):
891
+ # x: [B, N, L]
892
+ if x.dim() not in [2, 3]:
893
+ raise RuntimeError("{} accept 2/3D tensor as input".format(self.__class__.__name__))
894
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
895
+ x = torch.squeeze(x, dim=1) if torch.squeeze(x).dim() == 1 else torch.squeeze(x)
896
+ return x
897
+
898
+ class ReconstructionPath(nn.Module):
899
+ def __init__(
900
+ self,
901
+ layers = [
902
+ 'residual',
903
+ 'residual',
904
+ 'residual'
905
+ ],
906
+ image_size=88,
907
+ in_channel=1,
908
+ init_channel=16,
909
+ max_dim=128,
910
+ # conv相关
911
+ input_conv_kernel_size = [7, 7, 7],
912
+ output_conv_kernel_size = [3, 3, 3],
913
+ residual_conv_kernel_size=3,
914
+ pad_mode="constant",
915
+ # attn相关
916
+ attn_dim_head = 32,
917
+ attn_heads = 8,
918
+ attn_dropout = 0.,
919
+ flash_attn = True,
920
+ linear_attn_dim_head = 8,
921
+ linear_attn_heads = 16,
922
+ fuse_dim=32,
923
+ # quantizer相关
924
+ num_quantizers = 1,
925
+ codebook_size = 256,
926
+ codebook_dim=64,
927
+ commitment_cost=0.25,
928
+ ):
929
+ super().__init__()
930
+ input_conv_kernel_size=tuple(input_conv_kernel_size)
931
+
932
+ self.conv_in = nn.Conv3d(in_channel, init_channel, input_conv_kernel_size,padding='same')
933
+
934
+ layer_fmap_size=image_size
935
+ self.encoder_layers = nn.ModuleList([])
936
+ dim=init_channel
937
+ dim_out=dim
938
+ time_downsample_factor=1
939
+
940
+ for layer_type in layers:
941
+ if layer_type == 'residual':
942
+ encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
943
+
944
+ elif layer_type == 'consecutive_residual':
945
+ num_consecutive = 2
946
+ encoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)])
947
+
948
+ elif layer_type == 'compress_space':
949
+ dim_out = dim * 2
950
+ dim_out = min(dim_out, max_dim)
951
+
952
+ encoder_layer = SpatialDownsample2x(dim, dim_out)
953
+
954
+ assert layer_fmap_size > 1
955
+ layer_fmap_size //= 2
956
+
957
+ elif layer_type == 'compress_time':
958
+ dim_out = dim * 2
959
+ dim_out = min(dim_out, max_dim)
960
+
961
+ encoder_layer = TimeDownsample2x(dim, dim_out)
962
+
963
+ time_downsample_factor *= 2
964
+
965
+ elif layer_type == 'attend_space':
966
+ attn_kwargs = dict(
967
+ dim = dim,
968
+ dim_head = attn_dim_head,
969
+ heads = attn_heads,
970
+ dropout = attn_dropout,
971
+ flash = flash_attn
972
+ )
973
+
974
+ encoder_layer = Sequential(
975
+ Residual(SpaceAttention(**attn_kwargs)),
976
+ Residual(FeedForward(dim))
977
+ )
978
+
979
+ elif layer_type == 'linear_attend_space':
980
+ linear_attn_kwargs = dict(
981
+ dim = dim,
982
+ dim_head = linear_attn_dim_head,
983
+ heads = linear_attn_heads
984
+ )
985
+
986
+ encoder_layer = Sequential(
987
+ Residual(LinearSpaceAttention(**linear_attn_kwargs)),
988
+ Residual(FeedForward(dim))
989
+ )
990
+
991
+ else:
992
+ raise ValueError(f'unknown layer type {layer_type}')
993
+
994
+ self.encoder_layers.append(encoder_layer)
995
+
996
+ dim = dim_out
997
+
998
+ self.encoder_layers.append(Sequential(
999
+ Rearrange('b c ... -> b ... c'),
1000
+ nn.LayerNorm(dim),
1001
+ Rearrange('b ... c -> b c ...'),
1002
+ ))
1003
+
1004
+
1005
+ def forward(self, x, semantic_fea=None):
1006
+ x = self.conv_in(x)
1007
+ for layer in self.encoder_layers:
1008
+ x = layer(x)
1009
+ z_e = x
1010
+
1011
+ z_q=z_e
1012
+
1013
+ if semantic_fea!=None:
1014
+ B,C,T,H,W=z_q.shape
1015
+ z_q=z_q.contiguous().permute(0,2,1,3,4)
1016
+ z_q=z_q.contiguous().view(B,T,-1)
1017
+ z_q=z_q + semantic_fea
1018
+
1019
+ return z_q
1020
+
1021
+ class SemanticPath(nn.Module):
1022
+ def __init__(
1023
+ self,
1024
+ layers = [
1025
+ 'residual',
1026
+ 'residual',
1027
+ 'residual'
1028
+ ],
1029
+ image_size=88,
1030
+ in_channel=1,
1031
+ init_channel=4,
1032
+ max_dim=32,
1033
+ input_conv_kernel_size = [7, 7, 7],
1034
+ output_conv_kernel_size = [3, 3, 3],
1035
+ residual_conv_kernel_size=3,
1036
+ pad_mode="constant",
1037
+ attn_dim_head = 32,
1038
+ attn_heads = 8,
1039
+ attn_dropout = 0.,
1040
+ flash_attn = True,
1041
+ linear_attn_dim_head = 8,
1042
+ linear_attn_heads = 16,
1043
+ num_quantizers = 1,
1044
+ codebook_size = 256,
1045
+ codebook_dim= 64,
1046
+ commitment_cost=0.25,
1047
+ distill_dim=1024,
1048
+ config=None,
1049
+ pretrain=None
1050
+ ):
1051
+ super().__init__()
1052
+ input_conv_kernel_size=tuple(input_conv_kernel_size)
1053
+
1054
+ self.conv_in = nn.Conv3d(in_channel, init_channel, input_conv_kernel_size,padding='same')
1055
+
1056
+ layer_fmap_size=image_size
1057
+ self.encoder_layers = nn.ModuleList([])
1058
+ dim=init_channel
1059
+ dim_out=dim
1060
+ time_downsample_factor=1
1061
+
1062
+ for layer_type in layers:
1063
+ if layer_type == 'residual':
1064
+ encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
1065
+
1066
+ elif layer_type == 'consecutive_residual':
1067
+ num_consecutive = 2
1068
+ encoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)])
1069
+
1070
+ elif layer_type == 'compress_space':
1071
+ dim_out = dim * 2
1072
+ dim_out = min(dim_out, max_dim)
1073
+
1074
+ encoder_layer = SpatialDownsample2x(dim, dim_out)
1075
+
1076
+ assert layer_fmap_size > 1
1077
+ layer_fmap_size //= 2
1078
+
1079
+ elif layer_type == 'compress_time':
1080
+ dim_out = dim * 2
1081
+ dim_out = min(dim_out, max_dim)
1082
+
1083
+ encoder_layer = TimeDownsample2x(dim, dim_out)
1084
+
1085
+ time_downsample_factor *= 2
1086
+
1087
+ elif layer_type == 'attend_space':
1088
+ attn_kwargs = dict(
1089
+ dim = dim,
1090
+ dim_head = attn_dim_head,
1091
+ heads = attn_heads,
1092
+ dropout = attn_dropout,
1093
+ flash = flash_attn
1094
+ )
1095
+
1096
+ encoder_layer = Sequential(
1097
+ Residual(SpaceAttention(**attn_kwargs)),
1098
+ Residual(FeedForward(dim))
1099
+ )
1100
+
1101
+ elif layer_type == 'linear_attend_space':
1102
+ linear_attn_kwargs = dict(
1103
+ dim = dim,
1104
+ dim_head = linear_attn_dim_head,
1105
+ heads = linear_attn_heads
1106
+ )
1107
+
1108
+ encoder_layer = Sequential(
1109
+ Residual(LinearSpaceAttention(**linear_attn_kwargs)),
1110
+ Residual(FeedForward(dim))
1111
+ )
1112
+
1113
+ else:
1114
+ raise ValueError(f'unknown layer type {layer_type}')
1115
+
1116
+ self.encoder_layers.append(encoder_layer)
1117
+
1118
+ dim = dim_out
1119
+
1120
+ self.encoder_layers.append(Sequential(
1121
+ Rearrange('b c ... -> b ... c'),
1122
+ nn.LayerNorm(dim),
1123
+ Rearrange('b ... c -> b c ...'),
1124
+ ))
1125
+
1126
+ # layer_fmap_size = 3
1127
+ self.quantizer = ResidualVQ(
1128
+ dim = dim*layer_fmap_size*layer_fmap_size,
1129
+ num_quantizers = num_quantizers,
1130
+ codebook_size = codebook_size,
1131
+ codebook_dim = codebook_dim,
1132
+ quantize_dropout=False,
1133
+ stochastic_sample_codes = True,
1134
+ sample_codebook_temp = 0.1,
1135
+ kmeans_init = True,
1136
+ kmeans_iters = 10
1137
+ )
1138
+
1139
+ def forward(self, x):
1140
+ x = self.conv_in(x)
1141
+ for layer in self.encoder_layers:
1142
+ x = layer(x)
1143
+ b,c,t,h,w=x.shape
1144
+ x = x.contiguous().permute(0,2,1,3,4)
1145
+ z_e = x.contiguous().view(b,t,-1)
1146
+
1147
+ z_q,_,_=self.quantizer(z_e)
1148
+
1149
+ return z_q
1150
+
1151
+ class VideoEncoder(nn.Module):
1152
+ def __init__(
1153
+ self,
1154
+ layers,
1155
+ image_size=88,
1156
+ in_channel=1,
1157
+ init_channel=16,
1158
+ max_dim=128,
1159
+ input_conv_kernel_size = [7, 7, 7],
1160
+ output_conv_kernel_size = [3, 3, 3],
1161
+ residual_conv_kernel_size=3,
1162
+ pad_mode="constant",
1163
+ # attn相关
1164
+ attn_dim_head = 32,
1165
+ attn_heads = 8,
1166
+ attn_dropout = 0.,
1167
+ flash_attn = True,
1168
+ linear_attn_dim_head = 8,
1169
+ linear_attn_heads = 16,
1170
+ num_quantizers = 1,
1171
+ codebook_size = 256,
1172
+ codebook_dim=64,
1173
+ commitment_cost=0.25,
1174
+ distill_cost=1.0,
1175
+ ):
1176
+ super().__init__()
1177
+ self.semantic_model=SemanticPath(
1178
+ layers=layers,
1179
+ image_size=image_size,
1180
+ in_channel=in_channel,
1181
+ init_channel=init_channel,
1182
+ max_dim=max_dim,
1183
+ input_conv_kernel_size=input_conv_kernel_size,
1184
+ output_conv_kernel_size=output_conv_kernel_size,
1185
+ residual_conv_kernel_size=residual_conv_kernel_size,
1186
+ pad_mode=pad_mode,
1187
+ attn_dim_head = attn_dim_head,
1188
+ attn_heads = attn_heads,
1189
+ attn_dropout = attn_dropout,
1190
+ flash_attn = flash_attn,
1191
+ linear_attn_dim_head = linear_attn_dim_head,
1192
+ linear_attn_heads = linear_attn_heads,
1193
+ num_quantizers = num_quantizers,
1194
+ codebook_size = codebook_size,
1195
+ codebook_dim = codebook_dim,
1196
+ commitment_cost = commitment_cost,
1197
+ )
1198
+ self.recon_model=ReconstructionPath(
1199
+ layers=layers,
1200
+ image_size=image_size,
1201
+ in_channel=in_channel,
1202
+ init_channel=init_channel,
1203
+ max_dim=max_dim,
1204
+ input_conv_kernel_size=input_conv_kernel_size,
1205
+ output_conv_kernel_size=output_conv_kernel_size,
1206
+ residual_conv_kernel_size=residual_conv_kernel_size,
1207
+ pad_mode=pad_mode,
1208
+ attn_dim_head = attn_dim_head,
1209
+ attn_heads = attn_heads,
1210
+ attn_dropout = attn_dropout,
1211
+ flash_attn = flash_attn,
1212
+ linear_attn_dim_head = linear_attn_dim_head,
1213
+ linear_attn_heads = linear_attn_heads,
1214
+ num_quantizers = num_quantizers,
1215
+ codebook_size = codebook_size,
1216
+ codebook_dim = codebook_dim,
1217
+ commitment_cost = commitment_cost,
1218
+ )
1219
+
1220
+ def forward(self, x):
1221
+ semantic_fea = self.semantic_model(x)
1222
+ return self.recon_model(x,semantic_fea)
1223
+
1224
+ class Dolphin(nn.Module, PyTorchModelHubMixin):
1225
+ def __init__(self,
1226
+ num_stages: int,
1227
+ sample_rate: int,
1228
+ module_audio_enc: dict,
1229
+ module_feature_projector: dict,
1230
+ module_separator: dict,
1231
+ module_output_layer: dict,
1232
+ module_audio_dec: dict,
1233
+ video_encoder_params: dict,
1234
+ vpre_channels=512,
1235
+ vmid_channels=512,
1236
+ vin_channels=64,
1237
+ vout_channels=64,):
1238
+ super(Dolphin, self).__init__()
1239
+
1240
+ self.pre_v1 = ConvNormAct(vpre_channels, vin_channels, kSize=3, norm_type="BN")
1241
+
1242
+ self.num_stages = num_stages
1243
+ self.audio_encoder = AudioEncoder(**module_audio_enc)
1244
+ self.feature_projector = FeatureProjector(**module_feature_projector)
1245
+ self.separator = Separator(**module_separator)
1246
+ self.out_layer = OutputLayer(**module_output_layer)
1247
+ self.audio_decoder = AudioDecoder(**module_audio_dec)
1248
+
1249
+ self.video_blocks = UConvBlock(vin_channels, vout_channels, 3, norm_type="BN")
1250
+ self.modalfuse = AVFModule(module_feature_projector["out_channels"], vout_channels)
1251
+
1252
+ self.video_encoder = VideoEncoder(**video_encoder_params)
1253
+
1254
+ @classmethod
1255
+ def _from_pretrained(
1256
+ cls,
1257
+ *,
1258
+ model_id: str,
1259
+ revision: str,
1260
+ cache_dir: str,
1261
+ force_download: bool,
1262
+ proxies: dict,
1263
+ resume_download: bool,
1264
+ local_files_only: bool,
1265
+ token: str,
1266
+ map_location: str = "cpu",
1267
+ strict: bool = False,
1268
+ **model_kwargs,
1269
+ ):
1270
+ """Load model from HuggingFace Hub with proper configuration handling."""
1271
+ import json
1272
+ from huggingface_hub import hf_hub_download
1273
+
1274
+ # Download config file
1275
+ config_file = hf_hub_download(
1276
+ repo_id=model_id,
1277
+ filename="config.json",
1278
+ revision=revision,
1279
+ cache_dir=cache_dir,
1280
+ force_download=force_download,
1281
+ proxies=proxies,
1282
+ resume_download=resume_download,
1283
+ local_files_only=local_files_only,
1284
+ token=token,
1285
+ )
1286
+
1287
+ # Load configuration
1288
+ with open(config_file, "r") as f:
1289
+ config = json.load(f)
1290
+
1291
+ # Extract only the model parameters, excluding HF metadata
1292
+ hf_metadata_keys = {
1293
+ "model_type", "task", "framework", "license", "tags",
1294
+ "architectures", "auto_map"
1295
+ }
1296
+ model_config = {k: v for k, v in config.items() if k not in hf_metadata_keys}
1297
+
1298
+ # Create model instance with config
1299
+ model = cls(**model_config)
1300
+
1301
+ # Try to download different possible model file formats
1302
+ import torch
1303
+ model_files_to_try = [
1304
+ "model.safetensors",
1305
+ ]
1306
+
1307
+ state_dict = None
1308
+ for filename in model_files_to_try:
1309
+ try:
1310
+ model_file = hf_hub_download(
1311
+ repo_id=model_id,
1312
+ filename=filename,
1313
+ revision=revision,
1314
+ cache_dir=cache_dir,
1315
+ force_download=force_download,
1316
+ proxies=proxies,
1317
+ resume_download=resume_download,
1318
+ local_files_only=local_files_only,
1319
+ token=token,
1320
+ )
1321
+
1322
+ # Try to load the state dict
1323
+ if filename.endswith('.safetensors'):
1324
+ # Handle safetensors format
1325
+ try:
1326
+ from safetensors.torch import load_file
1327
+ state_dict = load_file(model_file, device=map_location)
1328
+ except ImportError:
1329
+ print("safetensors not available, skipping .safetensors files")
1330
+ continue
1331
+ else:
1332
+ # Handle PyTorch format
1333
+ checkpoint = torch.load(model_file, map_location=map_location, weights_only=False)
1334
+
1335
+ # Handle different checkpoint formats
1336
+ if isinstance(checkpoint, dict):
1337
+ if 'state_dict' in checkpoint:
1338
+ state_dict = checkpoint['state_dict']
1339
+ elif 'model_state_dict' in checkpoint:
1340
+ state_dict = checkpoint['model_state_dict']
1341
+ else:
1342
+ state_dict = checkpoint
1343
+ else:
1344
+ state_dict = checkpoint
1345
+
1346
+ # If we successfully loaded a state dict, break
1347
+ if state_dict is not None:
1348
+ break
1349
+
1350
+ except Exception as e:
1351
+ print(f"Failed to load {filename}: {e}")
1352
+ continue
1353
+
1354
+ if state_dict is None:
1355
+ raise RuntimeError(f"Could not load model weights from any of the tried files: {model_files_to_try}")
1356
+
1357
+ model.load_state_dict(state_dict, strict=strict)
1358
+
1359
+ return model
1360
+
1361
+ def forward(self, input, mouth):
1362
+ mouth = self.video_encoder(mouth).permute(0, 2, 1).contiguous()
1363
+ v=self.pre_v1(mouth)
1364
+ v=self.video_blocks(v)
1365
+
1366
+ encoder_output = self.audio_encoder(input)
1367
+ projected_feature = self.feature_projector(encoder_output)
1368
+
1369
+ projected_feature = self.modalfuse(projected_feature,v)
1370
+
1371
+ last_stage_output, each_stage_outputs = self.separator(projected_feature)
1372
+
1373
+ out_layer_output = self.out_layer(last_stage_output, encoder_output)
1374
+ audio=self.audio_decoder(out_layer_output)
1375
+
1376
+ return audio.unsqueeze(dim=1)
look2hear/models/video_compoent.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ import torch.nn.functional as F
5
+ from torch import einsum,Tensor
6
+ from functools import partial
7
+ from taylor_series_linear_attention import TaylorSeriesLinearAttn
8
+
9
+ from beartype import beartype
10
+ from beartype.typing import Tuple, List
11
+
12
+ from einops import rearrange, repeat, reduce, pack, unpack
13
+ from einops.layers.torch import Rearrange
14
+
15
+ from typing import Union
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn, einsum, Tensor
22
+ import torch.nn.functional as F
23
+
24
+ from collections import namedtuple
25
+ from functools import wraps
26
+ from packaging import version
27
+
28
+ # constants
29
+
30
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
31
+
32
+ # helpers
33
+
34
+ def exists(val):
35
+ return val is not None
36
+
37
+ def default(val, d):
38
+ return val if exists(val) else d
39
+
40
+ def compact(arr):
41
+ return [*filter(exists, arr)]
42
+
43
+ def once(fn):
44
+ called = False
45
+ @wraps(fn)
46
+ def inner(x):
47
+ nonlocal called
48
+ if called:
49
+ return
50
+ called = True
51
+ return fn(x)
52
+ return inner
53
+
54
+ print_once = once(print)
55
+
56
+ # functions for creating causal mask
57
+ # need a special one for onnx cpu (no support for .triu)
58
+
59
+ def create_causal_mask(i, j, device):
60
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
61
+
62
+ def onnx_create_causal_mask(i, j, device):
63
+ r = torch.arange(i, device = device)
64
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
65
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
66
+ return causal_mask
67
+
68
+ # main class
69
+
70
+ class Attend(nn.Module):
71
+ def __init__(
72
+ self,
73
+ *,
74
+ dropout = 0.,
75
+ causal = False,
76
+ heads = None,
77
+ scale = None,
78
+ flash = False,
79
+ onnxable = False,
80
+ sdp_kwargs: dict = dict(
81
+ enable_flash = True,
82
+ enable_math = True,
83
+ enable_mem_efficient = True
84
+ )
85
+ ):
86
+ super().__init__()
87
+ self.scale = scale
88
+
89
+ self.causal = causal
90
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
91
+
92
+ self.dropout = dropout
93
+ self.attn_dropout = nn.Dropout(dropout)
94
+
95
+ # flash attention
96
+
97
+ self.flash = flash and torch.cuda.is_available()
98
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
99
+
100
+ self.sdp_kwargs = sdp_kwargs
101
+
102
+ def flash_attn(
103
+ self,
104
+ q, k, v,
105
+ mask = None,
106
+ attn_bias = None
107
+ ):
108
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
109
+
110
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
111
+
112
+ # manage scale, since scale is not customizable in sdp, hack around it
113
+
114
+ if exists(self.scale):
115
+ q = q * self.scale / (q.shape[-1] ** -0.5)
116
+
117
+ # Check if mask exists and expand to compatible shape
118
+ # The mask is B L, so it would have to be expanded to B H N L
119
+
120
+ causal = self.causal
121
+
122
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
123
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
124
+
125
+ if q_len == 1 and causal:
126
+ causal = False
127
+
128
+ # expand key padding mask
129
+
130
+ if exists(mask):
131
+ assert mask.ndim == 4
132
+ mask = mask.expand(batch, heads, q_len, k_len)
133
+
134
+ # handle kv cache - this should be bypassable in updated flash attention 2
135
+
136
+ if k_len > q_len and causal:
137
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
138
+ if not exists(mask):
139
+ mask = ~causal_mask
140
+ else:
141
+ mask = mask & ~causal_mask
142
+ causal = False
143
+
144
+ # manually handle causal mask, if another mask was given
145
+
146
+ row_is_entirely_masked = None
147
+
148
+ if exists(mask) and causal:
149
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
150
+ mask = mask & ~causal_mask
151
+
152
+ # protect against an entire row being masked out
153
+
154
+ row_is_entirely_masked = ~mask.any(dim = -1)
155
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
156
+
157
+ causal = False
158
+
159
+ # handle alibi positional bias
160
+ # convert from bool to float
161
+
162
+ if exists(attn_bias):
163
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
164
+
165
+ # if mask given, the mask would already contain the causal mask from above logic
166
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
167
+
168
+ mask_value = -torch.finfo(q.dtype).max
169
+
170
+ if exists(mask):
171
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
172
+ elif causal:
173
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
174
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
175
+ causal = False
176
+
177
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
178
+ # make it an additive bias here
179
+
180
+ mask = attn_bias
181
+
182
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
183
+
184
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
185
+ out = F.scaled_dot_product_attention(
186
+ q, k, v,
187
+ attn_mask = mask,
188
+ dropout_p = self.dropout if self.training else 0.,
189
+ is_causal = causal
190
+ )
191
+
192
+ # for a row that is entirely masked out, should zero out the output of that row token
193
+
194
+ if exists(row_is_entirely_masked):
195
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
196
+
197
+ return out
198
+
199
+ def forward(
200
+ self,
201
+ q, k, v,
202
+ mask = None,
203
+ attn_bias = None,
204
+ prev_attn = None
205
+ ):
206
+ """
207
+ einstein notation
208
+ b - batch
209
+ h - heads
210
+ n, i, j - sequence length (base sequence length, source, target)
211
+ d - feature dimension
212
+ """
213
+
214
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
215
+
216
+ scale = default(self.scale, q.shape[-1] ** -0.5)
217
+
218
+ causal = self.causal
219
+
220
+ # handle kv cached decoding
221
+
222
+ if n == 1 and causal:
223
+ causal = False
224
+
225
+ # handle zero kv, as means for allowing network to attend to nothing
226
+
227
+ if self.flash:
228
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
229
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
230
+
231
+ dots = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale
232
+
233
+ if exists(prev_attn):
234
+ dots = dots + prev_attn
235
+
236
+ if exists(attn_bias):
237
+ dots = dots + attn_bias
238
+
239
+ i, j, dtype = *dots.shape[-2:], dots.dtype
240
+
241
+ mask_value = -torch.finfo(dots.dtype).max
242
+
243
+ if exists(mask):
244
+ dots = dots.masked_fill(~mask, mask_value)
245
+
246
+ if causal:
247
+ causal_mask = self.create_causal_mask(i, j, device = device)
248
+ dots = dots.masked_fill(causal_mask, mask_value)
249
+
250
+ attn = dots.softmax(dim = -1)
251
+
252
+ attn = self.attn_dropout(attn)
253
+
254
+ out = einsum(f'b h i j, b h j d -> b h i d', attn, v)
255
+
256
+ return out
257
+
258
+ def exists(v):
259
+ return v is not None
260
+
261
+ def default(v, d):
262
+ return v if exists(v) else d
263
+
264
+ def safe_get_index(it, ind, default = None):
265
+ if ind < len(it):
266
+ return it[ind]
267
+ return default
268
+
269
+ def pair(t):
270
+ return t if isinstance(t, tuple) else (t, t)
271
+
272
+ def identity(t, *args, **kwargs):
273
+ return t
274
+
275
+ def divisible_by(num, den):
276
+ return (num % den) == 0
277
+
278
+ def pack_one(t, pattern):
279
+ return pack([t], pattern)
280
+
281
+ def unpack_one(t, ps, pattern):
282
+ return unpack(t, ps, pattern)[0]
283
+
284
+ def append_dims(t, ndims: int):
285
+ return t.reshape(*t.shape, *((1,) * ndims))
286
+
287
+ def is_odd(n):
288
+ return not divisible_by(n, 2)
289
+
290
+ def maybe_del_attr_(o, attr):
291
+ if hasattr(o, attr):
292
+ delattr(o, attr)
293
+
294
+ def cast_tuple(t, length = 1):
295
+ return t if isinstance(t, tuple) else ((t,) * length)
296
+ class ResBlock(nn.Module):
297
+ def __init__(self, in_channel, channel):
298
+ super().__init__()
299
+
300
+ self.conv = nn.Sequential(
301
+ nn.ReLU(),
302
+ nn.Conv2d(in_channel, channel, 3, padding=1),
303
+ nn.ReLU(inplace=True),
304
+ nn.Conv2d(channel, in_channel, 1),
305
+ )
306
+
307
+ def forward(self, input):
308
+ out = self.conv(input)
309
+ out += input
310
+
311
+ return out
312
+
313
+
314
+ class EncoderAE(nn.Module):
315
+ def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
316
+ super().__init__()
317
+
318
+ if stride == 4:
319
+ blocks = [
320
+ nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
321
+ nn.ReLU(inplace=True),
322
+ nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
323
+ nn.ReLU(inplace=True),
324
+ nn.Conv2d(channel, channel, 3, padding=1),
325
+ ]
326
+
327
+ elif stride == 2:
328
+ blocks = [
329
+ nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
330
+ nn.ReLU(inplace=True),
331
+ nn.Conv2d(channel // 2, channel, 3, padding=1),
332
+ ]
333
+
334
+ for i in range(n_res_block):
335
+ blocks.append(ResBlock(channel, n_res_channel))
336
+
337
+ blocks.append(nn.ReLU(inplace=True))
338
+
339
+ self.blocks = nn.Sequential(*blocks)
340
+
341
+ def forward(self, input):
342
+ return self.blocks(input)
343
+
344
+
345
+ class DecoderAE(nn.Module):
346
+ def __init__(
347
+ self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
348
+ ):
349
+ super().__init__()
350
+
351
+ blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
352
+
353
+ for i in range(n_res_block):
354
+ blocks.append(ResBlock(channel, n_res_channel))
355
+
356
+ blocks.append(nn.ReLU(inplace=True))
357
+
358
+ if stride == 4:
359
+ blocks.extend(
360
+ [
361
+ nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
362
+ nn.ReLU(inplace=True),
363
+ nn.ConvTranspose2d(
364
+ channel // 2, out_channel, 4, stride=2, padding=1
365
+ ),
366
+ ]
367
+ )
368
+
369
+ elif stride == 2:
370
+ blocks.append(
371
+ nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
372
+ )
373
+
374
+ self.blocks = nn.Sequential(*blocks)
375
+
376
+ def forward(self, input):
377
+ return self.blocks(input)
378
+
379
+ class CausalConv3d(nn.Module):
380
+ # 因果三维卷积,实则和直接三维卷积区别不大
381
+ @beartype
382
+ def __init__(
383
+ self,
384
+ chan_in,
385
+ chan_out,
386
+ kernel_size: Union[int, Tuple[int, int, int]],
387
+ pad_mode = 'constant',
388
+ **kwargs
389
+ ):
390
+ super().__init__()
391
+ kernel_size = cast_tuple(kernel_size, 3)
392
+
393
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
394
+
395
+ # 这里以及下文中的height_pad,weight_pad的设置都是为了最后的HW size不变
396
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
397
+
398
+ dilation = kwargs.pop('dilation', 1)
399
+ stride = kwargs.pop('stride', 1)
400
+
401
+ self.pad_mode = pad_mode
402
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
403
+ height_pad = height_kernel_size // 2
404
+ width_pad = width_kernel_size // 2
405
+
406
+ self.time_pad = time_pad
407
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
408
+
409
+ stride = (stride, 1, 1)
410
+ dilation = (dilation, 1, 1)
411
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
412
+
413
+ def forward(self, x):
414
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
415
+
416
+ x = F.pad(x, self.time_causal_padding, mode = pad_mode)
417
+ return self.conv(x)
418
+
419
+ class SqueezeExcite(nn.Module):
420
+ # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
421
+ # 一个轻量化的 channel-wise attn
422
+ def __init__(
423
+ self,
424
+ dim,
425
+ *,
426
+ dim_out = None,
427
+ dim_hidden_min = 16,
428
+ init_bias = -10
429
+ ):
430
+ super().__init__()
431
+ dim_out = default(dim_out, dim)
432
+
433
+ self.to_k = nn.Conv2d(dim, 1, 1)
434
+ dim_hidden = max(dim_hidden_min, dim_out // 2)
435
+
436
+ self.net = nn.Sequential(
437
+ nn.Conv2d(dim, dim_hidden, 1),
438
+ nn.LeakyReLU(0.1),
439
+ nn.Conv2d(dim_hidden, dim_out, 1),
440
+ nn.Sigmoid()
441
+ )
442
+
443
+ nn.init.zeros_(self.net[-2].weight)
444
+ nn.init.constant_(self.net[-2].bias, init_bias)
445
+
446
+ def forward(self, x):
447
+ orig_input, batch = x, x.shape[0]
448
+ is_video = x.ndim == 5
449
+
450
+ if is_video:
451
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
452
+
453
+ # 根据HW经过conv得到context特征图
454
+ context = self.to_k(x)
455
+
456
+ context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim = -1)
457
+ spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')
458
+
459
+ out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
460
+ out = rearrange(out, '... -> ... 1')
461
+ gates = self.net(out)
462
+
463
+ if is_video:
464
+ gates = rearrange(gates, '(b f) c h w -> b c f h w', b = batch)
465
+
466
+ return gates * orig_input
467
+
468
+ class Residual(nn.Module):
469
+ @beartype
470
+ def __init__(self, fn: nn.Module):
471
+ super().__init__()
472
+ self.fn = fn
473
+
474
+ def forward(self, x, **kwargs):
475
+ return self.fn(x, **kwargs) + x
476
+
477
+ def ResidualUnit(
478
+ dim,
479
+ kernel_size: Union[int, Tuple[int, int, int]],
480
+ pad_mode: str = 'constant'
481
+ ):
482
+ net = nn.Sequential(
483
+ # 因果3D卷积
484
+ # CausalConv3d(dim, dim, kernel_size, pad_mode = pad_mode),
485
+ nn.Conv3d(dim, dim, kernel_size,padding='same'),
486
+ nn.ELU(),
487
+ nn.Conv3d(dim, dim, 1),
488
+ nn.ELU(),
489
+ # 一个channel wise的conv1d+softmax的global context attn
490
+ SqueezeExcite(dim)
491
+ )
492
+
493
+ return Residual(net)
494
+
495
+ # strided conv downsamples
496
+
497
+ class SpatialDownsample2x(nn.Module):
498
+ def __init__(
499
+ self,
500
+ dim,
501
+ dim_out = None,
502
+ kernel_size = 3,
503
+ antialias = False
504
+ ):
505
+ super().__init__()
506
+ dim_out = default(dim_out, dim)
507
+ self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride = 2, padding = kernel_size // 2)
508
+
509
+ def forward(self, x):
510
+
511
+ x = rearrange(x, 'b c t h w -> b t c h w')
512
+ x, ps = pack_one(x, '* c h w')
513
+
514
+ out = self.conv(x)
515
+
516
+ out = unpack_one(out, ps, '* c h w')
517
+ out = rearrange(out, 'b t c h w -> b c t h w')
518
+ return out
519
+
520
+ class TimeDownsample2x(nn.Module):
521
+ def __init__(
522
+ self,
523
+ dim,
524
+ dim_out = None,
525
+ kernel_size = 3,
526
+ antialias = False
527
+ ):
528
+ super().__init__()
529
+ dim_out = default(dim_out, dim)
530
+ self.time_causal_padding = (kernel_size - 1, 0)
531
+ self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2)
532
+
533
+ def forward(self, x):
534
+ x = rearrange(x, 'b c t h w -> b h w c t')
535
+ x, ps = pack_one(x, '* c t')
536
+
537
+ x = F.pad(x, self.time_causal_padding)
538
+ out = self.conv(x)
539
+
540
+ out = unpack_one(out, ps, '* c t')
541
+ out = rearrange(out, 'b h w c t -> b c t h w')
542
+ return out
543
+
544
+ # depth to space upsamples
545
+
546
+ class SpatialUpsample2x(nn.Module):
547
+ def __init__(
548
+ self,
549
+ dim,
550
+ dim_out = None
551
+ ):
552
+ super().__init__()
553
+ dim_out = default(dim_out, dim)
554
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
555
+
556
+ self.net = nn.Sequential(
557
+ conv,
558
+ nn.SiLU(),
559
+ Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
560
+ )
561
+
562
+ self.init_conv_(conv)
563
+
564
+ def init_conv_(self, conv):
565
+ o, i, h, w = conv.weight.shape
566
+ conv_weight = torch.empty(o // 4, i, h, w)
567
+ nn.init.kaiming_uniform_(conv_weight)
568
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
569
+
570
+ conv.weight.data.copy_(conv_weight)
571
+ nn.init.zeros_(conv.bias.data)
572
+
573
+ def forward(self, x):
574
+ x = rearrange(x, 'b c t h w -> b t c h w')
575
+ x, ps = pack_one(x, '* c h w')
576
+
577
+ out = self.net(x)
578
+
579
+ out = unpack_one(out, ps, '* c h w')
580
+ out = rearrange(out, 'b t c h w -> b c t h w')
581
+ return out
582
+
583
+ class TimeUpsample2x(nn.Module):
584
+ def __init__(
585
+ self,
586
+ dim,
587
+ dim_out = None
588
+ ):
589
+ super().__init__()
590
+ dim_out = default(dim_out, dim)
591
+ conv = nn.Conv1d(dim, dim_out * 2, 1)
592
+
593
+ self.net = nn.Sequential(
594
+ conv,
595
+ nn.SiLU(),
596
+ Rearrange('b (c p) t -> b c (t p)', p = 2)
597
+ )
598
+
599
+ self.init_conv_(conv)
600
+
601
+ def init_conv_(self, conv):
602
+ o, i, t = conv.weight.shape
603
+ conv_weight = torch.empty(o // 2, i, t)
604
+ nn.init.kaiming_uniform_(conv_weight)
605
+ conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
606
+
607
+ conv.weight.data.copy_(conv_weight)
608
+ nn.init.zeros_(conv.bias.data)
609
+
610
+ def forward(self, x):
611
+ x = rearrange(x, 'b c t h w -> b h w c t')
612
+ x, ps = pack_one(x, '* c t')
613
+
614
+ out = self.net(x)
615
+
616
+ out = unpack_one(out, ps, '* c t')
617
+ out = rearrange(out, 'b h w c t -> b c t h w')
618
+ return out
619
+
620
+ class RMSNorm(nn.Module):
621
+ def __init__(
622
+ self,
623
+ dim,
624
+ channel_first = False,
625
+ images = False,
626
+ bias = False
627
+ ):
628
+ super().__init__()
629
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
630
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
631
+
632
+ self.channel_first = channel_first
633
+ self.scale = dim ** 0.5
634
+ self.gamma = nn.Parameter(torch.ones(shape))
635
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
636
+
637
+ def forward(self, x):
638
+ return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
639
+
640
+ class AdaptiveRMSNorm(nn.Module):
641
+ def __init__(
642
+ self,
643
+ dim,
644
+ *,
645
+ dim_cond,
646
+ channel_first = False,
647
+ images = False,
648
+ bias = False
649
+ ):
650
+ super().__init__()
651
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
652
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
653
+
654
+ self.dim_cond = dim_cond
655
+ self.channel_first = channel_first
656
+ self.scale = dim ** 0.5
657
+
658
+ self.to_gamma = nn.Linear(dim_cond, dim)
659
+ self.to_bias = nn.Linear(dim_cond, dim) if bias else None
660
+
661
+ nn.init.zeros_(self.to_gamma.weight)
662
+ nn.init.ones_(self.to_gamma.bias)
663
+
664
+ if bias:
665
+ nn.init.zeros_(self.to_bias.weight)
666
+ nn.init.zeros_(self.to_bias.bias)
667
+
668
+ @beartype
669
+ def forward(self, x: Tensor, *, cond: Tensor):
670
+ batch = x.shape[0]
671
+ assert cond.shape == (batch, self.dim_cond)
672
+
673
+ gamma = self.to_gamma(cond)
674
+
675
+ bias = 0.
676
+ if exists(self.to_bias):
677
+ bias = self.to_bias(cond)
678
+
679
+ if self.channel_first:
680
+ gamma = append_dims(gamma, x.ndim - 2)
681
+
682
+ if exists(self.to_bias):
683
+ bias = append_dims(bias, x.ndim - 2)
684
+
685
+ return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * gamma + bias
686
+
687
+ class Attention(nn.Module):
688
+ @beartype
689
+ def __init__(
690
+ self,
691
+ *,
692
+ dim,
693
+ dim_cond: Union[int,None] = None,
694
+ causal = False,
695
+ dim_head = 32,
696
+ heads = 8,
697
+ flash = False,
698
+ dropout = 0.,
699
+ num_memory_kv = 4
700
+ ):
701
+ super().__init__()
702
+ dim_inner = dim_head * heads
703
+
704
+ self.need_cond = exists(dim_cond)
705
+
706
+ if self.need_cond:
707
+ self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
708
+ else:
709
+ self.norm = RMSNorm(dim)
710
+
711
+ self.to_qkv = nn.Sequential(
712
+ nn.Linear(dim, dim_inner * 3, bias = False),
713
+ Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
714
+ )
715
+
716
+ assert num_memory_kv > 0
717
+ self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
718
+
719
+ self.attend = Attend(
720
+ causal = causal,
721
+ dropout = dropout,
722
+ flash = flash
723
+ )
724
+
725
+ self.to_out = nn.Sequential(
726
+ Rearrange('b h n d -> b n (h d)'),
727
+ nn.Linear(dim_inner, dim, bias = False)
728
+ )
729
+
730
+ @beartype
731
+ def forward(
732
+ self,
733
+ x,
734
+ mask: Union[Tensor,None] = None,
735
+ cond: Union[Tensor,None] = None
736
+ ):
737
+ maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
738
+
739
+ x = self.norm(x, **maybe_cond_kwargs)
740
+
741
+ q, k, v = self.to_qkv(x)
742
+
743
+ mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv)
744
+ k = torch.cat((mk, k), dim = -2)
745
+ v = torch.cat((mv, v), dim = -2)
746
+
747
+ out = self.attend(q, k, v, mask = mask)
748
+ return self.to_out(out)
749
+
750
+ class LinearAttention(nn.Module):
751
+ """
752
+ using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
753
+ """
754
+
755
+ @beartype
756
+ def __init__(
757
+ self,
758
+ *,
759
+ dim,
760
+ dim_cond: Union[int,None] = None,
761
+ dim_head = 8,
762
+ heads = 8,
763
+ dropout = 0.
764
+ ):
765
+ super().__init__()
766
+ dim_inner = dim_head * heads
767
+
768
+ self.need_cond = exists(dim_cond)
769
+
770
+ if self.need_cond:
771
+ self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
772
+ else:
773
+ self.norm = RMSNorm(dim)
774
+
775
+ self.attn = TaylorSeriesLinearAttn(
776
+ dim = dim,
777
+ dim_head = dim_head,
778
+ heads = heads
779
+ )
780
+
781
+ def forward(
782
+ self,
783
+ x,
784
+ cond: Union[Tensor,None] = None
785
+ ):
786
+ maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
787
+
788
+ x = self.norm(x, **maybe_cond_kwargs)
789
+
790
+ return self.attn(x)
791
+
792
+ class LinearSpaceAttention(LinearAttention):
793
+ def forward(self, x, *args, **kwargs):
794
+ x = rearrange(x, 'b c ... h w -> b ... h w c')
795
+ x, batch_ps = pack_one(x, '* h w c')
796
+ x, seq_ps = pack_one(x, 'b * c')
797
+
798
+ x = super().forward(x, *args, **kwargs)
799
+
800
+ x = unpack_one(x, seq_ps, 'b * c')
801
+ x = unpack_one(x, batch_ps, '* h w c')
802
+ return rearrange(x, 'b ... h w c -> b c ... h w')
803
+
804
+ class SpaceAttention(Attention):
805
+ def forward(self, x, *args, **kwargs):
806
+ x = rearrange(x, 'b c t h w -> b t h w c')
807
+ x, batch_ps = pack_one(x, '* h w c')
808
+ x, seq_ps = pack_one(x, 'b * c')
809
+
810
+ x = super().forward(x, *args, **kwargs)
811
+
812
+ x = unpack_one(x, seq_ps, 'b * c')
813
+ x = unpack_one(x, batch_ps, '* h w c')
814
+ return rearrange(x, 'b t h w c -> b c t h w')
815
+
816
+ class TimeAttention(Attention):
817
+ def forward(self, x, *args, **kwargs):
818
+ x = rearrange(x, 'b c t h w -> b h w t c')
819
+ x, batch_ps = pack_one(x, '* t c')
820
+
821
+ x = super().forward(x, *args, **kwargs)
822
+
823
+ x = unpack_one(x, batch_ps, '* t c')
824
+ return rearrange(x, 'b h w t c -> b c t h w')
825
+
826
+ class GEGLU(nn.Module):
827
+ def forward(self, x):
828
+ x, gate = x.chunk(2, dim = 1)
829
+ return F.gelu(gate) * x
830
+
831
+ class FeedForward(nn.Module):
832
+ @beartype
833
+ def __init__(
834
+ self,
835
+ dim,
836
+ *,
837
+ dim_cond: Union[int,None] = None,
838
+ mult = 4,
839
+ images = False
840
+ ):
841
+ super().__init__()
842
+ conv_klass = nn.Conv2d if images else nn.Conv3d
843
+
844
+ rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond = dim_cond)
845
+
846
+ maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first = True, images = images)
847
+
848
+ dim_inner = int(dim * mult * 2 / 3)
849
+
850
+ self.norm = maybe_adaptive_norm_klass(dim)
851
+
852
+ self.net = Sequential(
853
+ conv_klass(dim, dim_inner * 2, 1),
854
+ GEGLU(),
855
+ conv_klass(dim_inner, dim, 1)
856
+ )
857
+
858
+ @beartype
859
+ def forward(
860
+ self,
861
+ x: Tensor,
862
+ *,
863
+ cond: Union[Tensor,None] = None
864
+ ):
865
+ maybe_cond_kwargs = dict(cond = cond) if exists(cond) else dict()
866
+
867
+ x = self.norm(x, **maybe_cond_kwargs)
868
+ return self.net(x)
869
+
870
+ def Sequential(*modules):
871
+ modules = [*filter(exists, modules)]
872
+
873
+ if len(modules) == 0:
874
+ return nn.Identity()
875
+
876
+ return nn.Sequential(*modules)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ torchvision
4
+ moviepy
5
+ face_alignment
6
+ beartype
7
+ taylor_series_linear_attention
8
+ huggingface_hub
9
+ einops
10
+ vector_quantize_pytorch
11
+ spaces
12
+ tf-keras
13
+ retina-face
14
+ safetensors