Spaces:
Running
on
A100
Running
on
A100
| """ | |
| DiT Alignment Score Module | |
| This module provides lyrics-to-audio alignment using cross-attention matrices | |
| from DiT model for generating LRC timestamps. | |
| Refactored from lyrics_alignment_infos.py for integration with ACE-Step. | |
| """ | |
| import numba | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass, asdict | |
| from typing import List, Dict, Any, Optional, Tuple, Union | |
| # ================= Data Classes ================= | |
| class TokenTimestamp: | |
| """Stores per-token timing information.""" | |
| token_id: int | |
| text: str | |
| start: float | |
| end: float | |
| probability: float | |
| class SentenceTimestamp: | |
| """Stores per-sentence timing information with token list.""" | |
| text: str | |
| start: float | |
| end: float | |
| tokens: List[TokenTimestamp] | |
| confidence: float | |
| # ================= DTW Algorithm (Numba Optimized) ================= | |
| def dtw_cpu(x: np.ndarray): | |
| """ | |
| Dynamic Time Warping algorithm optimized with Numba. | |
| Args: | |
| x: Cost matrix of shape [N, M] | |
| Returns: | |
| Tuple of (text_indices, time_indices) arrays | |
| """ | |
| N, M = x.shape | |
| # Use float32 for memory efficiency | |
| cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf | |
| trace = -np.ones((N + 1, M + 1), dtype=np.float32) | |
| cost[0, 0] = 0 | |
| for j in range(1, M + 1): | |
| for i in range(1, N + 1): | |
| c0 = cost[i - 1, j - 1] | |
| c1 = cost[i - 1, j] | |
| c2 = cost[i, j - 1] | |
| if c0 < c1 and c0 < c2: | |
| c, t = c0, 0 | |
| elif c1 < c0 and c1 < c2: | |
| c, t = c1, 1 | |
| else: | |
| c, t = c2, 2 | |
| cost[i, j] = x[i - 1, j - 1] + c | |
| trace[i, j] = t | |
| return _backtrace(trace, N, M) | |
| def _backtrace(trace: np.ndarray, N: int, M: int): | |
| """ | |
| Optimized backtrace function for DTW. | |
| Args: | |
| trace: Trace matrix of shape (N+1, M+1) | |
| N, M: Original matrix dimensions | |
| Returns: | |
| Path array of shape (2, path_len) - first row is text indices, second is time indices | |
| """ | |
| # Boundary handling | |
| trace[0, :] = 2 | |
| trace[:, 0] = 1 | |
| # Pre-allocate array, max path length is N+M | |
| max_path_len = N + M | |
| path = np.zeros((2, max_path_len), dtype=np.int32) | |
| i, j = N, M | |
| path_idx = max_path_len - 1 | |
| while i > 0 or j > 0: | |
| path[0, path_idx] = i - 1 # text index | |
| path[1, path_idx] = j - 1 # time index | |
| path_idx -= 1 | |
| t = trace[i, j] | |
| if t == 0: | |
| i -= 1 | |
| j -= 1 | |
| elif t == 1: | |
| i -= 1 | |
| elif t == 2: | |
| j -= 1 | |
| else: | |
| break | |
| actual_len = max_path_len - path_idx - 1 | |
| return path[:, path_idx + 1:max_path_len] | |
| # ================= Utility Functions ================= | |
| def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor: | |
| """ | |
| Apply median filter to tensor. | |
| Args: | |
| x: Input tensor | |
| filter_width: Width of median filter | |
| Returns: | |
| Filtered tensor | |
| """ | |
| pad_width = filter_width // 2 | |
| if x.shape[-1] <= pad_width: | |
| return x | |
| if x.ndim == 2: | |
| x = x[None, :] | |
| x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") | |
| result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] | |
| if result.ndim > 2: | |
| result = result.squeeze(0) | |
| return result | |
| # ================= Main Aligner Class ================= | |
| class MusicStampsAligner: | |
| """ | |
| Aligner class for generating lyrics timestamps from cross-attention matrices. | |
| Uses bidirectional consensus denoising and DTW for alignment. | |
| """ | |
| def __init__(self, tokenizer): | |
| """ | |
| Initialize the aligner. | |
| Args: | |
| tokenizer: Text tokenizer for decoding tokens | |
| """ | |
| self.tokenizer = tokenizer | |
| def _apply_bidirectional_consensus( | |
| self, | |
| weights_stack: torch.Tensor, | |
| violence_level: float, | |
| medfilt_width: int | |
| ) -> tuple: | |
| """ | |
| Core denoising logic using bidirectional consensus. | |
| Args: | |
| weights_stack: Attention weights [Heads, Tokens, Frames] | |
| violence_level: Denoising strength coefficient | |
| medfilt_width: Median filter width | |
| Returns: | |
| Tuple of (calc_matrix, energy_matrix) as numpy arrays | |
| """ | |
| # A. Bidirectional Consensus | |
| row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame | |
| col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token | |
| processed = row_prob * col_prob | |
| # 1. Row suppression (kill horizontal crossing lines) | |
| row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True) | |
| processed = processed - (violence_level * row_medians) | |
| processed = torch.relu(processed) | |
| # 2. Column suppression (kill vertical crossing lines) | |
| col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True) | |
| processed = processed - (violence_level * col_medians) | |
| processed = torch.relu(processed) | |
| # C. Power sharpening | |
| processed = processed ** 2 | |
| # Energy matrix for confidence | |
| energy_matrix = processed.mean(dim=0).cpu().numpy() | |
| # D. Z-Score normalization | |
| std, mean = torch.std_mean(processed, unbiased=False) | |
| weights_processed = (processed - mean) / (std + 1e-9) | |
| # E. Median filtering | |
| weights_processed = median_filter(weights_processed, filter_width=medfilt_width) | |
| calc_matrix = weights_processed.mean(dim=0).numpy() | |
| return calc_matrix, energy_matrix | |
| def _preprocess_attention( | |
| self, | |
| attention_matrix: torch.Tensor, | |
| custom_config: Dict[int, List[int]], | |
| violence_level: float, | |
| medfilt_width: int = 7 | |
| ) -> tuple: | |
| """ | |
| Preprocess attention matrix for alignment. | |
| Args: | |
| attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames] | |
| custom_config: Dict mapping layer indices to head indices | |
| violence_level: Denoising strength | |
| medfilt_width: Median filter width | |
| Returns: | |
| Tuple of (calc_matrix, energy_matrix, visual_matrix) | |
| """ | |
| if not isinstance(attention_matrix, torch.Tensor): | |
| weights = torch.tensor(attention_matrix) | |
| else: | |
| weights = attention_matrix.clone() | |
| weights = weights.cpu().float() | |
| selected_tensors = [] | |
| for layer_idx, head_indices in custom_config.items(): | |
| for head_idx in head_indices: | |
| if layer_idx < weights.shape[0] and head_idx < weights.shape[1]: | |
| head_matrix = weights[layer_idx, head_idx] | |
| selected_tensors.append(head_matrix) | |
| if not selected_tensors: | |
| return None, None, None | |
| # Stack selected heads: [Heads, Tokens, Frames] | |
| weights_stack = torch.stack(selected_tensors, dim=0) | |
| visual_matrix = weights_stack.mean(dim=0).numpy() | |
| calc_matrix, energy_matrix = self._apply_bidirectional_consensus( | |
| weights_stack, violence_level, medfilt_width | |
| ) | |
| return calc_matrix, energy_matrix, visual_matrix | |
| def stamps_align_info( | |
| self, | |
| attention_matrix: torch.Tensor, | |
| lyrics_tokens: List[int], | |
| total_duration_seconds: float, | |
| custom_config: Dict[int, List[int]], | |
| return_matrices: bool = False, | |
| violence_level: float = 2.0, | |
| medfilt_width: int = 1 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get alignment information from attention matrix. | |
| Args: | |
| attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames] | |
| lyrics_tokens: List of lyrics token IDs | |
| total_duration_seconds: Total audio duration in seconds | |
| custom_config: Dict mapping layer indices to head indices | |
| return_matrices: Whether to return intermediate matrices | |
| violence_level: Denoising strength | |
| medfilt_width: Median filter width | |
| Returns: | |
| Dict containing calc_matrix, lyrics_tokens, total_duration_seconds, | |
| and optionally energy_matrix and vis_matrix | |
| """ | |
| calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention( | |
| attention_matrix, custom_config, violence_level, medfilt_width | |
| ) | |
| if calc_matrix is None: | |
| return { | |
| "calc_matrix": None, | |
| "lyrics_tokens": lyrics_tokens, | |
| "total_duration_seconds": total_duration_seconds, | |
| "error": "No valid attention heads found" | |
| } | |
| return_dict = { | |
| "calc_matrix": calc_matrix, | |
| "lyrics_tokens": lyrics_tokens, | |
| "total_duration_seconds": total_duration_seconds | |
| } | |
| if return_matrices: | |
| return_dict['energy_matrix'] = energy_matrix | |
| return_dict['vis_matrix'] = visual_matrix | |
| return return_dict | |
| def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]: | |
| """ | |
| Decode tokens incrementally to properly handle multi-byte UTF-8 characters. | |
| For Chinese and other multi-byte characters, the tokenizer may split them | |
| into multiple byte-level tokens. Decoding each token individually produces | |
| invalid UTF-8 sequences (showing as �). This method uses byte-level comparison | |
| to correctly track which characters each token contributes. | |
| Args: | |
| token_ids: List of token IDs | |
| Returns: | |
| List of decoded text for each token position | |
| """ | |
| decoded_tokens = [] | |
| prev_bytes = b"" | |
| for i in range(len(token_ids)): | |
| # Decode tokens from start to current position | |
| current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False) | |
| current_bytes = current_text.encode('utf-8', errors='surrogatepass') | |
| # The contribution of current token is the new bytes added | |
| if len(current_bytes) >= len(prev_bytes): | |
| new_bytes = current_bytes[len(prev_bytes):] | |
| # Try to decode the new bytes; if incomplete, use empty string | |
| try: | |
| token_text = new_bytes.decode('utf-8') | |
| except UnicodeDecodeError: | |
| # Incomplete UTF-8 sequence, this token doesn't complete a character | |
| token_text = "" | |
| else: | |
| # Edge case: current decode is shorter (shouldn't happen normally) | |
| token_text = "" | |
| decoded_tokens.append(token_text) | |
| prev_bytes = current_bytes | |
| return decoded_tokens | |
| def token_timestamps( | |
| self, | |
| calc_matrix: np.ndarray, | |
| lyrics_tokens: List[int], | |
| total_duration_seconds: float | |
| ) -> List[TokenTimestamp]: | |
| """ | |
| Generate per-token timestamps using DTW. | |
| Args: | |
| calc_matrix: Processed attention matrix [Tokens, Frames] | |
| lyrics_tokens: List of token IDs | |
| total_duration_seconds: Total audio duration | |
| Returns: | |
| List of TokenTimestamp objects | |
| """ | |
| n_frames = calc_matrix.shape[-1] | |
| text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64)) | |
| seconds_per_frame = total_duration_seconds / n_frames | |
| alignment_results = [] | |
| # Use incremental decoding to properly handle multi-byte UTF-8 characters | |
| decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens) | |
| for i in range(len(lyrics_tokens)): | |
| mask = (text_indices == i) | |
| if not np.any(mask): | |
| start = alignment_results[-1].end if alignment_results else 0.0 | |
| end = start | |
| token_conf = 0.0 | |
| else: | |
| times = time_indices[mask] * seconds_per_frame | |
| start = times[0] | |
| end = times[-1] | |
| token_conf = 0.0 | |
| if end < start: | |
| end = start | |
| alignment_results.append(TokenTimestamp( | |
| token_id=lyrics_tokens[i], | |
| text=decoded_tokens[i], | |
| start=float(start), | |
| end=float(end), | |
| probability=token_conf | |
| )) | |
| return alignment_results | |
| def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str: | |
| """ | |
| Decode a sentence by decoding all token IDs together. | |
| This avoids UTF-8 encoding issues from joining individual token texts. | |
| Args: | |
| tokens: List of TokenTimestamp objects | |
| Returns: | |
| Properly decoded sentence text | |
| """ | |
| token_ids = [t.token_id for t in tokens] | |
| return self.tokenizer.decode(token_ids, skip_special_tokens=False) | |
| def sentence_timestamps( | |
| self, | |
| token_alignment: List[TokenTimestamp] | |
| ) -> List[SentenceTimestamp]: | |
| """ | |
| Group token timestamps into sentence timestamps. | |
| Args: | |
| token_alignment: List of TokenTimestamp objects | |
| Returns: | |
| List of SentenceTimestamp objects | |
| """ | |
| results = [] | |
| current_tokens = [] | |
| for token in token_alignment: | |
| current_tokens.append(token) | |
| if '\n' in token.text: | |
| # Decode all token IDs together to avoid UTF-8 issues | |
| full_text = self._decode_sentence_from_tokens(current_tokens) | |
| if full_text.strip(): | |
| valid_scores = [t.probability for t in current_tokens if t.probability > 0] | |
| sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 | |
| results.append(SentenceTimestamp( | |
| text=full_text.strip(), | |
| start=round(current_tokens[0].start, 3), | |
| end=round(current_tokens[-1].end, 3), | |
| tokens=list(current_tokens), | |
| confidence=sent_conf | |
| )) | |
| current_tokens = [] | |
| # Handle last sentence | |
| if current_tokens: | |
| # Decode all token IDs together to avoid UTF-8 issues | |
| full_text = self._decode_sentence_from_tokens(current_tokens) | |
| if full_text.strip(): | |
| valid_scores = [t.probability for t in current_tokens if t.probability > 0] | |
| sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 | |
| results.append(SentenceTimestamp( | |
| text=full_text.strip(), | |
| start=round(current_tokens[0].start, 3), | |
| end=round(current_tokens[-1].end, 3), | |
| tokens=list(current_tokens), | |
| confidence=sent_conf | |
| )) | |
| # Normalize confidence scores | |
| if results: | |
| all_scores = [s.confidence for s in results] | |
| min_score = min(all_scores) | |
| max_score = max(all_scores) | |
| score_range = max_score - min_score | |
| if score_range > 1e-9: | |
| for s in results: | |
| normalized_score = (s.confidence - min_score) / score_range | |
| s.confidence = round(normalized_score, 2) | |
| else: | |
| for s in results: | |
| s.confidence = round(s.confidence, 2) | |
| return results | |
| def format_lrc( | |
| self, | |
| sentence_timestamps: List[SentenceTimestamp], | |
| include_end_time: bool = False | |
| ) -> str: | |
| """ | |
| Format sentence timestamps as LRC lyrics format. | |
| Args: | |
| sentence_timestamps: List of SentenceTimestamp objects | |
| include_end_time: Whether to include end time (enhanced LRC format) | |
| Returns: | |
| LRC formatted string | |
| """ | |
| lines = [] | |
| for sentence in sentence_timestamps: | |
| # Convert seconds to mm:ss.xx format | |
| start_minutes = int(sentence.start // 60) | |
| start_seconds = sentence.start % 60 | |
| if include_end_time: | |
| end_minutes = int(sentence.end // 60) | |
| end_seconds = sentence.end % 60 | |
| timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]" | |
| else: | |
| timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]" | |
| # Clean the text (remove structural tags like [verse], [chorus]) | |
| text = sentence.text | |
| lines.append(f"{timestamp}{text}") | |
| return "\n".join(lines) | |
| def get_timestamps_and_lrc( | |
| self, | |
| calc_matrix: np.ndarray, | |
| lyrics_tokens: List[int], | |
| total_duration_seconds: float | |
| ) -> Dict[str, Any]: | |
| """ | |
| Convenience method to get both timestamps and LRC in one call. | |
| Args: | |
| calc_matrix: Processed attention matrix | |
| lyrics_tokens: List of token IDs | |
| total_duration_seconds: Total audio duration | |
| Returns: | |
| Dict containing token_timestamps, sentence_timestamps, and lrc_text | |
| """ | |
| token_stamps = self.token_timestamps( | |
| calc_matrix=calc_matrix, | |
| lyrics_tokens=lyrics_tokens, | |
| total_duration_seconds=total_duration_seconds | |
| ) | |
| sentence_stamps = self.sentence_timestamps(token_stamps) | |
| lrc_text = self.format_lrc(sentence_stamps) | |
| return { | |
| "token_timestamps": token_stamps, | |
| "sentence_timestamps": sentence_stamps, | |
| "lrc_text": lrc_text | |
| } | |
| class MusicLyricScorer: | |
| """ | |
| Scorer class for evaluating lyrics-to-audio alignment quality. | |
| Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence) | |
| using tensor operations for potential differentiability or GPU acceleration. | |
| """ | |
| def __init__(self, tokenizer: Any): | |
| """ | |
| Initialize the aligner. | |
| Args: | |
| tokenizer: Tokenizer instance (must implement .decode()). | |
| """ | |
| self.tokenizer = tokenizer | |
| def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray: | |
| """ | |
| Generate a mask distinguishing lyrics (1) from structural tags (0). | |
| Uses self.tokenizer to decode tokens. | |
| Args: | |
| token_ids: List of token IDs. | |
| Returns: | |
| Numpy array of shape [len(token_ids)] with 1 or 0. | |
| """ | |
| decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids] | |
| mask = np.ones(len(token_ids), dtype=np.int32) | |
| in_bracket = False | |
| for i, token_str in enumerate(decoded_tokens): | |
| if '[' in token_str: | |
| in_bracket = True | |
| if in_bracket: | |
| mask[i] = 0 | |
| if ']' in token_str: | |
| in_bracket = False | |
| mask[i] = 0 | |
| return mask | |
| def _preprocess_attention( | |
| self, | |
| attention_matrix: Union[torch.Tensor, np.ndarray], | |
| custom_config: Dict[int, List[int]], | |
| medfilt_width: int = 1 | |
| ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]: | |
| """ | |
| Extracts and normalizes the attention matrix. | |
| Logic V4: Uses Min-Max normalization to highlight energy differences. | |
| Args: | |
| attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames]. | |
| custom_config: Config mapping layers to heads. | |
| medfilt_width: Width for median filtering. | |
| Returns: | |
| Tuple of (calc_matrix, energy_matrix, avg_weights_tensor). | |
| """ | |
| # 1. Prepare Tensor | |
| if not isinstance(attention_matrix, torch.Tensor): | |
| weights = torch.tensor(attention_matrix) | |
| else: | |
| weights = attention_matrix.clone() | |
| weights = weights.cpu().float() | |
| # 2. Select Heads based on config | |
| selected_tensors = [] | |
| for layer_idx, head_indices in custom_config.items(): | |
| for head_idx in head_indices: | |
| if layer_idx < weights.shape[0] and head_idx < weights.shape[1]: | |
| selected_tensors.append(weights[layer_idx, head_idx]) | |
| if not selected_tensors: | |
| return None, None, None | |
| weights_stack = torch.stack(selected_tensors, dim=0) | |
| # 3. Average Heads | |
| avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames] | |
| # 4. Preprocessing Logic | |
| # Min-Max normalization preserving energy distribution | |
| # Median filter is applied to the energy matrix | |
| energy_tensor = median_filter(avg_weights, filter_width=medfilt_width) | |
| energy_matrix = energy_tensor.numpy() | |
| e_min, e_max = energy_matrix.min(), energy_matrix.max() | |
| if e_max - e_min > 1e-9: | |
| energy_matrix = (energy_matrix - e_min) / (e_max - e_min) | |
| else: | |
| energy_matrix = np.zeros_like(energy_matrix) | |
| # Contrast enhancement for DTW pathfinding | |
| # calc_matrix is used for pathfinding, energy_matrix for scoring | |
| calc_matrix = energy_matrix ** 2 | |
| return calc_matrix, energy_matrix, avg_weights | |
| def _compute_alignment_metrics( | |
| self, | |
| energy_matrix: torch.Tensor, | |
| path_coords: torch.Tensor, | |
| type_mask: torch.Tensor, | |
| time_weight: float = 0.01, | |
| overlap_frames: float = 9.0, | |
| instrumental_weight: float = 1.0 | |
| ) -> Tuple[float, float, float]: | |
| """ | |
| Core metric calculation logic using high-precision Tensor operations. | |
| Args: | |
| energy_matrix: Normalized energy [Rows, Cols]. | |
| path_coords: DTW path coordinates [Steps, 2]. | |
| type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags). | |
| time_weight: Minimum energy threshold for monotonicity. | |
| overlap_frames: Allowed overlap for monotonicity check. | |
| instrumental_weight: Weight for non-lyric tokens in confidence calc. | |
| Returns: | |
| Tuple of (coverage, monotonicity, confidence). | |
| """ | |
| # Ensure high precision for internal calculation | |
| energy_matrix = energy_matrix.to(dtype=torch.float64) | |
| path_coords = path_coords.long() | |
| type_mask = type_mask.long() | |
| device = energy_matrix.device | |
| rows, cols = energy_matrix.shape | |
| is_lyrics_row = (type_mask == 1) | |
| # ================= A. Coverage Score ================= | |
| # Ratio of lyric lines that have significant energy peak | |
| row_max_energies = energy_matrix.max(dim=1).values | |
| total_sung_rows = is_lyrics_row.sum().double() | |
| coverage_threshold = 0.1 | |
| valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold) | |
| valid_sung_rows = valid_sung_mask.sum().double() | |
| if total_sung_rows > 0: | |
| coverage_score = valid_sung_rows / total_sung_rows | |
| else: | |
| coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64) | |
| # ================= B. Monotonicity Score ================= | |
| # Check if the "center of mass" of lyric lines moves forward in time | |
| col_indices = torch.arange(cols, device=device, dtype=torch.float64) | |
| # Zero out low energy noise | |
| weights = torch.where( | |
| energy_matrix > time_weight, | |
| energy_matrix, | |
| torch.zeros_like(energy_matrix) | |
| ) | |
| sum_w = weights.sum(dim=1) | |
| sum_t = (weights * col_indices).sum(dim=1) | |
| # Calculate centroids | |
| centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64) | |
| valid_w_mask = sum_w > 1e-9 | |
| centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask] | |
| # Extract sequence of valid lyrics centroids | |
| valid_sequence_mask = is_lyrics_row & (centroids >= 0) | |
| sung_centroids = centroids[valid_sequence_mask] | |
| cnt = sung_centroids.shape[0] | |
| if cnt > 1: | |
| curr_c = sung_centroids[:-1] | |
| next_c = sung_centroids[1:] | |
| # Check non-decreasing order with overlap tolerance | |
| non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum() | |
| pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64) | |
| monotonicity_score = non_decreasing / pairs | |
| else: | |
| monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64) | |
| # ================= C. Path Confidence ================= | |
| # Average energy along the optimal path | |
| if path_coords.shape[0] > 0: | |
| p_rows = path_coords[:, 0] | |
| p_cols = path_coords[:, 1] | |
| path_energies = energy_matrix[p_rows, p_cols] | |
| step_weights = torch.ones_like(path_energies) | |
| # Lower weight for instrumental/tag steps | |
| is_inst_step = (type_mask[p_rows] == 0) | |
| step_weights[is_inst_step] = instrumental_weight | |
| total_energy = (path_energies * step_weights).sum() | |
| total_steps = step_weights.sum() | |
| if total_steps > 0: | |
| path_confidence = total_energy / total_steps | |
| else: | |
| path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64) | |
| else: | |
| path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64) | |
| return coverage_score.item(), monotonicity_score.item(), path_confidence.item() | |
| def lyrics_alignment_info( | |
| self, | |
| attention_matrix: Union[torch.Tensor, np.ndarray], | |
| token_ids: List[int], | |
| custom_config: Dict[int, List[int]], | |
| return_matrices: bool = False, | |
| medfilt_width: int = 1 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generates alignment path and processed matrices. | |
| Args: | |
| attention_matrix: Input attention tensor. | |
| token_ids: Corresponding token IDs. | |
| custom_config: Layer/Head configuration. | |
| return_matrices: If True, returns matrices in the output. | |
| medfilt_width: Median filter width. | |
| Returns: | |
| Dict or AlignmentInfo object containing path and masks. | |
| """ | |
| calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention( | |
| attention_matrix, custom_config, medfilt_width | |
| ) | |
| if calc_matrix is None: | |
| return { | |
| "calc_matrix": None, | |
| "error": "No valid attention heads found" | |
| } | |
| # 1. Generate Semantic Mask (1=Lyrics, 0=Tags) | |
| # Uses self.tokenizer internally | |
| type_mask = self._generate_token_type_mask(token_ids) | |
| # Safety check for shape mismatch | |
| if len(type_mask) != energy_matrix.shape[0]: | |
| # Fallback to all lyrics if shapes don't align | |
| type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32) | |
| # 2. DTW Pathfinding | |
| # Using negative calc_matrix because DTW minimizes cost | |
| text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32)) | |
| path_coords = np.stack([text_indices, time_indices], axis=1) | |
| return_dict = { | |
| "path_coords": path_coords, | |
| "type_mask": type_mask, | |
| "energy_matrix": energy_matrix | |
| } | |
| if return_matrices: | |
| return_dict['calc_matrix'] = calc_matrix | |
| return_dict['vis_matrix'] = vis_matrix | |
| return return_dict | |
| def calculate_score( | |
| self, | |
| energy_matrix: Union[torch.Tensor, np.ndarray], | |
| type_mask: Union[torch.Tensor, np.ndarray], | |
| path_coords: Union[torch.Tensor, np.ndarray], | |
| time_weight: float = 0.01, | |
| overlap_frames: float = 9.0, | |
| instrumental_weight: float = 1.0 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Calculates the final alignment score based on pre-computed components. | |
| Args: | |
| energy_matrix: Processed energy matrix. | |
| type_mask: Token type mask. | |
| path_coords: DTW path coordinates. | |
| time_weight: Minimum energy threshold for monotonicity. | |
| overlap_frames: Allowed backward movement frames. | |
| instrumental_weight: Weight for non-lyric path steps. | |
| Returns: | |
| AlignmentScore object containing individual metrics and final score. | |
| """ | |
| # Ensure Inputs are Tensors on the correct device | |
| if not isinstance(energy_matrix, torch.Tensor): | |
| energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32) | |
| device = energy_matrix.device | |
| if not isinstance(type_mask, torch.Tensor): | |
| type_mask = torch.tensor(type_mask, device=device, dtype=torch.long) | |
| else: | |
| type_mask = type_mask.to(device=device, dtype=torch.long) | |
| if not isinstance(path_coords, torch.Tensor): | |
| path_coords = torch.tensor(path_coords, device=device, dtype=torch.long) | |
| else: | |
| path_coords = path_coords.to(device=device, dtype=torch.long) | |
| # Compute Metrics | |
| coverage, monotonicity, confidence = self._compute_alignment_metrics( | |
| energy_matrix=energy_matrix, | |
| path_coords=path_coords, | |
| type_mask=type_mask, | |
| time_weight=time_weight, | |
| overlap_frames=overlap_frames, | |
| instrumental_weight=instrumental_weight | |
| ) | |
| # Final Score Calculation | |
| # (Cov^2 * Mono^2 * Conf) | |
| final_score = (coverage ** 2) * (monotonicity ** 2) * confidence | |
| final_score = float(np.clip(final_score, 0.0, 1.0)) | |
| return { | |
| "lyrics_score": round(final_score, 4) | |
| } |