import warnings warnings.filterwarnings("ignore") import os import argparse import face_alignment import torch import torchaudio import numpy as np import cv2 from PIL import Image, ImageDraw from moviepy import * from collections import deque from skimage import transform as tf import yaml from look2hear.models import Dolphin from look2hear.datas.transform import get_preprocessing_pipelines from face_detection_utils import detect_faces # -- Landmark interpolation: def linear_interpolate(landmarks, start_idx, stop_idx): start_landmarks = landmarks[start_idx] stop_landmarks = landmarks[stop_idx] delta = stop_landmarks - start_landmarks for idx in range(1, stop_idx-start_idx): landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta return landmarks # -- Face Transformation def warp_img(src, dst, img, std_size): tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) # wrap the frame image warped = warped * 255 # note output from wrap is double image (value range [0,1]) warped = warped.astype('uint8') return warped, tform def apply_transform(transform, img, std_size): warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) warped = warped * 255 # note output from wrap is double image (value range [0,1]) warped = warped.astype('uint8') return warped # -- Crop def cut_patch(img, landmarks, height, width, threshold=5): center_x, center_y = np.mean(landmarks, axis=0) if center_y - height < 0: center_y = height if center_y - height < 0 - threshold: raise Exception('too much bias in height') if center_x - width < 0: center_x = width if center_x - width < 0 - threshold: raise Exception('too much bias in width') if center_y + height > img.shape[0]: center_y = img.shape[0] - height if center_y + height > img.shape[0] + threshold: raise Exception('too much bias in height') if center_x + width > img.shape[1]: center_x = img.shape[1] - width if center_x + width > img.shape[1] + threshold: raise Exception('too much bias in width') cutted_img = np.copy(img[ int(round(center_y) - round(height)): int(round(center_y) + round(height)), int(round(center_x) - round(width)): int(round(center_x) + round(width))]) return cutted_img # -- RGB to GRAY def convert_bgr2gray(data): return np.stack([cv2.cvtColor(_, cv2.COLOR_BGR2GRAY) for _ in data], axis=0) def save2npz(filename, data=None): assert data is not None, "data is {}".format(data) if not os.path.exists(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename)) np.savez_compressed(filename, data=data) def read_video(filename): """Read video frames using MoviePy for better compatibility""" try: video_clip = VideoFileClip(filename) for frame in video_clip.iter_frames(): # Convert RGB to BGR to match cv2 format frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) yield frame_bgr video_clip.close() except Exception as e: print(f"Error reading video {filename}: {e}") return def face2head(boxes, scale=1.5): new_boxes = [] for box in boxes: width = box[2] - box[0] height= box[3] - box[1] width_center = (box[2] + box[0]) / 2 height_center = (box[3] + box[1]) / 2 square_width = int(max(width, height) * scale) new_box = [width_center - square_width/2, height_center - square_width/2, width_center + square_width/2, height_center + square_width/2] new_boxes.append(new_box) return new_boxes def bb_intersection_over_union(boxA, boxB): # determine the (x, y)-coordinates of the intersection rectangle xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) # compute the area of intersection rectangle interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) # compute the area of both the prediction and ground-truth # rectangles boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) # compute the intersection over union by taking the intersection # area and dividing it by the sum of prediction + ground-truth # areas - the interesection area iou = interArea / float(boxAArea + boxBArea - interArea) # return the intersection over union value return iou def detectface(video_input_path, output_path, detect_every_N_frame, scalar_face_detection, number_of_speakers): device = torch.device('cuda' if torch.cuda.get_device_name() else 'cpu') print('Running on device: {}'.format(device)) os.makedirs(os.path.join(output_path, 'faces'), exist_ok=True) os.makedirs(os.path.join(output_path, 'landmark'), exist_ok=True) landmarks_dic = {} faces_dic = {} boxes_dic = {} for i in range(number_of_speakers): landmarks_dic[i] = [] faces_dic[i] = [] boxes_dic[i] = [] video_clip = VideoFileClip(video_input_path) print("Video statistics: ", video_clip.w, video_clip.h, (video_clip.w, video_clip.h), video_clip.fps) frames = [Image.fromarray(frame) for frame in video_clip.iter_frames()] print('Number of frames in video: ', len(frames)) video_clip.close() fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) for i, frame in enumerate(frames): print('\rTracking frame: {}'.format(i + 1), end='') # Detect faces every N frames if i % detect_every_N_frame == 0: frame_array = np.array(frame) detected_boxes, _ = detect_faces( frame_array, threshold=0.9, allow_upscaling=False, ) if detected_boxes is None or len(detected_boxes) == 0: detected_boxes, _ = detect_faces( frame_array, threshold=0.7, allow_upscaling=True, ) if detected_boxes is not None and len(detected_boxes) > 0: detected_boxes = detected_boxes[:number_of_speakers] detected_boxes = face2head(detected_boxes, scalar_face_detection) else: detected_boxes = [] # Process the detection results if i == 0: # First frame - initialize tracking if len(detected_boxes) < number_of_speakers: raise ValueError(f"First frame must detect at least {number_of_speakers} faces, but only found {len(detected_boxes)}") # Assign first detections to speakers in order for j in range(number_of_speakers): box = detected_boxes[j] face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224)) preds = fa.get_landmarks(np.array(face)) if preds is None: raise ValueError(f"Face landmarks not detected in initial frame for speaker {j}") faces_dic[j].append(face) landmarks_dic[j].append(preds) boxes_dic[j].append(box) else: # For subsequent frames, match detected boxes to speakers matched_speakers = set() speaker_boxes = [None] * number_of_speakers # Match each detected box to the most likely speaker for box in detected_boxes: iou_scores = [] for speaker_id in range(number_of_speakers): if speaker_id in matched_speakers: iou_scores.append(-1) # Already matched else: last_box = boxes_dic[speaker_id][-1] iou_score = bb_intersection_over_union(box, last_box) iou_scores.append(iou_score) if max(iou_scores) > 0: # Valid match found best_speaker = iou_scores.index(max(iou_scores)) speaker_boxes[best_speaker] = box matched_speakers.add(best_speaker) # Process each speaker for speaker_id in range(number_of_speakers): if speaker_boxes[speaker_id] is not None: # Use detected box box = speaker_boxes[speaker_id] else: # Use previous box for this speaker box = boxes_dic[speaker_id][-1] # Extract face and landmarks face = frame.crop((box[0], box[1], box[2], box[3])).resize((224,224)) preds = fa.get_landmarks(np.array(face)) if preds is None: # Use previous landmarks if detection fails preds = landmarks_dic[speaker_id][-1] faces_dic[speaker_id].append(face) landmarks_dic[speaker_id].append(preds) boxes_dic[speaker_id].append(box) # Verify all speakers have same number of frames frame_counts = [len(boxes_dic[s]) for s in range(number_of_speakers)] print(f"\nFrame counts per speaker: {frame_counts}") assert all(count == len(frames) for count in frame_counts), f"Inconsistent frame counts: {frame_counts}" # Continue with saving videos and landmarks... for s in range(number_of_speakers): frames_tracked = [] for i, frame in enumerate(frames): frame_draw = frame.copy() draw = ImageDraw.Draw(frame_draw) draw.rectangle(boxes_dic[s][i], outline=(255, 0, 0), width=6) frames_tracked.append(frame_draw) # Save tracked video tracked_frames = [np.array(frame) for frame in frames_tracked] if tracked_frames: tracked_clip = ImageSequenceClip(tracked_frames, fps=25.0) tracked_video_path = os.path.join(output_path, 'video_tracked' + str(s+1) + '.mp4') tracked_clip.write_videofile(tracked_video_path, codec='libx264', audio=False, logger=None) tracked_clip.close() # Save landmarks for i in range(number_of_speakers): save2npz(os.path.join(output_path, 'landmark', 'speaker' + str(i+1)+'.npz'), data=landmarks_dic[i]) # Save face video face_frames = [np.array(frame) for frame in faces_dic[i]] if face_frames: face_clip = ImageSequenceClip(face_frames, fps=25.0) face_video_path = os.path.join(output_path, 'faces', 'speaker' + str(i+1) + '.mp4') face_clip.write_videofile(face_video_path, codec='libx264', audio=False, logger=None) face_clip.close() # Output video path parts = video_input_path.split('/') video_name = parts[-1][:-4] if not os.path.exists(os.path.join(output_path, 'filename_input')): os.mkdir(os.path.join(output_path, 'filename_input')) csvfile = open(os.path.join(output_path, 'filename_input', str(video_name) + '.csv'), 'w') for i in range(number_of_speakers): csvfile.write('speaker' + str(i+1)+ ',0\n') csvfile.close() return os.path.join(output_path, 'filename_input', str(video_name) + '.csv') def crop_patch(mean_face_landmarks, video_pathname, landmarks, window_margin, start_idx, stop_idx, crop_height, crop_width, STD_SIZE=(256, 256)): """Crop mouth patch :param str video_pathname: pathname for the video_dieo :param list landmarks: interpolated landmarks """ stablePntsIDs = [33, 36, 39, 42, 45] frame_idx = 0 frame_gen = read_video(video_pathname) while True: try: frame = frame_gen.__next__() ## -- BGR except StopIteration: break if frame_idx == 0: q_frame, q_landmarks = deque(), deque() sequence = [] q_landmarks.append(landmarks[frame_idx]) q_frame.append(frame) if len(q_frame) == window_margin: smoothed_landmarks = np.mean(q_landmarks, axis=0) cur_landmarks = q_landmarks.popleft() cur_frame = q_frame.popleft() # -- affine transformation trans_frame, trans = warp_img( smoothed_landmarks[stablePntsIDs, :], mean_face_landmarks[stablePntsIDs, :], cur_frame, STD_SIZE) trans_landmarks = trans(cur_landmarks) # -- crop mouth patch sequence.append( cut_patch( trans_frame, trans_landmarks[start_idx:stop_idx], crop_height//2, crop_width//2,)) if frame_idx == len(landmarks)-1: #deal with corner case with video too short if len(landmarks) < window_margin: smoothed_landmarks = np.mean(q_landmarks, axis=0) cur_landmarks = q_landmarks.popleft() cur_frame = q_frame.popleft() # -- affine transformation trans_frame, trans = warp_img(smoothed_landmarks[stablePntsIDs, :], mean_face_landmarks[stablePntsIDs, :], cur_frame, STD_SIZE) trans_landmarks = trans(cur_landmarks) # -- crop mouth patch sequence.append(cut_patch( trans_frame, trans_landmarks[start_idx:stop_idx], crop_height//2, crop_width//2,)) while q_frame: cur_frame = q_frame.popleft() # -- transform frame trans_frame = apply_transform( trans, cur_frame, STD_SIZE) # -- transform landmarks trans_landmarks = trans(q_landmarks.popleft()) # -- crop mouth patch sequence.append( cut_patch( trans_frame, trans_landmarks[start_idx:stop_idx], crop_height//2, crop_width//2,)) return np.array(sequence) frame_idx += 1 return None def landmarks_interpolate(landmarks): """Interpolate landmarks param list landmarks: landmarks detected in raw videos """ valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] if not valid_frames_idx: return None for idx in range(1, len(valid_frames_idx)): if valid_frames_idx[idx] - valid_frames_idx[idx-1] == 1: continue else: landmarks = linear_interpolate(landmarks, valid_frames_idx[idx-1], valid_frames_idx[idx]) valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] # -- Corner case: keep frames at the beginning or at the end failed to be detected. if valid_frames_idx: landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0] landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1]) valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] assert len(valid_frames_idx) == len(landmarks), "not every frame has landmark" return landmarks def crop_mouth(video_direc, landmark_direc, filename_path, save_direc, convert_gray=False, testset_only=False): lines = open(filename_path).read().splitlines() lines = list(filter(lambda x: 'test' in x, lines)) if testset_only else lines for filename_idx, line in enumerate(lines): filename, person_id = line.split(',') print('idx: {} \tProcessing.\t{}'.format(filename_idx, filename)) video_pathname = os.path.join(video_direc, filename+'.mp4') landmarks_pathname = os.path.join(landmark_direc, filename+'.npz') dst_pathname = os.path.join( save_direc, filename+'.npz') # if os.path.exists(dst_pathname): # continue multi_sub_landmarks = np.load(landmarks_pathname, allow_pickle=True)['data'] landmarks = [None] * len(multi_sub_landmarks) for frame_idx in range(len(landmarks)): try: #landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)]['facial_landmarks'] #original for LRW landmarks[frame_idx] = multi_sub_landmarks[frame_idx][int(person_id)] #VOXCELEB2 except (IndexError, TypeError): continue # -- pre-process landmarks: interpolate frames not being detected. preprocessed_landmarks = landmarks_interpolate(landmarks) if not preprocessed_landmarks: continue # -- crop mean_face_landmarks = np.load('assets/20words_mean_face.npy') sequence = crop_patch(mean_face_landmarks, video_pathname, preprocessed_landmarks, 12, 48, 68, 96, 96) assert sequence is not None, "cannot crop from {}.".format(filename) # -- save data = convert_bgr2gray(sequence) if convert_gray else sequence[...,::-1] save2npz(dst_pathname, data=data) def convert_video_fps(input_file, output_file, target_fps=25): """Convert video to target FPS using moviepy""" video = VideoFileClip(input_file) video_fps = video.fps if video_fps != target_fps: video.write_videofile( output_file, fps=target_fps, codec='libx264', audio_codec='aac', temp_audiofile='temp-audio.m4a', remove_temp=True, ) else: # If already at target fps, just copy import shutil shutil.copy2(input_file, output_file) video.close() print(f'Video has been converted to {target_fps} fps and saved to {output_file}') def extract_audio(video_file, audio_output_file, sample_rate=16000): """Extract audio from video using moviepy""" video = VideoFileClip(video_file) audio = video.audio # Save audio with specified sample rate audio.write_audiofile(audio_output_file, fps=sample_rate, nbytes=2, codec='pcm_s16le') video.close() audio.close() def merge_video_audio(video_file, audio_file, output_file): """Merge video and audio using moviepy""" video = VideoFileClip(video_file) audio = AudioFileClip(audio_file) # Attach audio (MoviePy v2 renamed set_audio to with_audio) set_audio_fn = getattr(video, "set_audio", None) if callable(set_audio_fn): final_video = set_audio_fn(audio) else: with_audio_fn = getattr(video, "with_audio", None) if not callable(with_audio_fn): video.close() audio.close() raise AttributeError("VideoFileClip object lacks both set_audio and with_audio methods") final_video = with_audio_fn(audio) # Write the result final_video.write_videofile(output_file, codec='libx264', audio_codec='aac', temp_audiofile='temp-audio.m4a', remove_temp=True) # Clean up video.close() audio.close() final_video.close() def process_video(input_file, output_path, number_of_speakers=2, detect_every_N_frame=8, scalar_face_detection=1.5, config_path="checkpoints/vox2/conf.yml", cuda_device=None): """Main processing function for video speaker separation""" # Set CUDA device if specified if cuda_device is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device) # Create output directory os.makedirs(output_path, exist_ok=True) # Convert video to 25fps temp_25fps_file = os.path.join(output_path, 'temp_25fps.mp4') convert_video_fps(input_file, temp_25fps_file, target_fps=25) # Detect faces filename_path = detectface(video_input_path=temp_25fps_file, output_path=output_path, detect_every_N_frame=detect_every_N_frame, scalar_face_detection=scalar_face_detection, number_of_speakers=number_of_speakers) # Extract audio audio_output = os.path.join(output_path, 'audio.wav') extract_audio(temp_25fps_file, audio_output, sample_rate=16000) # Crop mouth crop_mouth(video_direc=os.path.join(output_path, "faces"), landmark_direc=os.path.join(output_path, "landmark"), filename_path=filename_path, save_direc=os.path.join(output_path, "mouthroi"), convert_gray=True, testset_only=False) # Load model audiomodel = Dolphin.from_pretrained("JusperLee/Dolphin") audiomodel.cuda() audiomodel.eval() # Process each speaker with torch.no_grad(): for i in range(number_of_speakers): mouth_roi = np.load(os.path.join(output_path, "mouthroi", f"speaker{i+1}.npz"))["data"] mouth_roi = get_preprocessing_pipelines()["val"](mouth_roi) mix, sr = torchaudio.load(audio_output) mix = mix.cuda().mean(dim=0) window_size = 4 * sr hop_size = 4 * sr all_estimates = [] # 滑动窗口处理 start_idx = 0 while start_idx < len(mix): end_idx = min(start_idx + window_size, len(mix)) window_mix = mix[start_idx:end_idx] start_frame = int(start_idx / sr * 25) end_frame = int(end_idx / sr * 25) end_frame = min(end_frame, len(mouth_roi)) window_mouth_roi = mouth_roi[start_frame:end_frame] est_sources = audiomodel(window_mix[None], torch.from_numpy(window_mouth_roi[None, None]).float().cuda()) all_estimates.append({ 'start': start_idx, 'end': end_idx, 'estimate': est_sources[0].cpu() }) start_idx += hop_size if start_idx >= len(mix): break output_length = len(mix) merged_output = torch.zeros(1, output_length) weights = torch.zeros(output_length) for est in all_estimates: window_len = est['end'] - est['start'] hann_window = torch.hann_window(window_len) merged_output[0, est['start']:est['end']] += est['estimate'][0, :window_len] * hann_window weights[est['start']:est['end']] += hann_window merged_output[:, weights > 0] /= weights[weights > 0] torchaudio.save(os.path.join(output_path, f"speaker{i+1}_est.wav"), merged_output, sr) # Merge video with separated audio for each speaker output_files = [] for i in range(number_of_speakers): video_input = os.path.join(output_path, f"video_tracked{i+1}.mp4") audio_input = os.path.join(output_path, f"speaker{i+1}_est.wav") video_output = os.path.join(output_path, f"s{i+1}.mp4") merge_video_audio(video_input, audio_input, video_output) output_files.append(video_output) # Clean up temporary file if os.path.exists(temp_25fps_file): os.remove(temp_25fps_file) return output_files if __name__ == '__main__': parser = argparse.ArgumentParser(description='Video Speaker Separation using Dolphin model') parser.add_argument('--input', '-i', type=str, required=True, help='Path to input video file') parser.add_argument('--output', '-o', type=str, default=None, help='Output directory path (default: creates directory based on input filename)') parser.add_argument('--speakers', '-s', type=int, default=2, help='Number of speakers to separate (default: 2)') parser.add_argument('--detect-every-n', type=int, default=8, help='Detect faces every N frames (default: 8)') parser.add_argument('--face-scale', type=float, default=1.5, help='Face detection bounding box scale factor (default: 1.5)') parser.add_argument('--cuda-device', type=int, default=0, help='CUDA device ID to use (default: 0, set to -1 for CPU)') parser.add_argument('--config', type=str, default="checkpoints/vox2/conf.yml", help='Path to model configuration file') args = parser.parse_args() # 验证输入文件是否存在 if not os.path.exists(args.input): print(f"Error: Input file '{args.input}' does not exist") exit(1) # 如果没有指定输出路径,基于输入文件名创建输出目录 if args.output is None: input_basename = os.path.splitext(os.path.basename(args.input))[0] args.output = os.path.join(os.path.dirname(args.input), input_basename + "_output") # 设置CUDA设备 cuda_device = args.cuda_device if args.cuda_device >= 0 else None print(f"Processing video: {args.input}") print(f"Output directory: {args.output}") print(f"Number of speakers: {args.speakers}") print(f"CUDA device: {cuda_device if cuda_device is not None else 'CPU'}") # 处理视频 output_files = process_video( input_file=args.input, output_path=args.output, number_of_speakers=args.speakers, detect_every_N_frame=args.detect_every_n, scalar_face_detection=args.face_scale, config_path=args.config, cuda_device=cuda_device ) print("\nProcessing completed!") print("Output files:") for i, output_file in enumerate(output_files): print(f" Speaker {i+1}: {output_file}")